Merge pull request #676 from SuanmoSuanyangTechnology/feature/multimodel_memory
feat(memory, model): update multi-modal memory write and model list API
This commit is contained in:
@@ -118,142 +118,142 @@ async def download_log(
|
||||
return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e))
|
||||
|
||||
|
||||
@router.post("/writer_service", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def write_server(
|
||||
user_input: Write_UserInput,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Write service endpoint - processes write operations synchronously
|
||||
|
||||
Args:
|
||||
user_input: Write request containing message and end_user_id
|
||||
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||
|
||||
Returns:
|
||||
Response with write operation status
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
user=current_user
|
||||
)
|
||||
if storage_type is None: storage_type = 'neo4j'
|
||||
user_rag_memory_id = ''
|
||||
|
||||
# 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
||||
if storage_type == 'rag':
|
||||
if workspace_id:
|
||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
db=db,
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
else:
|
||||
api_logger.warning(
|
||||
f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
else:
|
||||
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
|
||||
api_logger.info(
|
||||
f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
try:
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
result = await memory_agent_service.write_memory(
|
||||
user_input.end_user_id,
|
||||
messages_list,
|
||||
config_id,
|
||||
db,
|
||||
storage_type,
|
||||
user_rag_memory_id,
|
||||
language
|
||||
)
|
||||
|
||||
return success(data=result, msg="写入成功")
|
||||
except BaseException as e:
|
||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
if hasattr(e, 'exceptions'):
|
||||
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||
detailed_error = "; ".join(error_messages)
|
||||
api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
|
||||
api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||
|
||||
|
||||
@router.post("/writer_service_async", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def write_server_async(
|
||||
user_input: Write_UserInput,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Async write service endpoint - enqueues write processing to Celery
|
||||
|
||||
Args:
|
||||
user_input: Write request containing message and end_user_id
|
||||
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||
|
||||
Returns:
|
||||
Task ID for tracking async operation
|
||||
Use GET /memory/write_result/{task_id} to check task status and get result
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(
|
||||
f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
user=current_user
|
||||
)
|
||||
if storage_type is None: storage_type = 'neo4j'
|
||||
user_rag_memory_id = ''
|
||||
if workspace_id:
|
||||
|
||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
db=db,
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||
api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
try:
|
||||
# 获取标准化的消息列表
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
|
||||
task = celery_app.send_task(
|
||||
"app.core.memory.agent.write_message",
|
||||
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
|
||||
)
|
||||
api_logger.info(f"Write task queued: {task.id}")
|
||||
|
||||
return success(data={"task_id": task.id}, msg="写入任务已提交")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Async write operation failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||
# @router.post("/writer_service", response_model=ApiResponse)
|
||||
# @cur_workspace_access_guard()
|
||||
# async def write_server(
|
||||
# user_input: Write_UserInput,
|
||||
# language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
# db: Session = Depends(get_db),
|
||||
# current_user: User = Depends(get_current_user)
|
||||
# ):
|
||||
# """
|
||||
# Write service endpoint - processes write operations synchronously
|
||||
#
|
||||
# Args:
|
||||
# user_input: Write request containing message and end_user_id
|
||||
# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||
#
|
||||
# Returns:
|
||||
# Response with write operation status
|
||||
# """
|
||||
# # 使用集中化的语言校验
|
||||
# language = get_language_from_header(language_type)
|
||||
#
|
||||
# config_id = user_input.config_id
|
||||
# workspace_id = current_user.current_workspace_id
|
||||
# api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
#
|
||||
# # 获取 storage_type,如果为 None 则使用默认值
|
||||
# storage_type = workspace_service.get_workspace_storage_type(
|
||||
# db=db,
|
||||
# workspace_id=workspace_id,
|
||||
# user=current_user
|
||||
# )
|
||||
# if storage_type is None: storage_type = 'neo4j'
|
||||
# user_rag_memory_id = ''
|
||||
#
|
||||
# # 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
||||
# if storage_type == 'rag':
|
||||
# if workspace_id:
|
||||
# knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
# db=db,
|
||||
# name="USER_RAG_MERORY",
|
||||
# workspace_id=workspace_id
|
||||
# )
|
||||
# if knowledge:
|
||||
# user_rag_memory_id = str(knowledge.id)
|
||||
# else:
|
||||
# api_logger.warning(
|
||||
# f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||
# storage_type = 'neo4j'
|
||||
# else:
|
||||
# api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||
# storage_type = 'neo4j'
|
||||
#
|
||||
# api_logger.info(
|
||||
# f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
# try:
|
||||
# messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
# result = await memory_agent_service.write_memory(
|
||||
# user_input.end_user_id,
|
||||
# messages_list,
|
||||
# config_id,
|
||||
# db,
|
||||
# storage_type,
|
||||
# user_rag_memory_id,
|
||||
# language
|
||||
# )
|
||||
#
|
||||
# return success(data=result, msg="写入成功")
|
||||
# except BaseException as e:
|
||||
# # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
# if hasattr(e, 'exceptions'):
|
||||
# error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||
# detailed_error = "; ".join(error_messages)
|
||||
# api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
|
||||
# return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
|
||||
# api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
|
||||
# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||
#
|
||||
#
|
||||
# @router.post("/writer_service_async", response_model=ApiResponse)
|
||||
# @cur_workspace_access_guard()
|
||||
# async def write_server_async(
|
||||
# user_input: Write_UserInput,
|
||||
# language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
# db: Session = Depends(get_db),
|
||||
# current_user: User = Depends(get_current_user)
|
||||
# ):
|
||||
# """
|
||||
# Async write service endpoint - enqueues write processing to Celery
|
||||
#
|
||||
# Args:
|
||||
# user_input: Write request containing message and end_user_id
|
||||
# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||
#
|
||||
# Returns:
|
||||
# Task ID for tracking async operation
|
||||
# Use GET /memory/write_result/{task_id} to check task status and get result
|
||||
# """
|
||||
# # 使用集中化的语言校验
|
||||
# language = get_language_from_header(language_type)
|
||||
#
|
||||
# config_id = user_input.config_id
|
||||
# workspace_id = current_user.current_workspace_id
|
||||
# api_logger.info(
|
||||
# f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
#
|
||||
# # 获取 storage_type,如果为 None 则使用默认值
|
||||
# storage_type = workspace_service.get_workspace_storage_type(
|
||||
# db=db,
|
||||
# workspace_id=workspace_id,
|
||||
# user=current_user
|
||||
# )
|
||||
# if storage_type is None: storage_type = 'neo4j'
|
||||
# user_rag_memory_id = ''
|
||||
# if workspace_id:
|
||||
#
|
||||
# knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
# db=db,
|
||||
# name="USER_RAG_MERORY",
|
||||
# workspace_id=workspace_id
|
||||
# )
|
||||
# if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||
# api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
# try:
|
||||
# # 获取标准化的消息列表
|
||||
# messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
#
|
||||
# task = celery_app.send_task(
|
||||
# "app.core.memory.agent.write_message",
|
||||
# args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
|
||||
# )
|
||||
# api_logger.info(f"Write task queued: {task.id}")
|
||||
#
|
||||
# return success(data={"task_id": task.id}, msg="写入任务已提交")
|
||||
# except Exception as e:
|
||||
# api_logger.error(f"Async write operation failed: {str(e)}")
|
||||
# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||
|
||||
|
||||
@router.post("/read_service", response_model=ApiResponse)
|
||||
|
||||
@@ -54,8 +54,8 @@ router = APIRouter(
|
||||
|
||||
@router.get("/info", response_model=ApiResponse)
|
||||
async def get_storage_info(
|
||||
storage_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
storage_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Example wrapper endpoint - retrieves storage information
|
||||
@@ -75,17 +75,12 @@ async def get_storage_info(
|
||||
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
|
||||
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
|
||||
def create_config(
|
||||
payload: ConfigParamsCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
|
||||
payload: ConfigParamsCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
# 检查用户是否已选择工作空间
|
||||
@@ -107,9 +102,11 @@ def create_config(
|
||||
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
|
||||
lang = get_language_from_header(x_language_type)
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
|
||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists",
|
||||
f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
|
||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在",
|
||||
f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.error(f"Create config failed: {err_str}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
|
||||
@@ -119,9 +116,11 @@ def create_config(
|
||||
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
|
||||
lang = get_language_from_header(x_language_type)
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
|
||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists",
|
||||
f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
|
||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在",
|
||||
f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.error(f"Create config failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
|
||||
@@ -129,10 +128,10 @@ def create_config(
|
||||
|
||||
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
|
||||
def delete_config(
|
||||
config_id: UUID|int,
|
||||
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
config_id: UUID | int,
|
||||
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""删除记忆配置(带终端用户保护)
|
||||
|
||||
@@ -145,7 +144,7 @@ def delete_config(
|
||||
force: 设置为 true 可强制删除(即使有终端用户正在使用)
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
config_id=resolve_config_id(config_id, db)
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
|
||||
@@ -203,9 +202,9 @@ def delete_config(
|
||||
|
||||
@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc
|
||||
def update_config(
|
||||
payload: ConfigUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
payload: ConfigUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||
@@ -217,7 +216,8 @@ def update_config(
|
||||
# 校验至少有一个字段需要更新
|
||||
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段",
|
||||
"config_name, config_desc, scene_id 均为空")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
|
||||
try:
|
||||
@@ -231,9 +231,9 @@ def update_config(
|
||||
|
||||
@router.post("/update_config_extracted", response_model=ApiResponse) # 更新数据库中的部分内容 所有业务字段均可选
|
||||
def update_config_extracted(
|
||||
payload: ConfigUpdateExtracted,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
payload: ConfigUpdateExtracted,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||
@@ -256,11 +256,11 @@ def update_config_extracted(
|
||||
# 遗忘引擎配置接口已迁移到 memory_forget_controller.py
|
||||
# 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config
|
||||
|
||||
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||
def read_config_extracted(
|
||||
config_id: UUID | int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
config_id: UUID | int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
@@ -278,10 +278,11 @@ def read_config_extracted(
|
||||
api_logger.error(f"Read config extracted failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e))
|
||||
|
||||
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
|
||||
|
||||
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
|
||||
def read_all_config(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -303,10 +304,10 @@ def read_all_config(
|
||||
|
||||
@router.post("/pilot_run", response_model=None)
|
||||
async def pilot_run(
|
||||
payload: ConfigPilotRun,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
payload: ConfigPilotRun,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> StreamingResponse:
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
@@ -333,9 +334,9 @@ async def pilot_run(
|
||||
|
||||
@router.get("/search/kb_type_distribution", response_model=ApiResponse)
|
||||
async def get_kb_type_distribution(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"KB type distribution requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await kb_type_distribution(end_user_id)
|
||||
@@ -347,9 +348,9 @@ async def get_kb_type_distribution(
|
||||
|
||||
@router.get("/search/dialogue", response_model=ApiResponse)
|
||||
async def search_dialogues_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search dialogue requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_dialogue(end_user_id)
|
||||
@@ -361,9 +362,9 @@ async def search_dialogues_num(
|
||||
|
||||
@router.get("/search/chunk", response_model=ApiResponse)
|
||||
async def search_chunks_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search chunk requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_chunk(end_user_id)
|
||||
@@ -375,9 +376,9 @@ async def search_chunks_num(
|
||||
|
||||
@router.get("/search/statement", response_model=ApiResponse)
|
||||
async def search_statements_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search statement requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_statement(end_user_id)
|
||||
@@ -389,9 +390,9 @@ async def search_statements_num(
|
||||
|
||||
@router.get("/search/entity", response_model=ApiResponse)
|
||||
async def search_entities_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search entity requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_entity(end_user_id)
|
||||
@@ -403,9 +404,9 @@ async def search_entities_num(
|
||||
|
||||
@router.get("/search", response_model=ApiResponse)
|
||||
async def search_all_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search all requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_all(end_user_id)
|
||||
@@ -417,9 +418,9 @@ async def search_all_num(
|
||||
|
||||
@router.get("/search/detials", response_model=ApiResponse)
|
||||
async def search_entities_detials(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search details requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_detials(end_user_id)
|
||||
@@ -431,9 +432,9 @@ async def search_entities_detials(
|
||||
|
||||
@router.get("/search/edges", response_model=ApiResponse)
|
||||
async def search_entity_edges(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search edges requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_edges(end_user_id)
|
||||
@@ -443,14 +444,12 @@ async def search_entity_edges(
|
||||
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
|
||||
|
||||
|
||||
|
||||
|
||||
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
|
||||
async def get_hot_memory_tags_api(
|
||||
limit: int = 10,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
limit: int = 10,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
获取热门记忆标签(带Redis缓存)
|
||||
|
||||
@@ -505,8 +504,8 @@ async def get_hot_memory_tags_api(
|
||||
|
||||
@router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse)
|
||||
async def clear_hot_memory_tags_cache(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
清除热门标签缓存
|
||||
|
||||
@@ -543,7 +542,7 @@ async def clear_hot_memory_tags_cache(
|
||||
|
||||
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
|
||||
async def get_recent_activity_stats_api(
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None
|
||||
api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}")
|
||||
@@ -553,4 +552,3 @@ async def get_recent_activity_stats_api(
|
||||
except Exception as e:
|
||||
api_logger.error(f"Recent activity stats failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ def get_model_strategies():
|
||||
@router.get("", response_model=ApiResponse)
|
||||
def get_model_list(
|
||||
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||
capability: Optional[list[str]] = Query(None, description="能力筛选(支持多个,如 ?capability=chat 或 ?capability=chat, embedding)"),
|
||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||
@@ -74,10 +75,21 @@ def get_model_list(
|
||||
unique_flat_type = list(dict.fromkeys(flat_type))
|
||||
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
||||
|
||||
capability_list = []
|
||||
if capability is not None:
|
||||
flat_capability = []
|
||||
for item in capability:
|
||||
split_items = [c.strip() for c in item.split(', ') if c.strip()]
|
||||
flat_capability.extend(split_items)
|
||||
|
||||
unique_flat_capability = list(dict.fromkeys(flat_capability))
|
||||
capability_list = unique_flat_capability
|
||||
|
||||
api_logger.error(f"获取模型type_list: {type_list}")
|
||||
query = model_schema.ModelConfigQuery(
|
||||
type=type_list,
|
||||
provider=provider,
|
||||
capability=capability_list,
|
||||
is_active=is_active,
|
||||
is_public=is_public,
|
||||
search=search,
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
from typing import Optional
|
||||
import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import APIRouter, Depends,Header
|
||||
from fastapi import APIRouter, Depends, Header
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.language_utils import get_language_from_header
|
||||
@@ -19,7 +19,7 @@ from app.services.user_memory_service import (
|
||||
analytics_graph_data,
|
||||
analytics_community_graph_data,
|
||||
)
|
||||
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
|
||||
from app.services.memory_entity_relationship_service import MemoryEntityService, MemoryEmotion, MemoryInteraction
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.memory_storage_schema import GenerateCacheRequest
|
||||
from app.repositories.workspace_repository import WorkspaceRepository
|
||||
@@ -45,9 +45,9 @@ router = APIRouter(
|
||||
|
||||
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
||||
async def get_memory_insight_report_api(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
获取缓存的记忆洞察报告
|
||||
@@ -73,10 +73,10 @@ async def get_memory_insight_report_api(
|
||||
|
||||
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
||||
async def get_user_summary_api(
|
||||
end_user_id: str,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
获取缓存的用户摘要
|
||||
@@ -102,7 +102,7 @@ async def get_user_summary_api(
|
||||
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
try:
|
||||
# 调用服务层获取缓存数据
|
||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language)
|
||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id, model_id, language)
|
||||
|
||||
if result["is_cached"]:
|
||||
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
|
||||
@@ -117,10 +117,10 @@ async def get_user_summary_api(
|
||||
|
||||
@router.post("/analytics/generate_cache", response_model=ApiResponse)
|
||||
async def generate_cache_api(
|
||||
request: GenerateCacheRequest,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
request: GenerateCacheRequest,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
手动触发缓存生成
|
||||
@@ -155,10 +155,12 @@ async def generate_cache_api(
|
||||
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
|
||||
|
||||
# 生成记忆洞察
|
||||
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id, language=language)
|
||||
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id,
|
||||
language=language)
|
||||
|
||||
# 生成用户摘要
|
||||
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id, language=language)
|
||||
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id,
|
||||
language=language)
|
||||
|
||||
# 构建响应
|
||||
result = {
|
||||
@@ -209,9 +211,9 @@ async def generate_cache_api(
|
||||
|
||||
@router.get("/analytics/node_statistics", response_model=ApiResponse)
|
||||
async def get_node_statistics_api(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -220,7 +222,8 @@ async def get_node_statistics_api(
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
|
||||
api_logger.info(
|
||||
f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
|
||||
|
||||
try:
|
||||
# 调用新的记忆类型统计函数
|
||||
@@ -228,21 +231,23 @@ async def get_node_statistics_api(
|
||||
|
||||
# 计算总数用于日志
|
||||
total_count = sum(item["count"] for item in result)
|
||||
api_logger.info(f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
|
||||
api_logger.info(
|
||||
f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/graph_data", response_model=ApiResponse)
|
||||
async def get_graph_data_api(
|
||||
end_user_id: str,
|
||||
node_types: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
depth: int = 1,
|
||||
center_node_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
node_types: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
depth: int = 1,
|
||||
center_node_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -298,9 +303,9 @@ async def get_graph_data_api(
|
||||
|
||||
@router.get("/analytics/community_graph", response_model=ApiResponse)
|
||||
async def get_community_graph_data_api(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -334,9 +339,9 @@ async def get_community_graph_data_api(
|
||||
|
||||
@router.get("/read_end_user/profile", response_model=ApiResponse)
|
||||
async def get_end_user_profile(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
@@ -385,9 +390,9 @@ async def get_end_user_profile(
|
||||
|
||||
@router.post("/updated_end_user/profile", response_model=ApiResponse)
|
||||
async def update_end_user_profile(
|
||||
profile_update: EndUserProfileUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
profile_update: EndUserProfileUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
更新终端用户的基本信息
|
||||
@@ -427,15 +432,18 @@ async def update_end_user_profile(
|
||||
# 只有未预期的错误才使用 INTERNAL_ERROR
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg)
|
||||
|
||||
|
||||
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
|
||||
async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
async def memory_space_timeline_of_shared_memories(
|
||||
id: str, label: str,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
workspace_id=current_user.current_workspace_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
|
||||
@@ -447,11 +455,13 @@ async def memory_space_timeline_of_shared_memories(id: str, label: str,language_
|
||||
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language)
|
||||
|
||||
return success(data=timeline_memories_result, msg="共同记忆时间线")
|
||||
|
||||
|
||||
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
|
||||
async def memory_space_relationship_evolution(id: str, label: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
try:
|
||||
api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}")
|
||||
|
||||
|
||||
@@ -598,8 +598,10 @@ class LangChainAgent:
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens",
|
||||
0) if response_meta else 0
|
||||
total_tokens = response_meta.get("token_usage", {}).get(
|
||||
"total_tokens",
|
||||
0
|
||||
) if response_meta else 0
|
||||
yield total_tokens
|
||||
break
|
||||
if memory_flag:
|
||||
|
||||
@@ -11,7 +11,7 @@ async def get_chunked_dialogs(
|
||||
chunker_strategy: str = "RecursiveChunker",
|
||||
end_user_id: str = "group_1",
|
||||
messages: list = None,
|
||||
ref_id: str = "wyl_20251027",
|
||||
ref_id: str = "",
|
||||
config_id: str = None
|
||||
) -> List[DialogData]:
|
||||
"""Generate chunks from structured messages using the specified chunker strategy.
|
||||
@@ -40,12 +40,13 @@ async def get_chunked_dialogs(
|
||||
|
||||
role = msg['role']
|
||||
content = msg['content']
|
||||
files = msg.get("file_content", [])
|
||||
|
||||
if role not in ['user', 'assistant']:
|
||||
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
|
||||
|
||||
if content.strip():
|
||||
conversation_messages.append(ConversationMessage(role=role, msg=content.strip()))
|
||||
conversation_messages.append(ConversationMessage(role=role, msg=content.strip(), files=files))
|
||||
|
||||
if not conversation_messages:
|
||||
raise ValueError("Message list cannot be empty after filtering")
|
||||
|
||||
@@ -6,6 +6,7 @@ pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
@@ -13,7 +14,8 @@ from dotenv import load_dotenv
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import memory_summary_generation
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import \
|
||||
memory_summary_generation
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.memory.utils.log.logging_utils import log_time
|
||||
from app.db import get_db_context
|
||||
@@ -23,18 +25,17 @@ from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
async def write(
|
||||
end_user_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
messages: list,
|
||||
ref_id: str = "wyl20251027",
|
||||
language: str = "zh",
|
||||
end_user_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
messages: list,
|
||||
ref_id: str = "",
|
||||
language: str = "zh",
|
||||
) -> None:
|
||||
"""
|
||||
Execute the complete knowledge extraction pipeline.
|
||||
@@ -43,9 +44,11 @@ async def write(
|
||||
end_user_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
ref_id: Reference ID, defaults to "wyl20251027"
|
||||
ref_id: Reference ID, defaults to ""
|
||||
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
|
||||
"""
|
||||
if not ref_id:
|
||||
ref_id = uuid.uuid4().hex
|
||||
# Extract config values
|
||||
embedding_model_id = str(memory_config.embedding_model_id)
|
||||
chunker_strategy = memory_config.chunker_strategy
|
||||
@@ -135,9 +138,11 @@ async def write(
|
||||
all_chunk_nodes,
|
||||
all_statement_nodes,
|
||||
all_entity_nodes,
|
||||
all_perceptual_nodes,
|
||||
all_statement_chunk_edges,
|
||||
all_statement_entity_edges,
|
||||
all_entity_entity_edges,
|
||||
all_perceptual_edges,
|
||||
all_dedup_details,
|
||||
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
|
||||
|
||||
@@ -162,9 +167,11 @@ async def write(
|
||||
chunk_nodes=all_chunk_nodes,
|
||||
statement_nodes=all_statement_nodes,
|
||||
entity_nodes=all_entity_nodes,
|
||||
perceptual_nodes=all_perceptual_nodes,
|
||||
statement_chunk_edges=all_statement_chunk_edges,
|
||||
statement_entity_edges=all_statement_entity_edges,
|
||||
entity_edges=all_entity_entity_edges,
|
||||
perceptual_edges=all_perceptual_edges,
|
||||
connector=neo4j_connector,
|
||||
)
|
||||
if success:
|
||||
@@ -173,7 +180,8 @@ async def write(
|
||||
await _trigger_clustering_sync(
|
||||
all_entity_nodes,
|
||||
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
||||
embedding_model_id=str(
|
||||
memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
||||
)
|
||||
break
|
||||
else:
|
||||
@@ -208,9 +216,8 @@ async def write(
|
||||
summaries = await memory_summary_generation(
|
||||
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language
|
||||
)
|
||||
|
||||
ms_connector = Neo4jConnector()
|
||||
try:
|
||||
ms_connector = Neo4jConnector()
|
||||
await add_memory_summary_nodes(summaries, ms_connector)
|
||||
await add_memory_summary_statement_edges(summaries, ms_connector)
|
||||
finally:
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from typing import Any, List
|
||||
import re
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import numpy as np
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Fix tokenizer parallelism warning
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
@@ -246,6 +246,7 @@ class ChunkerClient:
|
||||
"total_sub_chunks": len(sub_chunks),
|
||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||
},
|
||||
files=msg.files
|
||||
)
|
||||
dialogue.chunks.append(chunk)
|
||||
else:
|
||||
@@ -258,6 +259,7 @@ class ChunkerClient:
|
||||
"message_role": msg.role,
|
||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||
},
|
||||
files=msg.files
|
||||
)
|
||||
dialogue.chunks.append(chunk)
|
||||
|
||||
|
||||
@@ -114,7 +114,7 @@ class Edge(BaseModel):
|
||||
end_user_id: str = Field(..., description="The end user ID of the edge.")
|
||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
|
||||
expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.")
|
||||
expired_at: Optional[datetime] = Field(default=None, description="The expired time of the edge from system perspective.")
|
||||
|
||||
|
||||
class ChunkEdge(Edge):
|
||||
@@ -175,6 +175,12 @@ class EntityEntityEdge(Edge):
|
||||
return parse_historical_datetime(v)
|
||||
|
||||
|
||||
class PerceptualEdge(Edge):
|
||||
"""Edge connecting perceptual nodes to their source chunks
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Node(BaseModel):
|
||||
"""Base class for all graph nodes in the knowledge graph.
|
||||
|
||||
@@ -206,7 +212,8 @@ class DialogueNode(Node):
|
||||
ref_id: str = Field(..., description="Reference identifier of the dialog")
|
||||
content: str = Field(..., description="Dialogue content")
|
||||
dialog_embedding: Optional[List[float]] = Field(None, description="Dialog embedding vector")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialogue (integer or string)")
|
||||
config_id: Optional[int | str] = Field(None,
|
||||
description="Configuration ID used to process this dialogue (integer or string)")
|
||||
|
||||
|
||||
class StatementNode(Node):
|
||||
@@ -281,7 +288,8 @@ class StatementNode(Node):
|
||||
statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector")
|
||||
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
|
||||
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)")
|
||||
config_id: Optional[int | str] = Field(None,
|
||||
description="Configuration ID used to process this statement (integer or string)")
|
||||
|
||||
# ACT-R Memory Activation Properties
|
||||
importance_score: float = Field(
|
||||
@@ -416,7 +424,8 @@ class ExtractedEntityNode(Node):
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
||||
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
|
||||
config_id: Optional[int | str] = Field(None,
|
||||
description="Configuration ID used to process this entity (integer or string)")
|
||||
|
||||
# ACT-R Memory Activation Properties
|
||||
importance_score: float = Field(
|
||||
@@ -453,7 +462,7 @@ class ExtractedEntityNode(Node):
|
||||
|
||||
@field_validator('aliases', mode='before')
|
||||
@classmethod
|
||||
def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
|
||||
def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
|
||||
"""Validate and clean aliases field using utility function.
|
||||
|
||||
This validator ensures that the aliases field is always a valid list of strings.
|
||||
@@ -507,7 +516,8 @@ class MemorySummaryNode(Node):
|
||||
memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory")
|
||||
summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary")
|
||||
metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)")
|
||||
config_id: Optional[int | str] = Field(None,
|
||||
description="Configuration ID used to process this summary (integer or string)")
|
||||
|
||||
# ACT-R Forgetting Engine Properties
|
||||
original_statement_id: Optional[str] = Field(
|
||||
@@ -549,3 +559,18 @@ class MemorySummaryNode(Node):
|
||||
ge=0,
|
||||
description="Total number of times this node has been accessed (reset to 1 on creation)"
|
||||
)
|
||||
|
||||
|
||||
class PerceptualNode(Node):
|
||||
"""Node representing a multimodal message in the knowledge graph.
|
||||
"""
|
||||
perceptual_type: int
|
||||
file_path: str
|
||||
file_name: str
|
||||
file_ext: str
|
||||
summary: str
|
||||
keywords: list[str]
|
||||
topic: str
|
||||
domain: str
|
||||
file_type: str
|
||||
summary_embedding: list[float] | None
|
||||
|
||||
@@ -30,6 +30,7 @@ class ConversationMessage(BaseModel):
|
||||
"""
|
||||
role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').")
|
||||
msg: str = Field(..., description="The text content of the message.")
|
||||
files: list[tuple] = Field(default_factory=list, description="The file content of the message", exclude=True)
|
||||
|
||||
|
||||
class TemporalValidityRange(BaseModel):
|
||||
@@ -130,7 +131,8 @@ class Chunk(BaseModel):
|
||||
content: str = Field(..., description="The content of the chunk as a string.")
|
||||
speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).")
|
||||
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
|
||||
chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.")
|
||||
files: list[tuple] = Field(default_factory=list, description="List of files in the chunk.")
|
||||
chunk_embedding: Optional[List[float]] = Field(default=None, description="The embedding vector of the chunk.")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -25,17 +25,17 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
async def dedup_layers_and_merge_and_return(
|
||||
dialogue_nodes: List[DialogueNode],
|
||||
chunk_nodes: List[ChunkNode],
|
||||
statement_nodes: List[StatementNode],
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
dialog_data_list: List[DialogData],
|
||||
pipeline_config: ExtractionPipelineConfig,
|
||||
connector: Optional[Neo4jConnector] = None,
|
||||
llm_client = None,
|
||||
dialogue_nodes: List[DialogueNode],
|
||||
chunk_nodes: List[ChunkNode],
|
||||
statement_nodes: List[StatementNode],
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
dialog_data_list: List[DialogData],
|
||||
pipeline_config: ExtractionPipelineConfig,
|
||||
connector: Optional[Neo4jConnector] = None,
|
||||
llm_client=None,
|
||||
) -> Tuple[
|
||||
List[DialogueNode],
|
||||
List[ChunkNode],
|
||||
@@ -44,7 +44,7 @@ async def dedup_layers_and_merge_and_return(
|
||||
List[StatementChunkEdge],
|
||||
List[StatementEntityEdge],
|
||||
List[EntityEntityEdge],
|
||||
dict, # 新增:返回去重详情
|
||||
dict
|
||||
]:
|
||||
"""
|
||||
执行两层实体去重与融合:
|
||||
|
||||
@@ -32,10 +32,11 @@ from app.core.memory.models.graph_models import (
|
||||
StatementChunkEdge,
|
||||
StatementEntityEdge,
|
||||
StatementNode,
|
||||
PerceptualEdge,
|
||||
PerceptualNode
|
||||
)
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||
from app.core.memory.models.variate_config import (
|
||||
ExtractionPipelineConfig,
|
||||
)
|
||||
@@ -46,7 +47,6 @@ from app.core.memory.storage_services.extraction_engine.knowledge_extraction.emb
|
||||
embedding_generation,
|
||||
generate_entity_embeddings_from_triplets,
|
||||
)
|
||||
|
||||
# 导入各个提取模块
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.statement_extraction import (
|
||||
StatementExtractor,
|
||||
@@ -90,16 +90,16 @@ class ExtractionOrchestrator:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: LLMClient,
|
||||
embedder_client: OpenAIEmbedderClient,
|
||||
connector: Neo4jConnector,
|
||||
config: Optional[ExtractionPipelineConfig] = None,
|
||||
progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None,
|
||||
embedding_id: Optional[str] = None,
|
||||
ontology_types: Optional[OntologyTypeList] = None,
|
||||
enable_general_types: bool = True,
|
||||
language: str = "zh",
|
||||
self,
|
||||
llm_client: LLMClient,
|
||||
embedder_client: OpenAIEmbedderClient,
|
||||
connector: Neo4jConnector,
|
||||
config: Optional[ExtractionPipelineConfig] = None,
|
||||
progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None,
|
||||
embedding_id: Optional[str] = None,
|
||||
ontology_types: Optional[OntologyTypeList] = None,
|
||||
enable_general_types: bool = True,
|
||||
language: str = "zh",
|
||||
):
|
||||
"""
|
||||
初始化流水线编排器
|
||||
@@ -157,19 +157,27 @@ class ExtractionOrchestrator:
|
||||
llm_client=llm_client,
|
||||
config=self.config.statement_extraction,
|
||||
)
|
||||
self.triplet_extractor = TripletExtractor(llm_client=llm_client,ontology_types=self.ontology_types, language=language)
|
||||
self.triplet_extractor = TripletExtractor(llm_client=llm_client, ontology_types=self.ontology_types,
|
||||
language=language)
|
||||
self.temporal_extractor = TemporalExtractor(llm_client=llm_client)
|
||||
|
||||
logger.info("ExtractionOrchestrator 初始化完成")
|
||||
|
||||
async def run(
|
||||
self,
|
||||
dialog_data_list: List[DialogData],
|
||||
is_pilot_run: bool = False,
|
||||
) -> Tuple[
|
||||
Tuple[List[DialogueNode], List[ChunkNode], List[StatementNode]],
|
||||
Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]],
|
||||
Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]],
|
||||
self,
|
||||
dialog_data_list: List[DialogData],
|
||||
is_pilot_run: bool = False,
|
||||
) -> tuple[
|
||||
list[DialogueNode],
|
||||
list[ChunkNode],
|
||||
list[StatementNode],
|
||||
list[ExtractedEntityNode],
|
||||
list[PerceptualNode],
|
||||
list[StatementChunkEdge],
|
||||
list[StatementEntityEdge],
|
||||
list[EntityEntityEdge],
|
||||
list[PerceptualEdge],
|
||||
dict
|
||||
]:
|
||||
"""
|
||||
运行完整的知识提取流水线(优化版:并行执行)
|
||||
@@ -208,7 +216,6 @@ class ExtractionOrchestrator:
|
||||
for dialog in dialog_data_list:
|
||||
for chunk in dialog.chunks:
|
||||
all_statements_list.extend(chunk.statements)
|
||||
len(all_statements_list)
|
||||
|
||||
# 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成
|
||||
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成")
|
||||
@@ -230,10 +237,6 @@ class ExtractionOrchestrator:
|
||||
all_entities_list.extend(triplet_info.entities)
|
||||
all_triplets_list.extend(triplet_info.triplets)
|
||||
|
||||
len(all_entities_list)
|
||||
len(all_triplets_list)
|
||||
sum(len(temporal_map) for temporal_map in temporal_maps)
|
||||
|
||||
# 步骤 3: 生成实体嵌入(依赖三元组提取结果)
|
||||
logger.info("步骤 3/6: 生成实体嵌入")
|
||||
triplet_maps = await self._generate_entity_embeddings(triplet_maps)
|
||||
@@ -260,9 +263,11 @@ class ExtractionOrchestrator:
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
perceptual_nodes,
|
||||
statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
perceptual_edges
|
||||
) = await self._create_nodes_and_edges(dialog_data_list)
|
||||
|
||||
# 导出去重前的测试输入文档(试运行和正式模式都需要,用于生成结果汇总)
|
||||
@@ -276,7 +281,16 @@ class ExtractionOrchestrator:
|
||||
|
||||
# 注意:deduplication 消息已在创建节点和边完成后立即发送
|
||||
|
||||
result = await self._run_dedup_and_write_summary(
|
||||
(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
dialog_data_list,
|
||||
) = await self._run_dedup_and_write_summary(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
@@ -287,17 +301,26 @@ class ExtractionOrchestrator:
|
||||
dialog_data_list,
|
||||
)
|
||||
|
||||
|
||||
|
||||
logger.info(f"知识提取流水线运行完成({mode_str})")
|
||||
return result
|
||||
return (
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
perceptual_nodes,
|
||||
statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
perceptual_edges,
|
||||
dialog_data_list,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"知识提取流水线运行失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def _extract_statements(
|
||||
self, dialog_data_list: List[DialogData]
|
||||
self, dialog_data_list: List[DialogData]
|
||||
) -> List[DialogData]:
|
||||
"""
|
||||
从对话中提取陈述句(流式输出版本:边提取边发送进度)
|
||||
@@ -395,7 +418,7 @@ class ExtractionOrchestrator:
|
||||
return dialog_data_list
|
||||
|
||||
async def _extract_triplets(
|
||||
self, dialog_data_list: List[DialogData]
|
||||
self, dialog_data_list: List[DialogData]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从对话中提取三元组(流式输出版本:边提取边发送进度)
|
||||
@@ -478,7 +501,7 @@ class ExtractionOrchestrator:
|
||||
return triplet_maps
|
||||
|
||||
async def _extract_temporal(
|
||||
self, dialog_data_list: List[DialogData]
|
||||
self, dialog_data_list: List[DialogData]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从对话中提取时间信息(流式输出版本:边提取边发送进度)
|
||||
@@ -585,7 +608,7 @@ class ExtractionOrchestrator:
|
||||
return temporal_maps
|
||||
|
||||
async def _extract_emotions(
|
||||
self, dialog_data_list: List[DialogData]
|
||||
self, dialog_data_list: List[DialogData]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从对话中提取情绪信息(仅针对用户消息,全局陈述句级并行)
|
||||
@@ -706,7 +729,7 @@ class ExtractionOrchestrator:
|
||||
return emotion_maps
|
||||
|
||||
async def _parallel_extract_and_embed(
|
||||
self, dialog_data_list: List[DialogData]
|
||||
self, dialog_data_list: List[DialogData]
|
||||
) -> Tuple[
|
||||
List[Dict[str, Any]],
|
||||
List[Dict[str, Any]],
|
||||
@@ -777,7 +800,7 @@ class ExtractionOrchestrator:
|
||||
)
|
||||
|
||||
async def _generate_basic_embeddings(
|
||||
self, dialog_data_list: List[DialogData]
|
||||
self, dialog_data_list: List[DialogData]
|
||||
) -> Tuple[List[Dict[str, List[float]]], List[Dict[str, List[float]]], List[List[float]]]:
|
||||
"""
|
||||
生成基础嵌入向量(陈述句、分块、对话)
|
||||
@@ -836,7 +859,7 @@ class ExtractionOrchestrator:
|
||||
)
|
||||
|
||||
async def _generate_entity_embeddings(
|
||||
self, triplet_maps: List[Dict[str, Any]]
|
||||
self, triplet_maps: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
生成实体嵌入向量
|
||||
@@ -874,17 +897,15 @@ class ExtractionOrchestrator:
|
||||
logger.error(f"实体嵌入生成失败: {e}", exc_info=True)
|
||||
return triplet_maps
|
||||
|
||||
|
||||
|
||||
async def _assign_extracted_data(
|
||||
self,
|
||||
dialog_data_list: List[DialogData],
|
||||
temporal_maps: List[Dict[str, Any]],
|
||||
triplet_maps: List[Dict[str, Any]],
|
||||
emotion_maps: List[Dict[str, Any]],
|
||||
statement_embedding_maps: List[Dict[str, List[float]]],
|
||||
chunk_embedding_maps: List[Dict[str, List[float]]],
|
||||
dialog_embeddings: List[List[float]],
|
||||
self,
|
||||
dialog_data_list: List[DialogData],
|
||||
temporal_maps: List[Dict[str, Any]],
|
||||
triplet_maps: List[Dict[str, Any]],
|
||||
emotion_maps: List[Dict[str, Any]],
|
||||
statement_embedding_maps: List[Dict[str, List[float]]],
|
||||
chunk_embedding_maps: List[Dict[str, List[float]]],
|
||||
dialog_embeddings: List[List[float]],
|
||||
) -> List[DialogData]:
|
||||
"""
|
||||
将提取的数据赋值到语句
|
||||
@@ -906,12 +927,12 @@ class ExtractionOrchestrator:
|
||||
# 确保列表长度匹配
|
||||
expected_length = len(dialog_data_list)
|
||||
if (
|
||||
len(temporal_maps) != expected_length
|
||||
or len(triplet_maps) != expected_length
|
||||
or len(emotion_maps) != expected_length
|
||||
or len(statement_embedding_maps) != expected_length
|
||||
or len(chunk_embedding_maps) != expected_length
|
||||
or len(dialog_embeddings) != expected_length
|
||||
len(temporal_maps) != expected_length
|
||||
or len(triplet_maps) != expected_length
|
||||
or len(emotion_maps) != expected_length
|
||||
or len(statement_embedding_maps) != expected_length
|
||||
or len(chunk_embedding_maps) != expected_length
|
||||
or len(dialog_embeddings) != expected_length
|
||||
):
|
||||
logger.warning(
|
||||
f"数据大小不匹配 - 对话: {len(dialog_data_list)}, "
|
||||
@@ -999,15 +1020,17 @@ class ExtractionOrchestrator:
|
||||
return dialog_data_list
|
||||
|
||||
async def _create_nodes_and_edges(
|
||||
self, dialog_data_list: List[DialogData]
|
||||
self, dialog_data_list: List[DialogData]
|
||||
) -> Tuple[
|
||||
List[DialogueNode],
|
||||
List[ChunkNode],
|
||||
List[StatementNode],
|
||||
List[ExtractedEntityNode],
|
||||
List[PerceptualNode],
|
||||
List[StatementChunkEdge],
|
||||
List[StatementEntityEdge],
|
||||
List[EntityEntityEdge],
|
||||
List[PerceptualEdge]
|
||||
]:
|
||||
"""
|
||||
创建图数据库节点和边
|
||||
@@ -1031,6 +1054,8 @@ class ExtractionOrchestrator:
|
||||
statement_chunk_edges = []
|
||||
statement_entity_edges = []
|
||||
entity_entity_edges = []
|
||||
perceptual_nodes = []
|
||||
perceptual_edges = []
|
||||
|
||||
# 用于去重的集合
|
||||
entity_id_set = set()
|
||||
@@ -1074,6 +1099,46 @@ class ExtractionOrchestrator:
|
||||
metadata=chunk.metadata,
|
||||
)
|
||||
chunk_nodes.append(chunk_node)
|
||||
logger.error(f"chunk file: {chunk.files}")
|
||||
|
||||
for p, file_type in chunk.files:
|
||||
|
||||
meta = p.meta_data or {}
|
||||
content_meta = meta.get("content", {})
|
||||
|
||||
# 生成 summary embedding(如果有 embedder_client)
|
||||
summary_embedding = None
|
||||
if self.embedder_client and p.summary:
|
||||
try:
|
||||
summary_embedding = (await self.embedder_client.response([p.summary]))[0]
|
||||
except Exception as emb_err:
|
||||
print(f"Failed to embed perceptual summary: {emb_err}")
|
||||
|
||||
perceptual = PerceptualNode(
|
||||
name=f"Perceptual_{p.id}",
|
||||
**{
|
||||
"id": str(p.id),
|
||||
"end_user_id": str(p.end_user_id),
|
||||
"perceptual_type": p.perceptual_type,
|
||||
"file_path": p.file_path or "",
|
||||
"file_name": p.file_name or "",
|
||||
"file_ext": p.file_ext or "",
|
||||
"summary": p.summary or "",
|
||||
"keywords": content_meta.get("keywords", []),
|
||||
"topic": content_meta.get("topic", ""),
|
||||
"domain": content_meta.get("domain", ""),
|
||||
"created_at": p.created_time.isoformat() if p.created_time else None,
|
||||
"file_type": file_type,
|
||||
"summary_embedding": summary_embedding,
|
||||
})
|
||||
perceptual_nodes.append(perceptual)
|
||||
perceptual_edges.append(PerceptualEdge(
|
||||
source=perceptual.id,
|
||||
target=chunk.id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id,
|
||||
created_at=dialog_data.created_at,
|
||||
))
|
||||
|
||||
# 处理每个陈述句
|
||||
for statement in chunk.statements:
|
||||
@@ -1083,15 +1148,19 @@ class ExtractionOrchestrator:
|
||||
name=f"Statement_{statement.id}", # 添加必需的 name 字段
|
||||
chunk_id=chunk.id,
|
||||
stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段
|
||||
temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段
|
||||
connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
|
||||
temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL),
|
||||
# 添加必需的 temporal_info 字段
|
||||
connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong',
|
||||
# 添加必需的 connect_strength 字段
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||
statement=statement.statement,
|
||||
speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段
|
||||
statement_embedding=statement.statement_embedding,
|
||||
valid_at=statement.temporal_validity.valid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None,
|
||||
invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None,
|
||||
valid_at=statement.temporal_validity.valid_at if hasattr(statement,
|
||||
'temporal_validity') and statement.temporal_validity else None,
|
||||
invalid_at=statement.temporal_validity.invalid_at if hasattr(statement,
|
||||
'temporal_validity') and statement.temporal_validity else None,
|
||||
created_at=dialog_data.created_at,
|
||||
expired_at=dialog_data.expired_at,
|
||||
config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None,
|
||||
@@ -1141,7 +1210,8 @@ class ExtractionOrchestrator:
|
||||
example=getattr(entity, 'example', ''), # 新增:传递示例字段
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
|
||||
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
|
||||
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong',
|
||||
# 添加必需的 connect_strength 字段
|
||||
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
|
||||
name_embedding=getattr(entity, 'name_embedding', None),
|
||||
is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记
|
||||
@@ -1248,25 +1318,32 @@ class ExtractionOrchestrator:
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
perceptual_nodes,
|
||||
statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
perceptual_edges
|
||||
)
|
||||
|
||||
async def _run_dedup_and_write_summary(
|
||||
self,
|
||||
dialogue_nodes: List[DialogueNode],
|
||||
chunk_nodes: List[ChunkNode],
|
||||
statement_nodes: List[StatementNode],
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
dialog_data_list: List[DialogData],
|
||||
) -> Tuple[
|
||||
Tuple[List[DialogueNode], List[ChunkNode], List[StatementNode]],
|
||||
Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]],
|
||||
Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]],
|
||||
self,
|
||||
dialogue_nodes: List[DialogueNode],
|
||||
chunk_nodes: List[ChunkNode],
|
||||
statement_nodes: List[StatementNode],
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
dialog_data_list: List[DialogData],
|
||||
) -> tuple[
|
||||
list[DialogueNode],
|
||||
list[ChunkNode],
|
||||
list[StatementNode],
|
||||
list[ExtractedEntityNode],
|
||||
list[StatementChunkEdge],
|
||||
list[StatementEntityEdge],
|
||||
list[EntityEntityEdge],
|
||||
dict
|
||||
]:
|
||||
"""
|
||||
执行两阶段去重并写入汇总
|
||||
@@ -1415,7 +1492,6 @@ class ExtractionOrchestrator:
|
||||
len(entity_entity_edges), len(final_entity_entity_edges)
|
||||
)
|
||||
|
||||
|
||||
# 写入提取结果汇总(试运行和正式模式都需要生成)
|
||||
try:
|
||||
from app.core.config import settings
|
||||
@@ -1436,10 +1512,10 @@ class ExtractionOrchestrator:
|
||||
raise
|
||||
|
||||
def _save_dedup_details(
|
||||
self,
|
||||
dedup_details: Dict[str, Any],
|
||||
original_entities: List[ExtractedEntityNode],
|
||||
final_entities: List[ExtractedEntityNode]
|
||||
self,
|
||||
dedup_details: Dict[str, Any],
|
||||
original_entities: List[ExtractedEntityNode],
|
||||
final_entities: List[ExtractedEntityNode]
|
||||
):
|
||||
"""
|
||||
保存去重消歧的详细记录到实例变量(基于内存数据结构)
|
||||
@@ -1537,15 +1613,16 @@ class ExtractionOrchestrator:
|
||||
except Exception as e:
|
||||
logger.debug(f"解析消歧记录失败: {record}, 错误: {e}")
|
||||
|
||||
logger.info(f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录")
|
||||
logger.info(
|
||||
f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存去重消歧详情失败: {e}", exc_info=True)
|
||||
|
||||
async def _analyze_entity_merges(
|
||||
self,
|
||||
original_entities: List[ExtractedEntityNode],
|
||||
final_entities: List[ExtractedEntityNode]
|
||||
self,
|
||||
original_entities: List[ExtractedEntityNode],
|
||||
final_entities: List[ExtractedEntityNode]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
分析实体合并情况,直接使用内存中的合并记录(不再解析日志文件)
|
||||
@@ -1585,9 +1662,9 @@ class ExtractionOrchestrator:
|
||||
return []
|
||||
|
||||
async def _analyze_entity_disambiguation(
|
||||
self,
|
||||
original_entities: List[ExtractedEntityNode],
|
||||
final_entities: List[ExtractedEntityNode]
|
||||
self,
|
||||
original_entities: List[ExtractedEntityNode],
|
||||
final_entities: List[ExtractedEntityNode]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
分析实体消歧情况,直接使用内存中的消歧记录(不再解析日志文件)
|
||||
@@ -1645,9 +1722,9 @@ class ExtractionOrchestrator:
|
||||
return type_mapping.get(entity_type, f"{entity_type}实体节点")
|
||||
|
||||
async def _output_relationship_creation_results(
|
||||
self,
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
entity_nodes: List[ExtractedEntityNode]
|
||||
self,
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
entity_nodes: List[ExtractedEntityNode]
|
||||
):
|
||||
"""
|
||||
输出关系创建结果
|
||||
@@ -1681,13 +1758,13 @@ class ExtractionOrchestrator:
|
||||
logger.error(f"输出关系创建结果失败: {e}", exc_info=True)
|
||||
|
||||
async def _send_dedup_progress_callback(
|
||||
self,
|
||||
original_entities: int,
|
||||
final_entities: int,
|
||||
original_stmt_edges: int,
|
||||
final_stmt_edges: int,
|
||||
original_ent_edges: int,
|
||||
final_ent_edges: int,
|
||||
self,
|
||||
original_entities: int,
|
||||
final_entities: int,
|
||||
original_stmt_edges: int,
|
||||
final_stmt_edges: int,
|
||||
original_ent_edges: int,
|
||||
final_ent_edges: int,
|
||||
):
|
||||
"""
|
||||
发送去重消歧完成的进度回调,传递具体的去重和消歧效果
|
||||
@@ -1715,7 +1792,8 @@ class ExtractionOrchestrator:
|
||||
"original_count": original_entities,
|
||||
"final_count": final_entities,
|
||||
"reduced_count": entities_reduced,
|
||||
"reduction_rate": round(entities_reduced / original_entities * 100, 1) if original_entities > 0 else 0,
|
||||
"reduction_rate": round(entities_reduced / original_entities * 100,
|
||||
1) if original_entities > 0 else 0,
|
||||
},
|
||||
"statement_entity_edges": {
|
||||
"original_count": original_stmt_edges,
|
||||
@@ -1790,7 +1868,8 @@ class ExtractionOrchestrator:
|
||||
|
||||
disamb_examples.append({
|
||||
"entity1_name": entity_name,
|
||||
"entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:", "").strip() if "vs" in disamb_type else "未知",
|
||||
"entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:",
|
||||
"").strip() if "vs" in disamb_type else "未知",
|
||||
"entity2_name": entity_name,
|
||||
"entity2_type": disamb_type.split("vs")[1].strip() if "vs" in disamb_type else "未知",
|
||||
"description": f"{entity_name},消歧区分成功"
|
||||
@@ -1815,9 +1894,9 @@ class ExtractionOrchestrator:
|
||||
|
||||
|
||||
async def get_chunked_dialogs(
|
||||
chunker_strategy: str = "RecursiveChunker",
|
||||
end_user_id: str = "group_1",
|
||||
indices: Optional[List[int]] = None,
|
||||
chunker_strategy: str = "RecursiveChunker",
|
||||
end_user_id: str = "group_1",
|
||||
indices: Optional[List[int]] = None,
|
||||
) -> List[DialogData]:
|
||||
"""从测试数据生成分块对话
|
||||
|
||||
@@ -1924,10 +2003,10 @@ async def get_chunked_dialogs(
|
||||
|
||||
|
||||
def preprocess_data(
|
||||
input_path: Optional[str] = None,
|
||||
output_path: Optional[str] = None,
|
||||
skip_cleaning: bool = True,
|
||||
indices: Optional[List[int]] = None
|
||||
input_path: Optional[str] = None,
|
||||
output_path: Optional[str] = None,
|
||||
skip_cleaning: bool = True,
|
||||
indices: Optional[List[int]] = None
|
||||
) -> List[DialogData]:
|
||||
"""数据预处理
|
||||
|
||||
@@ -1946,7 +2025,8 @@ def preprocess_data(
|
||||
)
|
||||
preprocessor = DataPreprocessor()
|
||||
try:
|
||||
cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path, skip_cleaning=skip_cleaning, indices=indices)
|
||||
cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path,
|
||||
skip_cleaning=skip_cleaning, indices=indices)
|
||||
logger.debug(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据")
|
||||
return cleaned_data
|
||||
except Exception as e:
|
||||
@@ -1955,9 +2035,9 @@ def preprocess_data(
|
||||
|
||||
|
||||
async def get_chunked_dialogs_from_preprocessed(
|
||||
data: List[DialogData],
|
||||
chunker_strategy: str = "RecursiveChunker",
|
||||
llm_client: Optional[Any] = None,
|
||||
data: List[DialogData],
|
||||
chunker_strategy: str = "RecursiveChunker",
|
||||
llm_client: Optional[Any] = None,
|
||||
) -> List[DialogData]:
|
||||
"""从预处理后的数据中生成分块
|
||||
|
||||
@@ -1988,15 +2068,15 @@ async def get_chunked_dialogs_from_preprocessed(
|
||||
|
||||
|
||||
async def get_chunked_dialogs_with_preprocessing(
|
||||
chunker_strategy: str = "RecursiveChunker",
|
||||
end_user_id: str = "default",
|
||||
user_id: str = "default",
|
||||
apply_id: str = "default",
|
||||
indices: Optional[List[int]] = None,
|
||||
input_data_path: Optional[str] = None,
|
||||
llm_client: Optional[Any] = None,
|
||||
skip_cleaning: bool = True,
|
||||
pruning_config: Optional[Dict] = None,
|
||||
chunker_strategy: str = "RecursiveChunker",
|
||||
end_user_id: str = "default",
|
||||
user_id: str = "default",
|
||||
apply_id: str = "default",
|
||||
indices: Optional[List[int]] = None,
|
||||
input_data_path: Optional[str] = None,
|
||||
llm_client: Optional[Any] = None,
|
||||
skip_cleaning: bool = True,
|
||||
pruning_config: Optional[Dict] = None,
|
||||
) -> List[DialogData]:
|
||||
"""包含数据预处理步骤的完整分块流程
|
||||
|
||||
@@ -2046,7 +2126,8 @@ async def get_chunked_dialogs_with_preprocessing(
|
||||
if pruning_config:
|
||||
# 使用传入的配置
|
||||
config = PruningConfig(**pruning_config)
|
||||
logger.debug(f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}")
|
||||
logger.debug(
|
||||
f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}")
|
||||
else:
|
||||
# 使用默认配置(关闭剪枝)
|
||||
config = None
|
||||
|
||||
@@ -188,7 +188,6 @@ async def _process_chunk_summary(
|
||||
response_model=MemorySummaryResponse,
|
||||
)
|
||||
summary_text = structured.summary.strip()
|
||||
|
||||
# Generate title and type for the summary
|
||||
title = None
|
||||
episodic_type = None
|
||||
|
||||
@@ -390,6 +390,8 @@ class GraphBuilder:
|
||||
for edge in self.edges:
|
||||
source = edge.get("source")
|
||||
target = edge.get("target")
|
||||
if source not in self.reachable_nodes or target not in self.reachable_nodes:
|
||||
continue
|
||||
condition = edge.get("condition")
|
||||
edge_type = edge.get("type")
|
||||
|
||||
|
||||
@@ -12,14 +12,26 @@ class ExecutionContext(BaseModel):
|
||||
execution_id: str
|
||||
workspace_id: str
|
||||
user_id: str
|
||||
memory_storage_type: str
|
||||
user_rag_memory_id: str
|
||||
checkpoint_config: RunnableConfig
|
||||
|
||||
@classmethod
|
||||
def create(cls, execution_id: str, workspace_id: str, user_id: str):
|
||||
def create(
|
||||
cls,
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str,
|
||||
memory_storage_type: str,
|
||||
user_rag_memory_id: str
|
||||
):
|
||||
return cls(
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id,
|
||||
memory_storage_type=memory_storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
|
||||
checkpoint_config=RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": uuid.uuid4(),
|
||||
|
||||
@@ -33,6 +33,8 @@ class WorkflowState(dict):
|
||||
"workspace_id",
|
||||
"user_id",
|
||||
"activate",
|
||||
"memory_storage_type",
|
||||
"user_rag_memory_id"
|
||||
})
|
||||
__optional_keys__ = frozenset({
|
||||
"error",
|
||||
@@ -62,6 +64,9 @@ class WorkflowState(dict):
|
||||
# node activate status
|
||||
activate: Annotated[dict[str, bool], merge_activate_state]
|
||||
|
||||
memory_storage_type: str
|
||||
user_rag_memory_id: str
|
||||
|
||||
|
||||
class WorkflowStateManager:
|
||||
def create_initial_state(
|
||||
@@ -85,7 +90,9 @@ class WorkflowStateManager:
|
||||
looping=0,
|
||||
activate={
|
||||
start_node_id: True
|
||||
}
|
||||
},
|
||||
memory_storage_type=execution_context.memory_storage_type,
|
||||
user_rag_memory_id=execution_context.user_rag_memory_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -13,7 +13,7 @@ from pydantic import BaseModel
|
||||
|
||||
from app.core.workflow.engine.runtime_schema import ExecutionContext
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
from app.core.workflow.variable.variable_objects import T, create_variable_instance
|
||||
from app.core.workflow.variable.variable_objects import T, create_variable_instance, ArrayVariable, FileVariable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -373,6 +373,16 @@ class VariablePool:
|
||||
def copy(self, pool: 'VariablePool'):
|
||||
self.variables = deepcopy(pool.variables)
|
||||
|
||||
def is_file_variable(self, selector):
|
||||
variable_struct = self.get_instance(selector, default=None, strict=False)
|
||||
if variable_struct is None:
|
||||
return False
|
||||
if isinstance(variable_struct, FileVariable):
|
||||
return True
|
||||
elif isinstance(variable_struct, ArrayVariable) and variable_struct.child_type == FileVariable:
|
||||
return True
|
||||
return False
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""导出为字典
|
||||
|
||||
|
||||
@@ -409,7 +409,9 @@ async def execute_workflow(
|
||||
input_data: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str
|
||||
user_id: str,
|
||||
memory_storage_type: str,
|
||||
user_rag_memory_id: str
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Execute a workflow (convenience function, non-streaming).
|
||||
@@ -420,6 +422,8 @@ async def execute_workflow(
|
||||
execution_id (str): Execution ID.
|
||||
workspace_id (str): Workspace ID.
|
||||
user_id (str): User ID.
|
||||
user_rag_memory_id: rag knowledge db id
|
||||
memory_storage_type: neo4j / rag
|
||||
|
||||
Returns:
|
||||
dict: Workflow execution result.
|
||||
@@ -427,7 +431,9 @@ async def execute_workflow(
|
||||
execution_context = ExecutionContext.create(
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id
|
||||
user_id=user_id,
|
||||
memory_storage_type=memory_storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
executor = WorkflowExecutor(
|
||||
workflow_config=workflow_config,
|
||||
@@ -441,7 +447,9 @@ async def execute_workflow_stream(
|
||||
input_data: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str
|
||||
user_id: str,
|
||||
memory_storage_type: str,
|
||||
user_rag_memory_id: str
|
||||
):
|
||||
"""
|
||||
Execute a workflow in streaming mode (convenience function).
|
||||
@@ -452,6 +460,8 @@ async def execute_workflow_stream(
|
||||
execution_id (str): Execution ID.
|
||||
workspace_id (str): Workspace ID.
|
||||
user_id (str): User ID.
|
||||
user_rag_memory_id: rag knowledge db id
|
||||
memory_storage_type: neo4j / rag
|
||||
|
||||
Yields:
|
||||
dict: Streaming workflow events, e.g. node start, node end, chunk messages, workflow end.
|
||||
@@ -459,7 +469,9 @@ async def execute_workflow_stream(
|
||||
execution_context = ExecutionContext.create(
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id
|
||||
user_id=user_id,
|
||||
memory_storage_type=memory_storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
executor = WorkflowExecutor(
|
||||
workflow_config=workflow_config,
|
||||
|
||||
@@ -623,7 +623,6 @@ class BaseNode(ABC):
|
||||
async def process_message(
|
||||
api_config: ModelInfo,
|
||||
content: str | dict | FileObject,
|
||||
end_user_id: str,
|
||||
enable_file=False
|
||||
) -> list | str | None:
|
||||
provider = api_config.provider
|
||||
@@ -642,8 +641,8 @@ class BaseNode(ABC):
|
||||
return content
|
||||
|
||||
elif isinstance(content, FileObject):
|
||||
if content.content_cache.get(provider):
|
||||
return content.content_cache[provider]
|
||||
if content.content_cache.get(f"{provider}_{ModelInfo.is_omni}"):
|
||||
return content.content_cache[f"{provider}_{ModelInfo.is_omni}"]
|
||||
with get_db_read() as db:
|
||||
multimodel_service = MultimodalService(db, api_config=api_config)
|
||||
file_obj = FileInput(
|
||||
@@ -655,12 +654,11 @@ class BaseNode(ABC):
|
||||
)
|
||||
file_obj.set_content(content.get_content())
|
||||
message = await multimodel_service.process_files(
|
||||
end_user_id,
|
||||
[file_obj],
|
||||
)
|
||||
content.set_content(file_obj.get_content())
|
||||
if message:
|
||||
content.content_cache[provider] = message
|
||||
content.content_cache[f"{provider}_{ModelInfo.is_omni}"] = message
|
||||
return message
|
||||
return None
|
||||
raise TypeError(f'Unexpect input value type - {type(content)}')
|
||||
|
||||
@@ -144,7 +144,6 @@ class LLMNode(BaseNode):
|
||||
f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}")
|
||||
|
||||
messages_config = self.typed_config.messages
|
||||
|
||||
if messages_config:
|
||||
# 使用 LangChain 消息格式
|
||||
messages = []
|
||||
@@ -153,7 +152,6 @@ class LLMNode(BaseNode):
|
||||
content_template = msg_config.content
|
||||
content_template = self._render_context(content_template, variable_pool)
|
||||
content = self._render_template(content_template, variable_pool)
|
||||
user_id = self.get_variable("sys.user_id", variable_pool)
|
||||
# 根据角色创建对应的消息对象
|
||||
if role == "system":
|
||||
messages.append({
|
||||
@@ -161,32 +159,31 @@ class LLMNode(BaseNode):
|
||||
"content": await self.process_message(
|
||||
model_info,
|
||||
content,
|
||||
user_id,
|
||||
self.typed_config.vision,
|
||||
)
|
||||
})
|
||||
elif role in ["user", "human"]:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
||||
"content": await self.process_message(model_info, content, self.typed_config.vision)
|
||||
})
|
||||
elif role in ["ai", "assistant"]:
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
||||
"content": await self.process_message(model_info, content, self.typed_config.vision)
|
||||
})
|
||||
else:
|
||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
||||
"content": await self.process_message(model_info, content, self.typed_config.vision)
|
||||
})
|
||||
|
||||
if self.typed_config.vision_input and self.typed_config.vision:
|
||||
file_content = []
|
||||
files = variable_pool.get_instance(self.typed_config.vision_input)
|
||||
for file in files.value:
|
||||
content = await self.process_message(model_info, file.value, user_id, self.typed_config.vision)
|
||||
content = await self.process_message(model_info, file.value, self.typed_config.vision)
|
||||
if content:
|
||||
file_content.extend(content)
|
||||
if messages and messages[-1]["role"] == 'user':
|
||||
@@ -200,7 +197,7 @@ class LLMNode(BaseNode):
|
||||
if isinstance(message["content"], list):
|
||||
file_content = []
|
||||
for file in message["content"]:
|
||||
content = await self.process_message(model_info, file, user_id, self.typed_config.vision)
|
||||
content = await self.process_message(model_info, file, self.typed_config.vision)
|
||||
if content:
|
||||
file_content.extend(content)
|
||||
history_message.append(
|
||||
@@ -210,7 +207,6 @@ class LLMNode(BaseNode):
|
||||
message["content"] = await self.process_message(
|
||||
model_info,
|
||||
message["content"],
|
||||
user_id,
|
||||
self.typed_config.vision
|
||||
)
|
||||
history_message.append(message)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
@@ -5,7 +6,9 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
|
||||
from app.db import get_db_read
|
||||
from app.schemas import FileInput
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.tasks import write_message_task
|
||||
|
||||
@@ -36,8 +39,8 @@ class MemoryReadNode(BaseNode):
|
||||
search_switch=self.typed_config.search_switch,
|
||||
history=[],
|
||||
db=db,
|
||||
storage_type="neo4j",
|
||||
user_rag_memory_id=""
|
||||
storage_type=state["memory_storage_type"],
|
||||
user_rag_memory_id=state["user_rag_memory_id"]
|
||||
)
|
||||
|
||||
|
||||
@@ -49,6 +52,19 @@ class MemoryWriteNode(BaseNode):
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {"output": VariableType.STRING}
|
||||
|
||||
@staticmethod
|
||||
def _extract_multimodal_memory_variables(content: str, variable_pool: VariablePool) -> tuple[list[str], str]:
|
||||
variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}'
|
||||
variable_pattern = re.compile(variable_pattern_string)
|
||||
variables = variable_pattern.findall(content)
|
||||
file_variables = []
|
||||
for variable in variables:
|
||||
if variable_pool.is_file_variable(variable):
|
||||
file_variables.append(variable)
|
||||
for var in file_variables:
|
||||
content = content.replace(var, "")
|
||||
return file_variables, content
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
self.typed_config = MemoryWriteNodeConfig(**self.config)
|
||||
end_user_id = self.get_variable("sys.user_id", variable_pool)
|
||||
@@ -63,17 +79,42 @@ class MemoryWriteNode(BaseNode):
|
||||
})
|
||||
|
||||
for message in self.typed_config.messages:
|
||||
file_variables, content = self._extract_multimodal_memory_variables(
|
||||
message.content,
|
||||
variable_pool
|
||||
)
|
||||
file_info = []
|
||||
for var in file_variables:
|
||||
instence: FileVariable | ArrayVariable[FileVariable] = variable_pool.get_instance(var)
|
||||
if isinstance(instence, FileVariable):
|
||||
file_info.append(FileInput(
|
||||
type=instence.value.type,
|
||||
transfer_method=instence.value.transfer_method,
|
||||
upload_file_id=instence.value.file_id,
|
||||
url=instence.value.url,
|
||||
file_type=instence.value.origin_file_type
|
||||
).model_dump())
|
||||
elif isinstance(instence, ArrayVariable) and instence.child_type == FileVariable:
|
||||
for file_instence in instence.value:
|
||||
file_info.append(FileInput(
|
||||
type=file_instence.value.type,
|
||||
transfer_method=file_instence.value.transfer_method,
|
||||
upload_file_id=file_instence.value.file_id,
|
||||
url=file_instence.value.url,
|
||||
file_type=file_instence.value.origin_file_type
|
||||
).model_dump())
|
||||
messages.append({
|
||||
"role": message.role,
|
||||
"content": self._render_template(message.content, variable_pool)
|
||||
"content": self._render_template(content, variable_pool),
|
||||
"files": file_info
|
||||
})
|
||||
|
||||
write_message_task.delay(
|
||||
end_user_id,
|
||||
messages,
|
||||
str(self.typed_config.config_id),
|
||||
"neo4j",
|
||||
""
|
||||
end_user_id=end_user_id,
|
||||
message=messages,
|
||||
config_id=str(self.typed_config.config_id),
|
||||
storage_type=state["memory_storage_type"],
|
||||
user_rag_memory_id=state["user_rag_memory_id"]
|
||||
)
|
||||
|
||||
return "success"
|
||||
|
||||
@@ -30,6 +30,9 @@ class MemoryConfig(Base):
|
||||
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
|
||||
embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
|
||||
rerank_id = Column(String, nullable=True, comment="重排序模型配置ID")
|
||||
vision_id = Column(String, nullable=True, comment="视觉模型配置ID")
|
||||
audio_id = Column(String, nullable=True, comment="语音模型配置ID")
|
||||
video_id = Column(String, nullable=True, comment="视频模型配置ID")
|
||||
|
||||
# 记忆萃取引擎配置
|
||||
enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重")
|
||||
|
||||
@@ -2,10 +2,11 @@ import datetime
|
||||
import uuid
|
||||
from enum import StrEnum
|
||||
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table, text
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSON
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, UniqueConstraint, Integer, Table, text
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSON, ARRAY
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.db import Base
|
||||
|
||||
|
||||
|
||||
@@ -9,21 +9,22 @@ Classes:
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import desc, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_config_logger, get_db_logger
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
from app.models.workspace_model import Workspace
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigKey,
|
||||
ConfigParamsCreate,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
ConfigUpdateForget,
|
||||
)
|
||||
from sqlalchemy import desc, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
# 获取数据库专用日志器
|
||||
@@ -157,7 +158,7 @@ class MemoryConfigRepository:
|
||||
return memory_config_obj
|
||||
|
||||
@staticmethod
|
||||
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID|int|str) -> MemoryConfig:
|
||||
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID | int | str) -> MemoryConfig:
|
||||
"""构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数)
|
||||
|
||||
Args:
|
||||
@@ -309,57 +310,21 @@ class MemoryConfigRepository:
|
||||
|
||||
Returns:
|
||||
Optional[MemoryConfig]: 更新后的配置对象,不存在则返回None
|
||||
|
||||
Raises:
|
||||
ValueError: 没有字段需要更新时抛出
|
||||
"""
|
||||
db_logger.debug(f"更新萃取配置: config_id={update.config_id}")
|
||||
|
||||
try:
|
||||
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == update.config_id).first()
|
||||
stmt = select(MemoryConfig).where(MemoryConfig.config_id == update.config_id)
|
||||
db_config = db.execute(stmt).scalar_one_or_none()
|
||||
if not db_config:
|
||||
db_logger.warning(f"记忆配置不存在: config_id={update.config_id}")
|
||||
return None
|
||||
|
||||
# 更新字段映射
|
||||
field_mapping = {
|
||||
# 模型选择
|
||||
"llm_id": "llm_id",
|
||||
"embedding_id": "embedding_id",
|
||||
"rerank_id": "rerank_id",
|
||||
# 记忆萃取引擎
|
||||
"enable_llm_dedup_blockwise": "enable_llm_dedup_blockwise",
|
||||
"enable_llm_disambiguation": "enable_llm_disambiguation",
|
||||
"deep_retrieval": "deep_retrieval",
|
||||
"t_type_strict": "t_type_strict",
|
||||
"t_name_strict": "t_name_strict",
|
||||
"t_overall": "t_overall",
|
||||
"state": "state",
|
||||
"chunker_strategy": "chunker_strategy",
|
||||
# 句子提取
|
||||
"statement_granularity": "statement_granularity",
|
||||
"include_dialogue_context": "include_dialogue_context",
|
||||
"max_context": "max_context",
|
||||
# 剪枝配置
|
||||
"pruning_enabled": "pruning_enabled",
|
||||
"pruning_scene": "pruning_scene",
|
||||
"pruning_threshold": "pruning_threshold",
|
||||
# 自我反思配置
|
||||
"enable_self_reflexion": "enable_self_reflexion",
|
||||
"iteration_period": "iteration_period",
|
||||
"reflexion_range": "reflexion_range",
|
||||
"baseline": "baseline",
|
||||
}
|
||||
update_data = update.model_dump(exclude_unset=True)
|
||||
update_data.pop("config_id", None)
|
||||
|
||||
has_update = False
|
||||
for api_field, db_field in field_mapping.items():
|
||||
value = getattr(update, api_field, None)
|
||||
if value is not None:
|
||||
setattr(db_config, db_field, value)
|
||||
has_update = True
|
||||
|
||||
if not has_update:
|
||||
raise ValueError("No fields to update")
|
||||
for field, value in update_data.items():
|
||||
setattr(db_config, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_config)
|
||||
@@ -443,6 +408,9 @@ class MemoryConfigRepository:
|
||||
"llm_id": db_config.llm_id,
|
||||
"embedding_id": db_config.embedding_id,
|
||||
"rerank_id": db_config.rerank_id,
|
||||
"vision_id": db_config.vision_id,
|
||||
"audio_id": db_config.audio_id,
|
||||
"video_id": db_config.video_id,
|
||||
"enable_llm_dedup_blockwise": db_config.enable_llm_dedup_blockwise,
|
||||
"enable_llm_disambiguation": db_config.enable_llm_disambiguation,
|
||||
"deep_retrieval": db_config.deep_retrieval,
|
||||
@@ -527,7 +495,10 @@ class MemoryConfigRepository:
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_config_with_workspace(db: Session, config_id: uuid.UUID | int | str) -> Optional[tuple]:
|
||||
def get_config_with_workspace(
|
||||
db: Session,
|
||||
config_id: uuid.UUID | int | str
|
||||
) -> Optional[tuple[MemoryConfig, Workspace]]:
|
||||
"""Get memory config and its associated workspace information
|
||||
|
||||
Args:
|
||||
@@ -542,8 +513,6 @@ class MemoryConfigRepository:
|
||||
"""
|
||||
import time
|
||||
|
||||
from app.models.workspace_model import Workspace
|
||||
|
||||
start_time = time.time()
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
|
||||
@@ -630,7 +599,7 @@ class MemoryConfigRepository:
|
||||
|
||||
db_logger.debug(
|
||||
f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
|
||||
return (config, workspace)
|
||||
return config, workspace
|
||||
|
||||
except ValueError:
|
||||
# Re-raise known business exceptions
|
||||
@@ -775,9 +744,9 @@ class MemoryConfigRepository:
|
||||
|
||||
@staticmethod
|
||||
def get_with_fallback(
|
||||
db: Session,
|
||||
config_id: Optional[uuid.UUID],
|
||||
workspace_id: uuid.UUID
|
||||
db: Session,
|
||||
config_id: Optional[uuid.UUID],
|
||||
workspace_id: uuid.UUID
|
||||
) -> Optional[MemoryConfig]:
|
||||
"""获取记忆配置,支持回退到工作空间默认配置
|
||||
|
||||
@@ -807,4 +776,3 @@ class MemoryConfigRepository:
|
||||
)
|
||||
|
||||
return MemoryConfigRepository.get_workspace_default(db, workspace_id)
|
||||
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
from sqlalchemy.orm import Session, joinedload, selectinload
|
||||
from sqlalchemy import and_, or_, func, desc, select
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
import uuid
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
|
||||
from sqlalchemy import and_, or_, func, desc
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from app.core.logging_config import get_db_logger
|
||||
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelBase, model_config_api_key_association
|
||||
from app.schemas.model_schema import (
|
||||
ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
|
||||
ModelConfigQuery, ModelConfigQueryNew
|
||||
)
|
||||
from app.core.logging_config import get_db_logger
|
||||
|
||||
# 获取数据库专用日志器
|
||||
db_logger = get_db_logger()
|
||||
@@ -137,6 +138,9 @@ class ModelConfigRepository:
|
||||
type_values.append(ModelType.LLM)
|
||||
filters.append(ModelConfig.type.in_(type_values))
|
||||
|
||||
if query.capability:
|
||||
filters.append(ModelConfig.capability.contains(query.capability))
|
||||
|
||||
if query.is_active is not None:
|
||||
filters.append(ModelConfig.is_active == query.is_active)
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from typing import List, Optional
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE,MEMORY_SUMMARY_NODE_SAVE
|
||||
from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode
|
||||
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE, \
|
||||
MEMORY_SUMMARY_NODE_SAVE
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
@@ -12,9 +13,10 @@ logger = logging.getLogger(__name__)
|
||||
async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector):
|
||||
"""Delete all nodes in the database."""
|
||||
result = await connector.execute_query(f"MATCH (n {{end_user_id: '{end_user_id}'}}) DETACH DELETE n")
|
||||
print(f"All end_user_id: {end_user_id} node and edge deleted successfully")
|
||||
logger.warning(f"All end_user_id: {end_user_id} node and edge deleted successfully")
|
||||
return result
|
||||
|
||||
|
||||
async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||
"""Add dialogue nodes to Neo4j database.
|
||||
|
||||
@@ -26,7 +28,7 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn
|
||||
List of created node UUIDs or None if failed
|
||||
"""
|
||||
if not dialogues:
|
||||
print("No dialogues to save")
|
||||
logger.info("No dialogues to save")
|
||||
return []
|
||||
|
||||
try:
|
||||
@@ -51,11 +53,11 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn
|
||||
)
|
||||
|
||||
created_uuids = [record["uuid"] for record in result]
|
||||
print(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}")
|
||||
logger.info(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}")
|
||||
return created_uuids
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating dialogue nodes: {e}")
|
||||
logger.error(f"Error creating dialogue nodes: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -70,7 +72,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
|
||||
List of created node UUIDs or None if failed
|
||||
"""
|
||||
if not statements:
|
||||
print("No statements to save")
|
||||
logger.info("No statements to save")
|
||||
return []
|
||||
|
||||
try:
|
||||
@@ -123,13 +125,14 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
|
||||
)
|
||||
|
||||
created_uuids = [record["uuid"] for record in result]
|
||||
print(f"Successfully created {len(created_uuids)} statement nodes")
|
||||
logger.info(f"Successfully created {len(created_uuids)} statement nodes")
|
||||
return created_uuids
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating statement nodes: {e}")
|
||||
logger.error(f"Error creating statement nodes: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||
"""Add chunk nodes to Neo4j in batch.
|
||||
|
||||
@@ -141,7 +144,7 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
|
||||
List of created chunk UUIDs or None if failed
|
||||
"""
|
||||
if not chunks:
|
||||
print("No chunk nodes to add")
|
||||
logger.info("No chunk nodes to add")
|
||||
return []
|
||||
|
||||
try:
|
||||
@@ -174,16 +177,18 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
|
||||
)
|
||||
|
||||
created_uuids = [record["uuid"] for record in result]
|
||||
print(f"Successfully created {len(created_uuids)} chunk nodes")
|
||||
logger.info(f"Successfully created {len(created_uuids)} chunk nodes")
|
||||
return created_uuids
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating chunk nodes: {e}")
|
||||
logger.error(f"Error creating chunk nodes: {e}")
|
||||
return None
|
||||
|
||||
|
||||
|
||||
async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||
async def add_memory_summary_nodes(
|
||||
summaries: List[MemorySummaryNode],
|
||||
connector: Neo4jConnector
|
||||
) -> Optional[List[str]]:
|
||||
"""Add memory summary nodes to Neo4j in batch.
|
||||
|
||||
Args:
|
||||
@@ -194,7 +199,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
|
||||
List of created summary node ids or None if failed
|
||||
"""
|
||||
if not summaries:
|
||||
print("No memory summary nodes to add")
|
||||
logger.info("No memory summary nodes to add")
|
||||
return []
|
||||
|
||||
try:
|
||||
@@ -225,5 +230,3 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save MemorySummary nodes to Neo4j: {e}")
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -709,7 +709,6 @@ SET r.end_user_id = e.end_user_id,
|
||||
RETURN elementId(r) AS uuid
|
||||
"""
|
||||
|
||||
|
||||
# Entity Merge Query
|
||||
MERGE_ENTITIES = """
|
||||
MATCH (canonical:ExtractedEntity {id: $canonical_id})
|
||||
@@ -829,9 +828,8 @@ neo4j_query_all = """
|
||||
other as entity2
|
||||
"""
|
||||
|
||||
|
||||
'''针对当前节点下扩长的句子,实体和总结'''
|
||||
Memory_Timeline_ExtractedEntity="""
|
||||
Memory_Timeline_ExtractedEntity = """
|
||||
MATCH (n)-[r1]-(e)-[r2]-(ms)
|
||||
WHERE elementId(n) = $id
|
||||
AND (ms:ExtractedEntity OR ms:MemorySummary)
|
||||
@@ -869,7 +867,7 @@ RETURN
|
||||
|
||||
|
||||
"""
|
||||
Memory_Timeline_MemorySummary="""
|
||||
Memory_Timeline_MemorySummary = """
|
||||
MATCH (n)-[r1]-(e)-[r2]-(ms)
|
||||
WHERE elementId(n) =$id
|
||||
AND (ms:MemorySummary OR ms:ExtractedEntity)
|
||||
@@ -904,7 +902,7 @@ RETURN
|
||||
}
|
||||
) AS statement;
|
||||
"""
|
||||
Memory_Timeline_Statement="""
|
||||
Memory_Timeline_Statement = """
|
||||
MATCH (n)
|
||||
WHERE elementId(n) = $id
|
||||
|
||||
@@ -947,7 +945,7 @@ RETURN
|
||||
"""
|
||||
|
||||
'''针对当前节点,主要获取更加完整的句子节点'''
|
||||
Memory_Space_Emotion_Statement="""
|
||||
Memory_Space_Emotion_Statement = """
|
||||
MATCH (n)
|
||||
WHERE elementId(n) = $id
|
||||
RETURN
|
||||
@@ -957,7 +955,7 @@ RETURN
|
||||
n.statement AS statement;
|
||||
|
||||
"""
|
||||
Memory_Space_Emotion_MemorySummary="""
|
||||
Memory_Space_Emotion_MemorySummary = """
|
||||
MATCH (n)-[]-(e)
|
||||
WHERE elementId(n) = $id
|
||||
AND EXISTS {
|
||||
@@ -970,7 +968,7 @@ RETURN DISTINCT
|
||||
e.emotion_type AS emotion_type,
|
||||
e.statement AS statement;
|
||||
"""
|
||||
Memory_Space_Emotion_ExtractedEntity="""
|
||||
Memory_Space_Emotion_ExtractedEntity = """
|
||||
MATCH (n)-[]-(e)
|
||||
WHERE elementId(n) = $id
|
||||
AND EXISTS {
|
||||
@@ -985,18 +983,18 @@ RETURN DISTINCT
|
||||
|
||||
'''获取实体'''
|
||||
|
||||
Memory_Space_User="""
|
||||
Memory_Space_User = """
|
||||
MATCH (n)-[r]->(m)
|
||||
WHERE n.end_user_id = $end_user_id AND m.name="用户"
|
||||
return DISTINCT elementId(m) as id
|
||||
"""
|
||||
Memory_Space_Entity="""
|
||||
Memory_Space_Entity = """
|
||||
MATCH (n)-[]-(m)
|
||||
WHERE elementId(m) = $id AND m.entity_type = "Person"
|
||||
RETURN
|
||||
DISTINCT m.name as name,m.end_user_id as end_user_id
|
||||
"""
|
||||
Memory_Space_Associative="""
|
||||
Memory_Space_Associative = """
|
||||
MATCH (u)-[]-(x)-[]-(h)
|
||||
WHERE elementId(u) = $user_id
|
||||
AND elementId(h) = $id
|
||||
@@ -1005,61 +1003,69 @@ RETURN DISTINCT
|
||||
"""
|
||||
|
||||
Graph_Node_query = """
|
||||
MATCH (n:MemorySummary)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
0 AS priority
|
||||
LIMIT $limit
|
||||
MATCH (n:MemorySummary)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
0 AS priority
|
||||
LIMIT $limit
|
||||
|
||||
UNION ALL
|
||||
UNION ALL
|
||||
|
||||
MATCH (n:Dialogue)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
1 AS priority
|
||||
LIMIT 1
|
||||
MATCH (n:Dialogue)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
1 AS priority
|
||||
LIMIT 1
|
||||
|
||||
UNION ALL
|
||||
UNION ALL
|
||||
|
||||
MATCH (n:Statement)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
1 AS priority
|
||||
LIMIT $limit
|
||||
MATCH (n:Statement)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
1 AS priority
|
||||
LIMIT $limit
|
||||
|
||||
UNION ALL
|
||||
UNION ALL
|
||||
|
||||
MATCH (n:ExtractedEntity)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
2 AS priority
|
||||
LIMIT $limit
|
||||
MATCH (n:ExtractedEntity)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
2 AS priority
|
||||
LIMIT $limit
|
||||
|
||||
UNION ALL
|
||||
UNION ALL
|
||||
|
||||
MATCH (n:Chunk)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
3 AS priority
|
||||
LIMIT $limit
|
||||
MATCH (n:Chunk)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
3 AS priority
|
||||
LIMIT $limit
|
||||
|
||||
"""
|
||||
UNION ALL
|
||||
MATCH (n:Perceptual)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
4 AS priority
|
||||
|
||||
"""
|
||||
|
||||
# ============================================================
|
||||
# Community 节点 & BELONGS_TO_COMMUNITY 边
|
||||
@@ -1363,3 +1369,36 @@ RETURN s.statement AS statement,
|
||||
ORDER BY COALESCE(s.activation_value, 0) DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
# 感知记忆节点保存
|
||||
PERCEPTUAL_NODE_SAVE = """
|
||||
UNWIND $perceptuals AS p
|
||||
MERGE (n:Perceptual {id: p.id})
|
||||
SET n += {
|
||||
id: p.id,
|
||||
end_user_id: p.end_user_id,
|
||||
perceptual_type: p.perceptual_type,
|
||||
file_path: p.file_path,
|
||||
file_name: p.file_name,
|
||||
file_ext: p.file_ext,
|
||||
summary: p.summary,
|
||||
keywords: p.keywords,
|
||||
topic: p.topic,
|
||||
domain: p.domain,
|
||||
created_at: p.created_at,
|
||||
file_type: p.file_type,
|
||||
summary_embedding: p.summary_embedding
|
||||
}
|
||||
RETURN n.id AS uuid
|
||||
"""
|
||||
|
||||
# 感知记忆与对话的关联边
|
||||
PERCEPTUAL_CHUNK_EDGE_SAVE = """
|
||||
UNWIND $edges AS edge
|
||||
MATCH (p:Perceptual {id: edge.perceptual_id, end_user_id: edge.end_user_id})
|
||||
MATCH (c:Chunk {id: edge.chunk_id, end_user_id: edge.end_user_id})
|
||||
MERGE (c)-[r:HAS_PERCEPTUAL]->(p)
|
||||
ON CREATE SET r.end_user_id = edge.end_user_id,
|
||||
r.created_at = edge.created_at
|
||||
RETURN elementId(r) AS uuid
|
||||
"""
|
||||
|
||||
@@ -22,13 +22,18 @@ from app.core.memory.models.graph_models import (
|
||||
StatementNode,
|
||||
ExtractedEntityNode,
|
||||
EntityEntityEdge,
|
||||
PerceptualNode,
|
||||
PerceptualEdge,
|
||||
)
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def save_entities_and_relationships(
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
connector: Neo4jConnector
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
connector: Neo4jConnector
|
||||
):
|
||||
"""Save entities and their relationships using graph models"""
|
||||
all_entities = [entity.model_dump() for entity in entity_nodes]
|
||||
@@ -73,8 +78,8 @@ async def save_entities_and_relationships(
|
||||
|
||||
|
||||
async def save_chunk_nodes(
|
||||
chunk_nodes: List[ChunkNode],
|
||||
connector: Neo4jConnector
|
||||
chunk_nodes: List[ChunkNode],
|
||||
connector: Neo4jConnector
|
||||
):
|
||||
"""Save chunk nodes using graph models"""
|
||||
if not chunk_nodes:
|
||||
@@ -89,8 +94,8 @@ async def save_chunk_nodes(
|
||||
|
||||
|
||||
async def save_statement_chunk_edges(
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
connector: Neo4jConnector
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
connector: Neo4jConnector
|
||||
):
|
||||
"""Save statement-chunk edges using graph models"""
|
||||
if not statement_chunk_edges:
|
||||
@@ -118,8 +123,8 @@ async def save_statement_chunk_edges(
|
||||
|
||||
|
||||
async def save_statement_entity_edges(
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
connector: Neo4jConnector
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
connector: Neo4jConnector
|
||||
):
|
||||
"""Save statement-entity edges using graph models"""
|
||||
if not statement_entity_edges:
|
||||
@@ -154,9 +159,11 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
chunk_nodes: List[ChunkNode],
|
||||
statement_nodes: List[StatementNode],
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
perceptual_nodes: List[PerceptualNode],
|
||||
entity_edges: List[EntityEntityEdge],
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
perceptual_edges: List[PerceptualEdge],
|
||||
connector: Neo4jConnector,
|
||||
) -> bool:
|
||||
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
||||
@@ -169,9 +176,11 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
chunk_nodes: List of ChunkNode objects to save
|
||||
statement_nodes: List of StatementNode objects to save
|
||||
entity_nodes: List of ExtractedEntityNode objects to save
|
||||
perceptual_nodes: List of PerceptualNode objects to save
|
||||
entity_edges: List of EntityEntityEdge objects to save
|
||||
statement_chunk_edges: List of StatementChunkEdge objects to save
|
||||
statement_entity_edges: List of StatementEntityEdge objects to save
|
||||
perceptual_edges: List of PerceptualEdge objects to save
|
||||
connector: Neo4j connector instance
|
||||
|
||||
Returns:
|
||||
@@ -190,7 +199,7 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data)
|
||||
dialogue_uuids = [record["uuid"] async for record in result]
|
||||
results['dialogues'] = dialogue_uuids
|
||||
print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
|
||||
logger.info(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
|
||||
|
||||
# 2. Save all chunk nodes in batch
|
||||
if chunk_nodes:
|
||||
@@ -201,6 +210,14 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
results['chunks'] = chunk_uuids
|
||||
logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j")
|
||||
|
||||
if perceptual_nodes:
|
||||
from app.repositories.neo4j.cypher_queries import PERCEPTUAL_NODE_SAVE
|
||||
perceptual_data = [node.model_dump() for node in perceptual_nodes]
|
||||
result = await tx.run(PERCEPTUAL_NODE_SAVE, perceptuals=perceptual_data)
|
||||
perceptual_uuids = [record["uuid"] async for record in result]
|
||||
results["perceptuals"] = perceptual_uuids
|
||||
logger.info(f"Successfully saved {len(perceptual_uuids)} perceptual nodes to Neo4j")
|
||||
|
||||
# 3. Save all statement nodes in batch
|
||||
if statement_nodes:
|
||||
from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE
|
||||
@@ -281,6 +298,22 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
results['statement_entity_edges'] = se_uuids
|
||||
logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j")
|
||||
|
||||
if perceptual_edges:
|
||||
from app.repositories.neo4j.cypher_queries import PERCEPTUAL_CHUNK_EDGE_SAVE
|
||||
perceptual_edge_data = []
|
||||
for edge in perceptual_edges:
|
||||
print(edge.source, edge.target)
|
||||
perceptual_edge_data.append({
|
||||
"perceptual_id": edge.source,
|
||||
"chunk_id": edge.target,
|
||||
"end_user_id": edge.end_user_id,
|
||||
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
||||
})
|
||||
result = await tx.run(PERCEPTUAL_CHUNK_EDGE_SAVE, edges=perceptual_edge_data)
|
||||
perceptual_edges_uuids = [record["uuid"] async for record in result]
|
||||
results['perceptual_chunk_edges'] = perceptual_edges_uuids
|
||||
logger.info(f"Successfully saved {len(perceptual_edges_uuids)} perceptual-chunk edges to Neo4j")
|
||||
|
||||
return results
|
||||
|
||||
try:
|
||||
@@ -304,9 +337,9 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
|
||||
|
||||
async def _trigger_clustering_sync(
|
||||
entity_nodes: List,
|
||||
llm_model_id: Optional[str] = None,
|
||||
embedding_model_id: Optional[str] = None,
|
||||
entity_nodes: List,
|
||||
llm_model_id: Optional[str] = None,
|
||||
embedding_model_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
同步等待聚类完成,避免与其他 LLM 任务并发冲突。
|
||||
@@ -322,14 +355,15 @@ async def _trigger_clustering_sync(
|
||||
end_user_id = entity_nodes[0].end_user_id
|
||||
new_entity_ids = [e.id for e in entity_nodes]
|
||||
logger.info(f"[Clustering] 准备触发聚类(同步),实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
|
||||
await _trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id)
|
||||
await _trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id,
|
||||
embedding_model_id=embedding_model_id)
|
||||
|
||||
|
||||
async def _trigger_clustering(
|
||||
new_entity_ids: List[str],
|
||||
end_user_id: str,
|
||||
llm_model_id: Optional[str] = None,
|
||||
embedding_model_id: Optional[str] = None,
|
||||
new_entity_ids: List[str],
|
||||
end_user_id: str,
|
||||
llm_model_id: Optional[str] = None,
|
||||
embedding_model_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
聚类触发函数,自动判断全量初始化还是增量更新。
|
||||
|
||||
@@ -387,6 +387,12 @@ class MemoryConfig:
|
||||
|
||||
rerank_model_id: Optional[UUID] = None
|
||||
rerank_model_name: Optional[str] = None
|
||||
video_model_id: Optional[UUID] = None
|
||||
video_model_name: Optional[str] = None
|
||||
vision_model_id: Optional[UUID] = None
|
||||
vision_model_name: Optional[str] = None
|
||||
audio_model_id: Optional[UUID] = None
|
||||
audio_model_name: Optional[str] = None
|
||||
|
||||
llm_params: Dict[str, Any] = field(default_factory=dict)
|
||||
embedding_params: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@@ -8,9 +8,6 @@ import uuid
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 从 json_schema.py 迁移的 Schema
|
||||
# ============================================================================
|
||||
@@ -58,10 +55,13 @@ class MemoryVerifySchema(BaseModel):
|
||||
|
||||
class ConflictResultSchema(BaseModel):
|
||||
"""Schema for the conflict result data in the reflexion_data.json file."""
|
||||
data: List[BaseDataSchema] = Field(..., description="The conflict memory data. Only contains conflicting records when conflict is True.")
|
||||
data: List[BaseDataSchema] = Field(...,
|
||||
description="The conflict memory data. Only contains conflicting records when conflict is True.")
|
||||
conflict: bool = Field(..., description="Whether the memory is in conflict.")
|
||||
quality_assessment: Optional[QualityAssessmentSchema] = Field(None, description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.")
|
||||
memory_verify: Optional[MemoryVerifySchema] = Field(None, description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.")
|
||||
quality_assessment: Optional[QualityAssessmentSchema] = Field(None,
|
||||
description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.")
|
||||
memory_verify: Optional[MemoryVerifySchema] = Field(None,
|
||||
description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
def _normalize_data(cls, v):
|
||||
@@ -105,12 +105,15 @@ class ChangeRecordSchema(BaseModel):
|
||||
description="List of field changes. First item: {id: value}, followed by changed fields as {field_name: [old_value, new_value]} or {field_name: new_value} or nested structures like {entity2: {field_name: [old, new]}}"
|
||||
)
|
||||
|
||||
|
||||
class ResolvedSchema(BaseModel):
|
||||
"""Schema for the resolved memory data in the reflexion_data"""
|
||||
original_memory_id: Optional[str] = Field(None, description="The original memory identifier.")
|
||||
# resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data (only contains records that need modification).")
|
||||
resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None, description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.")
|
||||
change: Optional[List[ChangeRecordSchema]] = Field(None, description="List of detailed change records with IDs and field information.")
|
||||
resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None,
|
||||
description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.")
|
||||
change: Optional[List[ChangeRecordSchema]] = Field(None,
|
||||
description="List of detailed change records with IDs and field information.")
|
||||
|
||||
|
||||
class SingleReflexionResultSchema(BaseModel):
|
||||
@@ -120,9 +123,11 @@ class SingleReflexionResultSchema(BaseModel):
|
||||
resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data for this conflict.")
|
||||
type: str = Field("reflexion_result", description="The type identifier.")
|
||||
|
||||
|
||||
class ReflexionResultSchema(BaseModel):
|
||||
"""Schema for the complete reflexion result data - a list of individual conflict resolutions."""
|
||||
results: List[SingleReflexionResultSchema] = Field(..., description="List of individual conflict resolution results, grouped by conflict type.")
|
||||
results: List[SingleReflexionResultSchema] = Field(...,
|
||||
description="List of individual conflict resolution results, grouped by conflict type.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
def _normalize_resolved(cls, v):
|
||||
@@ -147,9 +152,9 @@ class ReflexionResultSchema(BaseModel):
|
||||
# Composite key identifying a config row
|
||||
class ConfigKey(BaseModel): # 配置参数键模型
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
config_id:Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)")
|
||||
user_id: str = Field("user_id", description="用户标识(字符串)")
|
||||
apply_id: str = Field("apply_id", description="应用或场景标识(字符串)")
|
||||
config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)")
|
||||
user_id: str | None = Field(default=None, description="用户标识(字符串)")
|
||||
apply_id: str | None = Field(default=None, description="应用或场景标识(字符串)")
|
||||
|
||||
|
||||
# Allowed chunking strategies (extendable later)
|
||||
@@ -241,10 +246,12 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body,
|
||||
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
|
||||
reflection_model_id: Optional[str] = Field(None, description="反思模型ID,默认与llm_id一致")
|
||||
emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID,默认与llm_id一致")
|
||||
|
||||
|
||||
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
# config_name: str = Field("配置名称", description="配置名称(字符串)")
|
||||
config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID(支持UUID、整数或字符串)")
|
||||
config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID(支持UUID、整数或字符串)")
|
||||
|
||||
|
||||
class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
|
||||
@@ -255,8 +262,11 @@ class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用
|
||||
|
||||
|
||||
class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
|
||||
config_id:Union[uuid.UUID, int, str] = None
|
||||
config_id: Union[uuid.UUID, int, str] = None
|
||||
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
|
||||
audio_id: Optional[str] = Field(None, description="语音模型ID")
|
||||
vision_id: Optional[str] = Field(None, description="视觉模型ID")
|
||||
video_id: Optional[str] = Field(None, description="视频模型ID")
|
||||
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
|
||||
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
|
||||
enable_llm_dedup_blockwise: Optional[bool] = None
|
||||
@@ -322,14 +332,14 @@ class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数
|
||||
|
||||
class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用的模型
|
||||
# 遗忘引擎配置参数更新模型
|
||||
config_id:Union[uuid.UUID, int, str] = None
|
||||
config_id: Union[uuid.UUID, int, str] = None
|
||||
lambda_time: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="最低保持度,0-1 小数;默认 0.5")
|
||||
lambda_mem: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="遗忘率,0-1 小数;默认 0.5")
|
||||
offset: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="偏移度,0-1 小数;默认 0.0")
|
||||
|
||||
|
||||
class ConfigPilotRun(BaseModel): # 试运行触发请求模型
|
||||
config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID(唯一,支持UUID、整数或字符串)")
|
||||
config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID(唯一,支持UUID、整数或字符串)")
|
||||
dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填")
|
||||
custom_text: Optional[str] = Field(None, description="自定义输入文本,当配置关联本体场景时使用此字段进行试运行")
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
@@ -364,11 +374,11 @@ def ok(msg: str = "OK", data: Optional[Any] = None, time: Optional[int] = None)
|
||||
|
||||
|
||||
def fail(
|
||||
msg: str,
|
||||
error_code: str = "ERROR",
|
||||
data: Optional[Any] = None,
|
||||
time: Optional[int] = None,
|
||||
query_preview: Optional[str] = None,
|
||||
msg: str,
|
||||
error_code: str = "ERROR",
|
||||
data: Optional[Any] = None,
|
||||
time: Optional[int] = None,
|
||||
query_preview: Optional[str] = None,
|
||||
) -> ApiResponse:
|
||||
payload = data
|
||||
if query_preview is not None:
|
||||
@@ -387,6 +397,7 @@ def fail(
|
||||
time=time or _now_ms(),
|
||||
)
|
||||
|
||||
|
||||
class GenerateCacheRequest(BaseModel):
|
||||
"""缓存生成请求模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
@@ -432,7 +443,7 @@ class ForgettingConfigUpdateRequest(BaseModel):
|
||||
"""遗忘引擎配置更新请求模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
config_id: Union[uuid.UUID, int,str] = Field(..., description="配置唯一标识(UUID或int)")
|
||||
config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)")
|
||||
decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d")
|
||||
lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数")
|
||||
lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数")
|
||||
@@ -472,7 +483,8 @@ class ForgettingStatsResponse(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
activation_metrics: Dict[str, Any] = Field(..., description="激活值相关指标")
|
||||
node_distribution: Dict[str, int] = Field(..., description="节点类型分布")
|
||||
recent_trends: List[ForgettingCycleHistoryPoint] = Field(..., description="最近7个日期的遗忘趋势数据(每天取最后一次执行)")
|
||||
recent_trends: List[ForgettingCycleHistoryPoint] = Field(...,
|
||||
description="最近7个日期的遗忘趋势数据(每天取最后一次执行)")
|
||||
pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表(前20个满足遗忘条件的节点)")
|
||||
timestamp: int = Field(..., description="统计时间(时间戳)")
|
||||
|
||||
|
||||
@@ -81,6 +81,12 @@ class ModelConfig(ModelConfigBase):
|
||||
updated_at: datetime.datetime
|
||||
api_keys: List["ModelApiKey"] = []
|
||||
|
||||
@staticmethod
|
||||
def mask_api_key(key: str, prefix: int = 4, suffix: int = 4) -> str:
|
||||
if not key or len(key) <= prefix + suffix:
|
||||
return "*" * len(key)
|
||||
return key[:prefix] + "*" * (len(key) - prefix - suffix) + key[-suffix:]
|
||||
|
||||
@field_validator("api_keys", mode="after")
|
||||
@classmethod
|
||||
def filter_active_api_keys(cls, api_keys: List["ModelApiKey"]) -> List["ModelApiKey"]:
|
||||
@@ -90,6 +96,15 @@ class ModelConfig(ModelConfigBase):
|
||||
def _serialize_created_at(self, dt: datetime.datetime | None):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("api_keys", when_used="json")
|
||||
def _serialize_api_keys(self, api_keys: List["ModelApiKey"]):
|
||||
result = []
|
||||
for api_key in api_keys:
|
||||
data = api_key.model_dump()
|
||||
data["api_key"] = self.mask_api_key(api_key.api_key)
|
||||
result.append(data)
|
||||
return result
|
||||
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
@@ -166,8 +181,8 @@ class ModelApiKey(ModelApiKeyBase):
|
||||
self.model_config_ids = [
|
||||
mc.id for mc in self.model_configs
|
||||
if hasattr(mc, 'id')
|
||||
and not getattr(mc, 'is_composite', False)
|
||||
and getattr(mc, 'name', None) == self.model_name
|
||||
and not getattr(mc, 'is_composite', False)
|
||||
and getattr(mc, 'name', None) == self.model_name
|
||||
]
|
||||
# 情况2:字典列表
|
||||
elif isinstance(self.model_configs, list):
|
||||
@@ -193,7 +208,6 @@ class ModelApiKey(ModelApiKeyBase):
|
||||
validate_assignment=True # 确保赋值触发校验
|
||||
)
|
||||
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
@@ -211,6 +225,7 @@ class ModelConfigQuery(BaseModel):
|
||||
"""模型配置查询Schema"""
|
||||
type: Optional[List[ModelType]] = Field(None, description="模型类型筛选(支持多个)")
|
||||
provider: Optional[ModelProvider] = Field(None, description="提供商筛选(通过API Key)")
|
||||
capability: Optional[List[str]] = Field(None, description="能力筛选(支持多个)")
|
||||
is_active: Optional[bool] = Field(None, description="激活状态筛选")
|
||||
is_public: Optional[bool] = Field(None, description="公开状态筛选")
|
||||
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
||||
@@ -228,6 +243,7 @@ class ModelConfigQueryNew(BaseModel):
|
||||
is_composite: Optional[bool] = Field(None, description="组合模型筛选")
|
||||
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
||||
|
||||
|
||||
class ModelMarketplace(BaseModel):
|
||||
"""模型广场响应Schema"""
|
||||
llm_models: List[ModelConfig] = []
|
||||
@@ -327,6 +343,7 @@ class ModelBaseQuery(BaseModel):
|
||||
is_deprecated: Optional[bool] = Field(None, description="是否弃用")
|
||||
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""模型信息Schema"""
|
||||
model_name: str = Field(..., description="模型名称")
|
||||
@@ -336,4 +353,3 @@ class ModelInfo(BaseModel):
|
||||
is_omni: bool = Field(default=False, description="是否为omni模型")
|
||||
model_type: ModelType = Field(..., description="模型类型")
|
||||
capability: List[str] = Field(default_factory=list, description="模型能力列表")
|
||||
|
||||
|
||||
@@ -140,7 +140,7 @@ class AppChatService:
|
||||
processed_files = None
|
||||
if files:
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
# 调用 Agent(支持多模态)
|
||||
@@ -343,7 +343,7 @@ class AppChatService:
|
||||
processed_files = None
|
||||
if files:
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
# 流式调用 Agent(支持多模态),同时并行启动 TTS
|
||||
|
||||
@@ -603,7 +603,7 @@ class AgentRunService:
|
||||
# 获取 provider 信息
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||
|
||||
# 7. 知识库检索
|
||||
@@ -846,7 +846,7 @@ class AgentRunService:
|
||||
# 获取 provider 信息
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||
|
||||
# 7. 知识库检索
|
||||
|
||||
@@ -19,32 +19,35 @@ from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
import redis
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.cache import InterestMemoryCache
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_config_logger, get_logger
|
||||
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
|
||||
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
||||
from app.core.memory.agent.utils.messages_tools import (
|
||||
merge_multiple_search_results,
|
||||
reorder_output_results,
|
||||
)
|
||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||
from app.core.memory.agent.utils.write_tools import write as write_neo4j
|
||||
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas import FileInput
|
||||
from app.schemas.memory_agent_schema import Write_UserInput
|
||||
from app.schemas.memory_config_schema import ConfigurationError
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.memory_konwledges_server import (
|
||||
write_rag,
|
||||
)
|
||||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||
|
||||
try:
|
||||
from app.core.memory.utils.log.audit_logger import audit_logger
|
||||
@@ -267,8 +270,16 @@ class MemoryAgentService:
|
||||
logger.info("Log streaming completed, cleaning up resources")
|
||||
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
||||
|
||||
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID] | int,
|
||||
db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str:
|
||||
async def write_memory(
|
||||
self,
|
||||
end_user_id: str,
|
||||
messages: list[dict],
|
||||
config_id: Optional[uuid.UUID] | int,
|
||||
db: Session,
|
||||
storage_type: str,
|
||||
user_rag_memory_id: str,
|
||||
language: str = "zh"
|
||||
) -> str:
|
||||
"""
|
||||
Process write operation with config_id
|
||||
|
||||
@@ -297,8 +308,8 @@ class MemoryAgentService:
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
|
||||
if config_id is None and workspace_id is None:
|
||||
raise ValueError(
|
||||
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
||||
raise ValueError(f"No memory configuration found for end_user {end_user_id}. "
|
||||
f"Please ensure the user has a connected memory configuration.")
|
||||
except Exception as e:
|
||||
if "No memory configuration found" in str(e):
|
||||
raise # Re-raise our specific error
|
||||
@@ -334,45 +345,57 @@ class MemoryAgentService:
|
||||
|
||||
raise ValueError(error_msg)
|
||||
|
||||
perceptual_serivce = MemoryPerceptualService(db)
|
||||
for message in messages:
|
||||
message["file_content"] = []
|
||||
for file in message["files"]:
|
||||
file_object = await perceptual_serivce.generate_perceptual_memory(
|
||||
end_user_id=end_user_id,
|
||||
memory_config=memory_config,
|
||||
file=FileInput(**file)
|
||||
)
|
||||
if file_object is None:
|
||||
continue
|
||||
message["file_content"].append((file_object, file["type"]))
|
||||
|
||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
try:
|
||||
if storage_type == "rag":
|
||||
# For RAG storage, convert messages to single string
|
||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
result = await write_rag(end_user_id, message_text, user_rag_memory_id)
|
||||
return result
|
||||
await write_rag(end_user_id, message_text, user_rag_memory_id)
|
||||
return "success"
|
||||
else:
|
||||
async with make_write_graph() as graph:
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# Convert structured messages to LangChain messages
|
||||
langchain_messages = []
|
||||
for msg in messages:
|
||||
if msg['role'] == 'user':
|
||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
||||
elif msg['role'] == 'assistant':
|
||||
langchain_messages.append(AIMessage(content=msg['content']))
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {
|
||||
"messages": langchain_messages,
|
||||
"end_user_id": end_user_id,
|
||||
"memory_config": memory_config,
|
||||
"language": language
|
||||
await write_neo4j(
|
||||
end_user_id=end_user_id,
|
||||
messages=messages,
|
||||
memory_config=memory_config,
|
||||
ref_id='',
|
||||
language=language
|
||||
)
|
||||
for lang in ["zh", "en"]:
|
||||
deleted = await InterestMemoryCache.delete_interest_distribution(
|
||||
end_user_id, lang
|
||||
)
|
||||
if deleted:
|
||||
logger.info(
|
||||
f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}")
|
||||
for message in messages:
|
||||
message["file_content"] = [
|
||||
perceptual[0].file_path for perceptual in message["file_content"]
|
||||
]
|
||||
return self.writer_messages_deal(
|
||||
"success",
|
||||
start_time,
|
||||
end_user_id,
|
||||
config_id,
|
||||
message_text,
|
||||
{
|
||||
"status": "success",
|
||||
"data": messages,
|
||||
"config_id": memory_config.config_id,
|
||||
"config_name": memory_config.config_name
|
||||
}
|
||||
|
||||
# 获取节点更新信息
|
||||
async for update_event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode="updates",
|
||||
config=config
|
||||
):
|
||||
for node_name, node_data in update_event.items():
|
||||
if 'save_neo4j' == node_name:
|
||||
massages = node_data
|
||||
massagesstatus = massages.get('write_result')['status']
|
||||
contents = massages.get('write_result')
|
||||
# Convert messages back to string for logging
|
||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text,
|
||||
contents)
|
||||
)
|
||||
except Exception as e:
|
||||
# Ensure proper error handling and logging
|
||||
error_msg = f"Write operation failed: {str(e)}"
|
||||
|
||||
@@ -38,9 +38,9 @@ class MemoryAPIService:
|
||||
self.db = db
|
||||
|
||||
def validate_end_user(
|
||||
self,
|
||||
end_user_id: str,
|
||||
workspace_id: uuid.UUID
|
||||
self,
|
||||
end_user_id: str,
|
||||
workspace_id: uuid.UUID
|
||||
) -> EndUser:
|
||||
"""Validate that end_user exists and belongs to the workspace.
|
||||
|
||||
@@ -125,13 +125,13 @@ class MemoryAPIService:
|
||||
logger.warning(f"Failed to update memory_config_id for end_user {end_user_id}: {e}")
|
||||
|
||||
async def write_memory(
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
end_user_id: str,
|
||||
message: str,
|
||||
config_id: str,
|
||||
storage_type: str = "neo4j",
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
end_user_id: str,
|
||||
message: str,
|
||||
config_id: str,
|
||||
storage_type: str = "neo4j",
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Write memory with validation.
|
||||
|
||||
@@ -171,7 +171,7 @@ class MemoryAPIService:
|
||||
config_id=config_id,
|
||||
db=self.db,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id or ""
|
||||
user_rag_memory_id=user_rag_memory_id or "",
|
||||
)
|
||||
|
||||
logger.info(f"Memory write successful for end_user: {end_user_id}")
|
||||
@@ -206,14 +206,14 @@ class MemoryAPIService:
|
||||
)
|
||||
|
||||
async def read_memory(
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
end_user_id: str,
|
||||
message: str,
|
||||
search_switch: str = "0",
|
||||
config_id: str = "",
|
||||
storage_type: str = "neo4j",
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
end_user_id: str,
|
||||
message: str,
|
||||
search_switch: str = "0",
|
||||
config_id: str = "",
|
||||
storage_type: str = "neo4j",
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Read memory with validation.
|
||||
|
||||
@@ -244,7 +244,6 @@ class MemoryAPIService:
|
||||
# Update end user's memory_config_id
|
||||
self._update_end_user_config(end_user_id, config_id)
|
||||
|
||||
|
||||
try:
|
||||
# Delegate to MemoryAgentService
|
||||
result = await MemoryAgentService().read_memory(
|
||||
@@ -282,8 +281,8 @@ class MemoryAPIService:
|
||||
)
|
||||
|
||||
def list_memory_configs(
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
) -> Dict[str, Any]:
|
||||
"""List all memory configs for a workspace.
|
||||
|
||||
|
||||
@@ -154,10 +154,10 @@ class MemoryConfigService:
|
||||
self.db = db
|
||||
|
||||
def load_memory_config(
|
||||
self,
|
||||
config_id: Optional[UUID] = None,
|
||||
workspace_id: Optional[UUID] = None,
|
||||
service_name: str = "MemoryConfigService",
|
||||
self,
|
||||
config_id: Optional[UUID] = None,
|
||||
workspace_id: Optional[UUID] = None,
|
||||
service_name: str = "MemoryConfigService",
|
||||
) -> MemoryConfig:
|
||||
"""
|
||||
Load memory configuration from database with optional fallback.
|
||||
@@ -194,8 +194,8 @@ class MemoryConfigService:
|
||||
try:
|
||||
# Use get_config_with_fallback if workspace_id is provided
|
||||
memory_config = None
|
||||
validated_config_id = None
|
||||
if workspace_id:
|
||||
validated_config_id = None
|
||||
if config_id:
|
||||
try:
|
||||
validated_config_id = _validate_config_id(config_id, self.db)
|
||||
@@ -243,10 +243,10 @@ class MemoryConfigService:
|
||||
|
||||
# Helper function to validate model with workspace fallback
|
||||
def _validate_model_with_fallback(
|
||||
model_id: str,
|
||||
model_type: str,
|
||||
workspace_default: str,
|
||||
required: bool = False
|
||||
model_id: str,
|
||||
model_type: str,
|
||||
workspace_default: str,
|
||||
required: bool = False
|
||||
) -> tuple:
|
||||
"""Validate model ID, falling back to workspace default if invalid.
|
||||
|
||||
@@ -343,6 +343,35 @@ class MemoryConfigService:
|
||||
if memory_config.rerank_id or workspace.rerank:
|
||||
logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s")
|
||||
|
||||
vision_uuid, vision_name = validate_and_resolve_model_id(
|
||||
memory_config.vision_id,
|
||||
"llm",
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
required=False,
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
|
||||
audio_uuid, audio_name = validate_and_resolve_model_id(
|
||||
memory_config.audio_id,
|
||||
"llm",
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
required=False,
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
|
||||
video_uuid, video_name = validate_and_resolve_model_id(
|
||||
memory_config.video_id,
|
||||
"llm",
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
required=False,
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
# Create immutable MemoryConfig object
|
||||
config = MemoryConfig(
|
||||
config_id=memory_config.config_id,
|
||||
@@ -356,6 +385,12 @@ class MemoryConfigService:
|
||||
embedding_model_name=embedding_name,
|
||||
rerank_model_id=rerank_uuid,
|
||||
rerank_model_name=rerank_name,
|
||||
video_model_id=video_uuid,
|
||||
video_model_name=video_name,
|
||||
vision_model_id=vision_uuid,
|
||||
vision_model_name=vision_name,
|
||||
audio_model_id=audio_uuid,
|
||||
audio_model_name=audio_name,
|
||||
storage_type=workspace.storage_type or "neo4j",
|
||||
chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker",
|
||||
reflexion_enabled=memory_config.enable_self_reflexion or False,
|
||||
@@ -364,24 +399,31 @@ class MemoryConfigService:
|
||||
reflexion_baseline=memory_config.baseline or "Time",
|
||||
loaded_at=datetime.now(),
|
||||
# Pipeline config: Deduplication
|
||||
enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
|
||||
enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
|
||||
enable_llm_dedup_blockwise=bool(
|
||||
memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
|
||||
enable_llm_disambiguation=bool(
|
||||
memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
|
||||
deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True,
|
||||
t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8,
|
||||
t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8,
|
||||
t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8,
|
||||
# Pipeline config: Statement extraction
|
||||
statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
|
||||
include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
|
||||
max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000,
|
||||
statement_granularity=int(
|
||||
memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
|
||||
include_dialogue_context=bool(
|
||||
memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
|
||||
max_dialogue_context_chars=int(
|
||||
memory_config.max_context) if memory_config.max_context is not None else 1000,
|
||||
# Pipeline config: Forgetting engine
|
||||
lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5,
|
||||
lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5,
|
||||
offset=float(memory_config.offset) if memory_config.offset is not None else 0.0,
|
||||
# Pipeline config: Pruning
|
||||
pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
|
||||
pruning_enabled=bool(
|
||||
memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
|
||||
pruning_scene=memory_config.pruning_scene or "education",
|
||||
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
||||
pruning_threshold=float(
|
||||
memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
||||
# Ontology scene association
|
||||
scene_id=memory_config.scene_id,
|
||||
ontology_class_infos=_load_ontology_class_infos(self.db, memory_config.scene_id),
|
||||
@@ -598,8 +640,8 @@ class MemoryConfigService:
|
||||
return None
|
||||
|
||||
def get_workspace_default_config(
|
||||
self,
|
||||
workspace_id: UUID
|
||||
self,
|
||||
workspace_id: UUID
|
||||
) -> Optional["MemoryConfigModel"]:
|
||||
"""Get workspace default memory config.
|
||||
|
||||
@@ -623,9 +665,9 @@ class MemoryConfigService:
|
||||
return config
|
||||
|
||||
def get_config_with_fallback(
|
||||
self,
|
||||
memory_config_id: Optional[UUID],
|
||||
workspace_id: UUID
|
||||
self,
|
||||
memory_config_id: Optional[UUID],
|
||||
workspace_id: UUID
|
||||
) -> Optional["MemoryConfigModel"]:
|
||||
"""Get memory config with fallback to workspace default.
|
||||
|
||||
@@ -663,9 +705,9 @@ class MemoryConfigService:
|
||||
return config
|
||||
|
||||
def delete_config(
|
||||
self,
|
||||
config_id: UUID | int,
|
||||
force: bool = False
|
||||
self,
|
||||
config_id: UUID | int,
|
||||
force: bool = False
|
||||
) -> dict:
|
||||
"""Delete memory config with protection against in-use configs.
|
||||
|
||||
@@ -800,9 +842,9 @@ class MemoryConfigService:
|
||||
# ==================== 记忆配置提取方法 ====================
|
||||
|
||||
def extract_memory_config_id(
|
||||
self,
|
||||
app_type: str,
|
||||
config: dict
|
||||
self,
|
||||
app_type: str,
|
||||
config: dict
|
||||
) -> tuple[Optional[uuid.UUID], bool]:
|
||||
"""从发布配置中提取 memory_config_id(根据应用类型分发)
|
||||
|
||||
@@ -828,8 +870,8 @@ class MemoryConfigService:
|
||||
return None, False
|
||||
|
||||
def _extract_memory_config_id_from_agent(
|
||||
self,
|
||||
config: dict
|
||||
self,
|
||||
config: dict
|
||||
) -> tuple[Optional[uuid.UUID], bool]:
|
||||
"""从 Agent 应用配置中提取 memory_config_id
|
||||
|
||||
@@ -888,8 +930,8 @@ class MemoryConfigService:
|
||||
return None, False
|
||||
|
||||
def _extract_memory_config_id_from_workflow(
|
||||
self,
|
||||
config: dict
|
||||
self,
|
||||
config: dict
|
||||
) -> tuple[Optional[uuid.UUID], bool]:
|
||||
"""从 Workflow 应用配置中提取 memory_config_id
|
||||
|
||||
|
||||
@@ -341,7 +341,7 @@ async def memory_konwledges_up(
|
||||
)
|
||||
db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user)
|
||||
|
||||
return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful")
|
||||
return db_document
|
||||
|
||||
|
||||
async def create_document_chunk(
|
||||
@@ -350,7 +350,7 @@ async def create_document_chunk(
|
||||
create_data: ChunkCreate,
|
||||
db: Session,
|
||||
current_user: User
|
||||
):
|
||||
) -> DocumentChunk:
|
||||
"""
|
||||
创建文档块
|
||||
|
||||
@@ -439,10 +439,10 @@ async def create_document_chunk(
|
||||
db_document.chunk_num += 1
|
||||
db.commit()
|
||||
|
||||
return success(data=chunk, msg="文档块创建成功")
|
||||
return chunk
|
||||
|
||||
|
||||
async def write_rag(end_user_id, message, user_rag_memory_id):
|
||||
async def write_rag(end_user_id, message, user_rag_memory_id) -> DocumentChunk:
|
||||
"""
|
||||
将消息写入 RAG 知识库
|
||||
|
||||
@@ -482,11 +482,11 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
|
||||
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt")
|
||||
print('======', document)
|
||||
api_logger.info(f"查找文档结果: document_id={document}")
|
||||
create_chunks = ChunkCreate(content=message)
|
||||
if document is not None:
|
||||
# 文档已存在,直接添加新块
|
||||
api_logger.info(f"文档已存在,添加新块: document_id={document}")
|
||||
|
||||
create_chunks = ChunkCreate(content=message)
|
||||
result = await create_document_chunk(
|
||||
kb_id=kb_uuid,
|
||||
document_id=uuid.UUID(document),
|
||||
@@ -498,13 +498,20 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
|
||||
else:
|
||||
# 文档不存在,创建新文档
|
||||
api_logger.info(f"文档不存在,创建新文档: end_user_id={end_user_id}")
|
||||
result = await memory_konwledges_up(
|
||||
document = await memory_konwledges_up(
|
||||
kb_id=user_rag_memory_id,
|
||||
parent_id=user_rag_memory_id,
|
||||
create_data=create_data,
|
||||
db=db,
|
||||
current_user=current_user
|
||||
)
|
||||
result = await create_document_chunk(
|
||||
kb_id=kb_uuid,
|
||||
document_id=document.id,
|
||||
create_data=create_chunks,
|
||||
db=db,
|
||||
current_user=current_user
|
||||
)
|
||||
# 重新查询刚创建的文档ID
|
||||
new_document_id = find_document_id_by_kb_and_filename(
|
||||
db=db,
|
||||
|
||||
@@ -12,11 +12,12 @@ from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models import FileMetadata
|
||||
from app.models import FileMetadata, ModelApiKey, ModelType
|
||||
from app.models.memory_perceptual_model import PerceptualType, FileStorageService
|
||||
from app.models.prompt_optimizer_model import RoleType
|
||||
from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository
|
||||
from app.schemas import FileType
|
||||
from app.schemas import FileType, FileInput
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from app.schemas.memory_perceptual_schema import (
|
||||
PerceptualQuerySchema,
|
||||
PerceptualTimelineResponse,
|
||||
@@ -24,6 +25,8 @@ from app.schemas.memory_perceptual_schema import (
|
||||
AudioModal, Content, VideoModal, TextModal
|
||||
)
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
|
||||
business_logger = get_business_logger()
|
||||
|
||||
@@ -195,21 +198,58 @@ class MemoryPerceptualService:
|
||||
business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}")
|
||||
raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR)
|
||||
|
||||
def _get_mutlimodal_client(
|
||||
self,
|
||||
file_type: FileType,
|
||||
config: MemoryConfig
|
||||
) -> tuple[RedBearLLM | None, ModelApiKey | None]:
|
||||
model_config = None
|
||||
if file_type == FileType.AUDIO:
|
||||
model_config = ModelApiKeyService.get_available_api_key(
|
||||
self.db,
|
||||
config.audio_model_id
|
||||
)
|
||||
elif file_type == FileType.VIDEO:
|
||||
model_config = ModelApiKeyService.get_available_api_key(
|
||||
self.db,
|
||||
config.video_model_id
|
||||
)
|
||||
elif file_type == FileType.DOCUMENT:
|
||||
model_config = ModelApiKeyService.get_available_api_key(
|
||||
self.db,
|
||||
config.llm_model_id
|
||||
)
|
||||
elif file_type == FileType.IMAGE:
|
||||
model_config = ModelApiKeyService.get_available_api_key(
|
||||
self.db,
|
||||
config.vision_model_id
|
||||
)
|
||||
llm = None
|
||||
if model_config:
|
||||
llm = RedBearLLM(
|
||||
RedBearModelConfig(
|
||||
model_name=model_config.model_name,
|
||||
provider=model_config.provider,
|
||||
api_key=model_config.api_key,
|
||||
base_url=model_config.api_base,
|
||||
is_omni=model_config.is_omni
|
||||
)
|
||||
)
|
||||
return llm, model_config
|
||||
|
||||
async def generate_perceptual_memory(
|
||||
self,
|
||||
end_user_id: str,
|
||||
model_config: ModelInfo,
|
||||
file_type: str,
|
||||
file_url: str,
|
||||
file_message: dict,
|
||||
memory_config: MemoryConfig,
|
||||
file: FileInput
|
||||
):
|
||||
memories = self.repository.get_by_url(file_url)
|
||||
memories = self.repository.get_by_url(file.url)
|
||||
if memories:
|
||||
business_logger.info(f"Perceptual memory already exists: {file_url}")
|
||||
business_logger.info(f"Perceptual memory already exists: {file.url}")
|
||||
if end_user_id not in [memory.end_user_id for memory in memories]:
|
||||
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
|
||||
memory_cache = memories[0]
|
||||
self.repository.create_perceptual_memory(
|
||||
memory = self.repository.create_perceptual_memory(
|
||||
end_user_id=uuid.UUID(end_user_id),
|
||||
perceptual_type=PerceptualType(memory_cache.perceptual_type),
|
||||
file_path=memory_cache.file_path,
|
||||
@@ -219,20 +259,33 @@ class MemoryPerceptualService:
|
||||
meta_data=memory_cache.meta_data
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
return
|
||||
llm = RedBearLLM(RedBearModelConfig(
|
||||
return memory
|
||||
else:
|
||||
for memory in memories:
|
||||
if memory.end_user_id == uuid.UUID(end_user_id):
|
||||
return memory
|
||||
llm, model_config = self._get_mutlimodal_client(file.type, memory_config)
|
||||
multimodel_service = MultimodalService(self.db, ModelInfo(
|
||||
model_name=model_config.model_name,
|
||||
provider=model_config.provider,
|
||||
api_key=model_config.api_key,
|
||||
base_url=model_config.api_base,
|
||||
is_omni=model_config.is_omni
|
||||
), type=model_config.model_type)
|
||||
api_base=model_config.api_base,
|
||||
is_omni=model_config.is_omni,
|
||||
capability=model_config.capability,
|
||||
model_type=ModelType.LLM
|
||||
))
|
||||
file_message = await multimodel_service.process_files(
|
||||
files=[file]
|
||||
)
|
||||
if not file_message:
|
||||
logger.warning(f"Unsupport file type {file}, model capability: {model_config.capability}")
|
||||
return None
|
||||
file_message = file_message[0]
|
||||
try:
|
||||
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
||||
with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
|
||||
opt_system_prompt = f.read()
|
||||
rendered_system_message = Template(opt_system_prompt).render(file_type=file_type, language='zh')
|
||||
rendered_system_message = Template(opt_system_prompt).render(file_type=file.type, language='zh')
|
||||
except FileNotFoundError:
|
||||
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
|
||||
messages = [
|
||||
@@ -242,8 +295,22 @@ class MemoryPerceptualService:
|
||||
]}
|
||||
]
|
||||
result = await llm.ainvoke(messages)
|
||||
content = json_repair.repair_json(result.content, return_objects=True)
|
||||
path = urlparse(file_url).path
|
||||
content = result.content
|
||||
final_output = ""
|
||||
if isinstance(content, list):
|
||||
for msg in content:
|
||||
if isinstance(msg, dict):
|
||||
final_output += msg.get("text", "")
|
||||
elif isinstance(msg, str):
|
||||
final_output += msg
|
||||
elif isinstance(content, dict):
|
||||
final_output += content.get("text", "")
|
||||
elif isinstance(content, str):
|
||||
final_output = content
|
||||
else:
|
||||
raise ValueError(f"Unexcept Model Output Type: {result.content}")
|
||||
content = json_repair.repair_json(final_output, return_objects=True)
|
||||
path = urlparse(file.url).path
|
||||
filename = os.path.basename(path)
|
||||
filename = unquote(filename)
|
||||
file_ext = os.path.splitext(filename)[1]
|
||||
@@ -252,21 +319,21 @@ class MemoryPerceptualService:
|
||||
stmt = select(FileMetadata).where(
|
||||
FileMetadata.id == file_id
|
||||
)
|
||||
file = self.db.execute(stmt).scalar_one_or_none()
|
||||
file_obj = self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if file:
|
||||
filename = file.file_name
|
||||
file_ext = file.file_ext
|
||||
if file_obj:
|
||||
filename = file_obj.file_name
|
||||
file_ext = file_obj.file_ext
|
||||
except ValueError:
|
||||
business_logger.debug(f"Remote file, file_id={filename}")
|
||||
if not file_ext:
|
||||
if file_type == FileType.AUDIO:
|
||||
if file.type == FileType.AUDIO:
|
||||
file_ext = ".mp3"
|
||||
elif file_type == FileType.VIDEO:
|
||||
elif file.type == FileType.VIDEO:
|
||||
file_ext = ".mp4"
|
||||
elif file_type == FileType.DOCUMENT:
|
||||
elif file.type == FileType.DOCUMENT:
|
||||
file_ext = ".txt"
|
||||
elif file_type == FileType.IMAGE:
|
||||
elif file.type == FileType.IMAGE:
|
||||
file_ext = ".jpg"
|
||||
filename += file_ext
|
||||
file_content = {
|
||||
@@ -274,11 +341,11 @@ class MemoryPerceptualService:
|
||||
"topic": content.get("topic"),
|
||||
"domain": content.get("domain")
|
||||
}
|
||||
if file_type in [FileType.IMAGE, FileType.VIDEO]:
|
||||
if file.type in [FileType.IMAGE, FileType.VIDEO]:
|
||||
file_modalities = {
|
||||
"scene": content.get("scene", [])
|
||||
}
|
||||
elif file_type in [FileType.DOCUMENT]:
|
||||
elif file.type in [FileType.DOCUMENT]:
|
||||
file_modalities = {
|
||||
"section_count": content.get("section_count", 0),
|
||||
"title": content.get("title", ""),
|
||||
@@ -288,10 +355,10 @@ class MemoryPerceptualService:
|
||||
file_modalities = {
|
||||
"speaker_count": content.get("speaker_count", 0)
|
||||
}
|
||||
self.repository.create_perceptual_memory(
|
||||
memory = self.repository.create_perceptual_memory(
|
||||
end_user_id=uuid.UUID(end_user_id),
|
||||
perceptual_type=PerceptualType.trans_from_file_type(file_type),
|
||||
file_path=file_url,
|
||||
perceptual_type=PerceptualType.trans_from_file_type(file.type),
|
||||
file_path=file.url,
|
||||
file_name=filename,
|
||||
file_ext=file_ext,
|
||||
summary=content.get('summary', ""),
|
||||
@@ -301,3 +368,4 @@ class MemoryPerceptualService:
|
||||
}
|
||||
)
|
||||
self.db.commit()
|
||||
return memory
|
||||
|
||||
@@ -11,9 +11,11 @@ import time
|
||||
from datetime import datetime
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_config_logger, get_logger
|
||||
from app.core.memory.analytics.hot_memory_tags import (
|
||||
get_hot_memory_tags,
|
||||
get_raw_tags_from_db,
|
||||
filter_tags_with_llm,
|
||||
)
|
||||
@@ -32,8 +34,6 @@ from app.schemas.memory_storage_schema import (
|
||||
)
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.utils.sse_utils import format_sse_message
|
||||
from dotenv import load_dotenv
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = get_logger(__name__)
|
||||
config_logger = get_config_logger()
|
||||
@@ -69,7 +69,7 @@ class MemoryStorageService:
|
||||
return result
|
||||
|
||||
|
||||
class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
"""Service layer for config params CRUD.
|
||||
|
||||
使用 SQLAlchemy ORM 进行数据库操作。
|
||||
@@ -114,7 +114,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
return data_list
|
||||
|
||||
# --- Create ---
|
||||
def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述)
|
||||
def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述)
|
||||
# 业务层检查同一工作空间下是否已存在同名配置
|
||||
if params.workspace_id and params.config_name:
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
@@ -183,20 +183,20 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
return None
|
||||
|
||||
# --- Delete ---
|
||||
def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置ID)
|
||||
def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置ID)
|
||||
success = MemoryConfigRepository.delete(self.db, key.config_id)
|
||||
if not success:
|
||||
raise ValueError("未找到配置")
|
||||
return {"affected": 1}
|
||||
|
||||
# --- Update ---
|
||||
def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数
|
||||
def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数
|
||||
config = MemoryConfigRepository.update(self.db, update)
|
||||
if not config:
|
||||
raise ValueError("未找到配置")
|
||||
return {"affected": 1}
|
||||
|
||||
def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数
|
||||
def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数
|
||||
config = MemoryConfigRepository.update_extracted(self.db, update)
|
||||
if not config:
|
||||
raise ValueError("未找到配置")
|
||||
@@ -207,14 +207,14 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
# 使用新方法: MemoryForgetService.read_forgetting_config() 和 MemoryForgetService.update_forgetting_config()
|
||||
|
||||
# --- Read ---
|
||||
def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数
|
||||
def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数
|
||||
result = MemoryConfigRepository.get_extracted_config(self.db, key.config_id)
|
||||
if not result:
|
||||
raise ValueError("未找到配置")
|
||||
return result
|
||||
|
||||
# --- Read All ---
|
||||
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
|
||||
def get_all(self, workspace_id=None) -> List[Dict[str, Any]]: # 获取所有配置参数
|
||||
results = MemoryConfigRepository.get_all(self.db, workspace_id)
|
||||
|
||||
# 检查并修正 pruning_scene 与 scene_name 不一致的记录
|
||||
@@ -241,11 +241,10 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
except (ValueError, TypeError):
|
||||
config_id_old = None
|
||||
|
||||
|
||||
if config_id_old:
|
||||
memory_config=config_id_old
|
||||
memory_config = config_id_old
|
||||
else:
|
||||
memory_config=config.config_id
|
||||
memory_config = config.config_id
|
||||
config_dict = {
|
||||
"config_id": memory_config,
|
||||
"config_name": config.config_name,
|
||||
@@ -289,7 +288,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
# 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
|
||||
return self._convert_timestamps_to_format(data_list)
|
||||
|
||||
|
||||
async def pilot_run_stream(self, payload: ConfigPilotRun, language: str = "zh") -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
流式执行试运行,产生 SSE 格式的进度事件
|
||||
@@ -344,11 +342,13 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
# 关联了本体场景,优先使用 custom_text
|
||||
if hasattr(payload, 'custom_text') and payload.custom_text:
|
||||
dialogue_text = payload.custom_text.strip()
|
||||
logger.info(f"[PILOT_RUN_STREAM] Using custom_text for scene_id={memory_config.scene_id}, length: {len(dialogue_text)}")
|
||||
logger.info(
|
||||
f"[PILOT_RUN_STREAM] Using custom_text for scene_id={memory_config.scene_id}, length: {len(dialogue_text)}")
|
||||
else:
|
||||
# 如果没有提供 custom_text,回退到 dialogue_text
|
||||
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
|
||||
logger.info(f"[PILOT_RUN_STREAM] No custom_text provided, using dialogue_text for scene_id={memory_config.scene_id}")
|
||||
logger.info(
|
||||
f"[PILOT_RUN_STREAM] No custom_text provided, using dialogue_text for scene_id={memory_config.scene_id}")
|
||||
else:
|
||||
# 没有关联本体场景,使用 dialogue_text
|
||||
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
|
||||
@@ -360,7 +360,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
|
||||
logger.info(f"[PILOT_RUN_STREAM] Final text preview: {dialogue_text[:100]}")
|
||||
|
||||
|
||||
# 步骤 2: 创建进度回调函数捕获管线进度
|
||||
# 使用队列在回调和生成器之间传递进度事件
|
||||
progress_queue: asyncio.Queue = asyncio.Queue()
|
||||
@@ -382,7 +381,8 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
try:
|
||||
from app.services.pilot_run_service import run_pilot_extraction
|
||||
|
||||
logger.info(f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}")
|
||||
logger.info(
|
||||
f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}")
|
||||
await run_pilot_extraction(
|
||||
memory_config=memory_config,
|
||||
dialogue_text=dialogue_text,
|
||||
@@ -483,11 +483,10 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
"time": int(time.time() * 1000)
|
||||
})
|
||||
|
||||
|
||||
async def _compute_ontology_coverage(
|
||||
self,
|
||||
extracted_result: Dict[str, Any],
|
||||
memory_config,
|
||||
self,
|
||||
extracted_result: Dict[str, Any],
|
||||
memory_config,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""根据提取结果中的实体类型,与场景/通用本体类型做互斥分类统计。
|
||||
|
||||
@@ -580,8 +579,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
|
||||
# -------------------- Neo4j Search & Analytics (fused from data_search_service.py) --------------------
|
||||
# Ensure env for connector (e.g., NEO4J_PASSWORD)
|
||||
load_dotenv()
|
||||
_neo4j_connector = Neo4jConnector()
|
||||
|
||||
|
||||
async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
@@ -701,10 +698,11 @@ async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def analytics_hot_memory_tags(
|
||||
db: Session,
|
||||
current_user: User,
|
||||
limit: int = 10
|
||||
db: Session,
|
||||
current_user: User,
|
||||
limit: int = 10
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取热门记忆标签,按数量排序并返回前N个
|
||||
@@ -815,11 +813,11 @@ async def analytics_recent_activity_stats(workspace_id: Optional[str] = None) ->
|
||||
source = "log"
|
||||
|
||||
total = (
|
||||
stats.get("chunk_count", 0)
|
||||
+ stats.get("statements_count", 0)
|
||||
+ stats.get("triplet_entities_count", 0)
|
||||
+ stats.get("triplet_relations_count", 0)
|
||||
+ stats.get("temporal_count", 0)
|
||||
stats.get("chunk_count", 0)
|
||||
+ stats.get("statements_count", 0)
|
||||
+ stats.get("triplet_entities_count", 0)
|
||||
+ stats.get("triplet_relations_count", 0)
|
||||
+ stats.get("temporal_count", 0)
|
||||
)
|
||||
|
||||
# 计算"最新一次活动多久前"(仅日志来源时有效)
|
||||
@@ -845,5 +843,3 @@ async def analytics_recent_activity_stats(workspace_id: Optional[str] = None) ->
|
||||
|
||||
data = {"total": total, "stats": stats, "latest_relative": latest_relative, "source": source}
|
||||
return data
|
||||
|
||||
|
||||
|
||||
@@ -9,17 +9,15 @@
|
||||
- OpenAI: 支持 URL 和 base64 格式
|
||||
"""
|
||||
import base64
|
||||
import csv
|
||||
import io
|
||||
import uuid
|
||||
import json
|
||||
import zipfile
|
||||
import chardet
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
import csv
|
||||
import json
|
||||
|
||||
import PyPDF2
|
||||
import chardet
|
||||
import httpx
|
||||
import magic
|
||||
import openpyxl
|
||||
@@ -35,7 +33,6 @@ from app.models.file_metadata_model import FileMetadata
|
||||
from app.schemas.app_schema import FileInput, FileType, TransferMethod
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.services.audio_transcription_service import AudioTranscriptionService
|
||||
from app.tasks import write_perceptual_memory
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -342,15 +339,12 @@ class MultimodalService:
|
||||
|
||||
async def process_files(
|
||||
self,
|
||||
end_user_id: uuid.UUID | str,
|
||||
files: Optional[List[FileInput]],
|
||||
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
处理文件列表,返回 LLM 可用的格式
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
files: 文件输入列表
|
||||
|
||||
Returns:
|
||||
@@ -358,8 +352,6 @@ class MultimodalService:
|
||||
"""
|
||||
if not files:
|
||||
return []
|
||||
if isinstance(end_user_id, uuid.UUID):
|
||||
end_user_id = str(end_user_id)
|
||||
|
||||
# 获取对应的策略
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容格式
|
||||
@@ -380,23 +372,15 @@ class MultimodalService:
|
||||
if file.type == FileType.IMAGE and "vision" in self.capability:
|
||||
is_support, content = await self._process_image(file, strategy)
|
||||
result.append(content)
|
||||
if is_support:
|
||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||
elif file.type == FileType.DOCUMENT:
|
||||
is_support, content = await self._process_document(file, strategy)
|
||||
result.append(content)
|
||||
if is_support:
|
||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||
elif file.type == FileType.AUDIO and "audio" in self.capability:
|
||||
is_support, content = await self._process_audio(file, strategy)
|
||||
result.append(content)
|
||||
if is_support:
|
||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||
elif file.type == FileType.VIDEO and "video" in self.capability:
|
||||
is_support, content = await self._process_video(file, strategy)
|
||||
result.append(content)
|
||||
if is_support:
|
||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||
else:
|
||||
logger.warning(f"不支持的文件类型: {file.type}")
|
||||
except Exception as e:
|
||||
@@ -483,17 +467,6 @@ class MultimodalService:
|
||||
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
|
||||
return result
|
||||
|
||||
def write_perceptual_memory(
|
||||
self,
|
||||
end_user_id: str,
|
||||
file_type: str,
|
||||
file_url: str,
|
||||
file_message: dict
|
||||
):
|
||||
"""写入感知记忆"""
|
||||
if end_user_id and self.api_config:
|
||||
write_perceptual_memory.delay(end_user_id, self.api_config.model_dump(), file_type, file_url, file_message)
|
||||
|
||||
async def _process_image(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]:
|
||||
"""
|
||||
处理图片文件
|
||||
|
||||
@@ -297,9 +297,12 @@ async def run_pilot_extraction(
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
_,
|
||||
statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_edges,
|
||||
_,
|
||||
_
|
||||
) = extraction_result
|
||||
|
||||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||||
|
||||
@@ -1887,7 +1887,8 @@ async def _extract_node_properties(label: str, properties: Dict[str, Any],node_
|
||||
"Chunk": ["content", "created_at"],
|
||||
"Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption","emotion_keywords","emotion_type","emotion_subject"],
|
||||
"ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption","aliases","connect_strength"],
|
||||
"MemorySummary": ["summary", "content", "created_at", "caption"] # 添加 content 字段
|
||||
"MemorySummary": ["summary", "content", "created_at", "caption"], # 添加 content 字段
|
||||
"Perceptual": ["file_name", "file_path", "file_type", "domain", "topic", "keywords", "summary"]
|
||||
}
|
||||
|
||||
# 获取该节点类型的白名单字段
|
||||
|
||||
@@ -20,6 +20,7 @@ from app.core.workflow.variable.base_variable import FileObject
|
||||
from app.db import get_db
|
||||
from app.models import App
|
||||
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
||||
from app.repositories import knowledge_repository
|
||||
from app.repositories.workflow_repository import (
|
||||
WorkflowConfigRepository,
|
||||
WorkflowExecutionRepository,
|
||||
@@ -29,6 +30,7 @@ from app.schemas import DraftRunRequest, FileInput
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.multi_agent_service import convert_uuids_to_str
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
from app.services.workspace_service import get_workspace_storage_type_without_auth
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -540,6 +542,25 @@ class WorkflowService:
|
||||
mapped = internal_event
|
||||
return mapped
|
||||
|
||||
def _get_memory_store_info(self, workspace_id: uuid.UUID) -> tuple[str, str]:
|
||||
storage_type = get_workspace_storage_type_without_auth(self.db, workspace_id)
|
||||
user_rag_memory_id = ""
|
||||
if storage_type == "rag":
|
||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
db=self.db,
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
else:
|
||||
logger.warning(
|
||||
f"No knowledge base named 'USER_RAG_MEMORY' found, "
|
||||
f"workspace_id: {workspace_id}, will use neo4j storage"
|
||||
)
|
||||
storage_type = 'neo4j'
|
||||
return storage_type, user_rag_memory_id
|
||||
|
||||
# ==================== 工作流执行 ====================
|
||||
|
||||
async def run(
|
||||
@@ -607,6 +628,7 @@ class WorkflowService:
|
||||
|
||||
try:
|
||||
files = await self._handle_file_input(payload.files)
|
||||
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
|
||||
input_data["files"] = files
|
||||
message_id = uuid.uuid4()
|
||||
# 更新状态为运行中
|
||||
@@ -631,7 +653,9 @@ class WorkflowService:
|
||||
input_data=input_data,
|
||||
execution_id=execution.execution_id,
|
||||
workspace_id=str(workspace_id),
|
||||
user_id=payload.user_id
|
||||
user_id=payload.user_id,
|
||||
memory_storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
# 更新执行结果
|
||||
if result.get("status") == "completed":
|
||||
@@ -780,6 +804,7 @@ class WorkflowService:
|
||||
|
||||
try:
|
||||
files = await self._handle_file_input(payload.files)
|
||||
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
|
||||
input_data["files"] = files
|
||||
self.update_execution_status(execution.execution_id, "running")
|
||||
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
|
||||
@@ -801,6 +826,8 @@ class WorkflowService:
|
||||
execution_id=execution.execution_id,
|
||||
workspace_id=str(workspace_id),
|
||||
user_id=payload.user_id,
|
||||
memory_storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
if event.get("event") == "workflow_end":
|
||||
status = event.get("data", {}).get("status")
|
||||
|
||||
@@ -863,7 +863,7 @@ def get_workspace_storage_type(
|
||||
def get_workspace_storage_type_without_auth(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
) -> Optional[str]:
|
||||
) -> str:
|
||||
"""获取工作空间的存储类型(无需权限验证,用于公开分享等场景)
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1073,9 +1073,15 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
|
||||
|
||||
|
||||
@celery_app.task(name="app.core.memory.agent.write_message", bind=True)
|
||||
def write_message_task(self, end_user_id: str, message: list[dict], config_id: str | int, storage_type: str,
|
||||
user_rag_memory_id: str,
|
||||
language: str = "zh") -> Dict[str, Any]:
|
||||
def write_message_task(
|
||||
self,
|
||||
end_user_id: str,
|
||||
message: list[dict],
|
||||
config_id: str | int,
|
||||
storage_type: str,
|
||||
user_rag_memory_id: str,
|
||||
language: str = "zh"
|
||||
) -> Dict[str, Any]:
|
||||
"""Celery task to process a write message via MemoryAgentService.
|
||||
Args:
|
||||
end_user_id: Group ID for the memory agent (also used as end_user_id)
|
||||
@@ -1091,7 +1097,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
||||
Raises:
|
||||
Exception on failure
|
||||
"""
|
||||
|
||||
logger.info(
|
||||
f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, "
|
||||
f"config_id={config_id} (type: {type(config_id).__name__}), "
|
||||
@@ -1105,14 +1110,11 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
actual_config_id = resolve_config_id(config_id, db)
|
||||
print(100 * '-')
|
||||
print(actual_config_id)
|
||||
print(100 * '-')
|
||||
logger.info(
|
||||
f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} (type: {type(actual_config_id).__name__})")
|
||||
logger.info(f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} "
|
||||
f"(type: {type(actual_config_id).__name__})")
|
||||
except (ValueError, AttributeError) as e:
|
||||
logger.error(
|
||||
f"[CELERY WRITE] Invalid config_id format: {config_id} (type: {type(config_id).__name__}), error: {e}")
|
||||
logger.error(f"[CELERY WRITE] Invalid config_id format: {config_id} "
|
||||
f"(type: {type(config_id).__name__}), error: {e}")
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": f"Invalid config_id format: {config_id} - {str(e)}",
|
||||
@@ -1151,8 +1153,8 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
logger.info(
|
||||
f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
|
||||
logger.info(f"[CELERY WRITE] Task completed successfully "
|
||||
f"- elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
|
||||
|
||||
# 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用
|
||||
try:
|
||||
@@ -1167,7 +1169,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
||||
)
|
||||
except Exception as _e:
|
||||
logger.warning(f"[CELERY WRITE] 写入 last_done 时间戳失败(不影响主流程): {_e}")
|
||||
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"result": result,
|
||||
@@ -2611,57 +2612,6 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="app.tasks.write_perceptual_memory",
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
max_retries=0,
|
||||
acks_late=False,
|
||||
time_limit=3600,
|
||||
soft_time_limit=3300,
|
||||
)
|
||||
def write_perceptual_memory(
|
||||
self,
|
||||
end_user_id: str,
|
||||
model_api_config: dict,
|
||||
file_type: str,
|
||||
file_url: str,
|
||||
file_message: dict
|
||||
):
|
||||
"""
|
||||
Write perceptual memory for a user into PostgreSQL and Neo4j.
|
||||
|
||||
This task generates or updates the user's perceptual memory
|
||||
in the backend databases. It is intended to be executed asynchronously
|
||||
via Celery.
|
||||
|
||||
Args:
|
||||
end_user_id (uuid.UUID): The unique identifier of the end user.
|
||||
model_api_config (ModelInfo): API configuration for the model
|
||||
used to generate perceptual memory.
|
||||
file_type (str): The file type
|
||||
file_url (url): The url of file
|
||||
file_message (dict): The file message containing details about the file
|
||||
to be processed.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
file_url_md5 = hashlib.md5(file_url.encode("utf-8")).hexdigest()
|
||||
set_asyncio_event_loop()
|
||||
with RedisLock(f"perceptual:{file_url_md5}", redis_client=get_sync_redis_client()):
|
||||
model_info = ModelInfo(**model_api_config)
|
||||
with get_db_context() as db:
|
||||
memory_perceptual_service = MemoryPerceptualService(db)
|
||||
return asyncio.run(memory_perceptual_service.generate_perceptual_memory(
|
||||
end_user_id,
|
||||
model_info,
|
||||
file_type,
|
||||
file_url,
|
||||
file_message,
|
||||
))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 社区聚类补全任务(触发型)
|
||||
# =============================================================================
|
||||
@@ -2672,7 +2622,7 @@ def write_perceptual_memory(
|
||||
ignore_result=False,
|
||||
max_retries=0,
|
||||
acks_late=False,
|
||||
time_limit=7200, # 2小时硬超时
|
||||
time_limit=7200, # 2小时硬超时
|
||||
soft_time_limit=6900,
|
||||
)
|
||||
def init_community_clustering_for_users(self, end_user_ids: List[str], workspace_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
@@ -2787,7 +2737,8 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace
|
||||
embedding_model_id=embedding_model_id,
|
||||
)
|
||||
|
||||
logger.info(f"[CommunityCluster] 用户 {end_user_id} 有 {len(entities)} 个实体,开始全量聚类,llm_model_id={llm_model_id}")
|
||||
logger.info(
|
||||
f"[CommunityCluster] 用户 {end_user_id} 有 {len(entities)} 个实体,开始全量聚类,llm_model_id={llm_model_id}")
|
||||
await engine.full_clustering(end_user_id)
|
||||
initialized += 1
|
||||
logger.info(f"[CommunityCluster] 用户 {end_user_id} 聚类完成")
|
||||
@@ -2810,12 +2761,6 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace
|
||||
}
|
||||
|
||||
try:
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
loop = set_asyncio_event_loop()
|
||||
result = loop.run_until_complete(_run())
|
||||
result["elapsed_time"] = time.time() - start_time
|
||||
|
||||
Reference in New Issue
Block a user