Merge pull request #676 from SuanmoSuanyangTechnology/feature/multimodel_memory
feat(memory, model): update multi-modal memory write and model list API
This commit is contained in:
@@ -118,142 +118,142 @@ async def download_log(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/writer_service", response_model=ApiResponse)
|
# @router.post("/writer_service", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
# @cur_workspace_access_guard()
|
||||||
async def write_server(
|
# async def write_server(
|
||||||
user_input: Write_UserInput,
|
# user_input: Write_UserInput,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
# language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
db: Session = Depends(get_db),
|
# db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
# current_user: User = Depends(get_current_user)
|
||||||
):
|
# ):
|
||||||
"""
|
# """
|
||||||
Write service endpoint - processes write operations synchronously
|
# Write service endpoint - processes write operations synchronously
|
||||||
|
#
|
||||||
Args:
|
# Args:
|
||||||
user_input: Write request containing message and end_user_id
|
# user_input: Write request containing message and end_user_id
|
||||||
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||||
|
#
|
||||||
Returns:
|
# Returns:
|
||||||
Response with write operation status
|
# Response with write operation status
|
||||||
"""
|
# """
|
||||||
# 使用集中化的语言校验
|
# # 使用集中化的语言校验
|
||||||
language = get_language_from_header(language_type)
|
# language = get_language_from_header(language_type)
|
||||||
|
#
|
||||||
config_id = user_input.config_id
|
# config_id = user_input.config_id
|
||||||
workspace_id = current_user.current_workspace_id
|
# workspace_id = current_user.current_workspace_id
|
||||||
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
# api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||||
|
#
|
||||||
# 获取 storage_type,如果为 None 则使用默认值
|
# # 获取 storage_type,如果为 None 则使用默认值
|
||||||
storage_type = workspace_service.get_workspace_storage_type(
|
# storage_type = workspace_service.get_workspace_storage_type(
|
||||||
db=db,
|
# db=db,
|
||||||
workspace_id=workspace_id,
|
# workspace_id=workspace_id,
|
||||||
user=current_user
|
# user=current_user
|
||||||
)
|
# )
|
||||||
if storage_type is None: storage_type = 'neo4j'
|
# if storage_type is None: storage_type = 'neo4j'
|
||||||
user_rag_memory_id = ''
|
# user_rag_memory_id = ''
|
||||||
|
#
|
||||||
# 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
# # 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
||||||
if storage_type == 'rag':
|
# if storage_type == 'rag':
|
||||||
if workspace_id:
|
# if workspace_id:
|
||||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
# knowledge = knowledge_repository.get_knowledge_by_name(
|
||||||
db=db,
|
# db=db,
|
||||||
name="USER_RAG_MERORY",
|
# name="USER_RAG_MERORY",
|
||||||
workspace_id=workspace_id
|
# workspace_id=workspace_id
|
||||||
)
|
# )
|
||||||
if knowledge:
|
# if knowledge:
|
||||||
user_rag_memory_id = str(knowledge.id)
|
# user_rag_memory_id = str(knowledge.id)
|
||||||
else:
|
# else:
|
||||||
api_logger.warning(
|
# api_logger.warning(
|
||||||
f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
# f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||||
storage_type = 'neo4j'
|
# storage_type = 'neo4j'
|
||||||
else:
|
# else:
|
||||||
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
# api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||||
storage_type = 'neo4j'
|
# storage_type = 'neo4j'
|
||||||
|
#
|
||||||
api_logger.info(
|
# api_logger.info(
|
||||||
f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
# f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||||
try:
|
# try:
|
||||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
# messages_list = memory_agent_service.get_messages_list(user_input)
|
||||||
result = await memory_agent_service.write_memory(
|
# result = await memory_agent_service.write_memory(
|
||||||
user_input.end_user_id,
|
# user_input.end_user_id,
|
||||||
messages_list,
|
# messages_list,
|
||||||
config_id,
|
# config_id,
|
||||||
db,
|
# db,
|
||||||
storage_type,
|
# storage_type,
|
||||||
user_rag_memory_id,
|
# user_rag_memory_id,
|
||||||
language
|
# language
|
||||||
)
|
# )
|
||||||
|
#
|
||||||
return success(data=result, msg="写入成功")
|
# return success(data=result, msg="写入成功")
|
||||||
except BaseException as e:
|
# except BaseException as e:
|
||||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
# # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||||
if hasattr(e, 'exceptions'):
|
# if hasattr(e, 'exceptions'):
|
||||||
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
# error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||||
detailed_error = "; ".join(error_messages)
|
# detailed_error = "; ".join(error_messages)
|
||||||
api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
|
# api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
|
||||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
|
# return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
|
||||||
api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
|
# api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
|
||||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||||
|
#
|
||||||
|
#
|
||||||
@router.post("/writer_service_async", response_model=ApiResponse)
|
# @router.post("/writer_service_async", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
# @cur_workspace_access_guard()
|
||||||
async def write_server_async(
|
# async def write_server_async(
|
||||||
user_input: Write_UserInput,
|
# user_input: Write_UserInput,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
# language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
db: Session = Depends(get_db),
|
# db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
# current_user: User = Depends(get_current_user)
|
||||||
):
|
# ):
|
||||||
"""
|
# """
|
||||||
Async write service endpoint - enqueues write processing to Celery
|
# Async write service endpoint - enqueues write processing to Celery
|
||||||
|
#
|
||||||
Args:
|
# Args:
|
||||||
user_input: Write request containing message and end_user_id
|
# user_input: Write request containing message and end_user_id
|
||||||
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||||
|
#
|
||||||
Returns:
|
# Returns:
|
||||||
Task ID for tracking async operation
|
# Task ID for tracking async operation
|
||||||
Use GET /memory/write_result/{task_id} to check task status and get result
|
# Use GET /memory/write_result/{task_id} to check task status and get result
|
||||||
"""
|
# """
|
||||||
# 使用集中化的语言校验
|
# # 使用集中化的语言校验
|
||||||
language = get_language_from_header(language_type)
|
# language = get_language_from_header(language_type)
|
||||||
|
#
|
||||||
config_id = user_input.config_id
|
# config_id = user_input.config_id
|
||||||
workspace_id = current_user.current_workspace_id
|
# workspace_id = current_user.current_workspace_id
|
||||||
api_logger.info(
|
# api_logger.info(
|
||||||
f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
# f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||||
|
#
|
||||||
# 获取 storage_type,如果为 None 则使用默认值
|
# # 获取 storage_type,如果为 None 则使用默认值
|
||||||
storage_type = workspace_service.get_workspace_storage_type(
|
# storage_type = workspace_service.get_workspace_storage_type(
|
||||||
db=db,
|
# db=db,
|
||||||
workspace_id=workspace_id,
|
# workspace_id=workspace_id,
|
||||||
user=current_user
|
# user=current_user
|
||||||
)
|
# )
|
||||||
if storage_type is None: storage_type = 'neo4j'
|
# if storage_type is None: storage_type = 'neo4j'
|
||||||
user_rag_memory_id = ''
|
# user_rag_memory_id = ''
|
||||||
if workspace_id:
|
# if workspace_id:
|
||||||
|
#
|
||||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
# knowledge = knowledge_repository.get_knowledge_by_name(
|
||||||
db=db,
|
# db=db,
|
||||||
name="USER_RAG_MERORY",
|
# name="USER_RAG_MERORY",
|
||||||
workspace_id=workspace_id
|
# workspace_id=workspace_id
|
||||||
)
|
# )
|
||||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
# if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||||
api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
# api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||||
try:
|
# try:
|
||||||
# 获取标准化的消息列表
|
# # 获取标准化的消息列表
|
||||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
# messages_list = memory_agent_service.get_messages_list(user_input)
|
||||||
|
#
|
||||||
task = celery_app.send_task(
|
# task = celery_app.send_task(
|
||||||
"app.core.memory.agent.write_message",
|
# "app.core.memory.agent.write_message",
|
||||||
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
|
# args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
|
||||||
)
|
# )
|
||||||
api_logger.info(f"Write task queued: {task.id}")
|
# api_logger.info(f"Write task queued: {task.id}")
|
||||||
|
#
|
||||||
return success(data={"task_id": task.id}, msg="写入任务已提交")
|
# return success(data={"task_id": task.id}, msg="写入任务已提交")
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
api_logger.error(f"Async write operation failed: {str(e)}")
|
# api_logger.error(f"Async write operation failed: {str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/read_service", response_model=ApiResponse)
|
@router.post("/read_service", response_model=ApiResponse)
|
||||||
|
|||||||
@@ -54,8 +54,8 @@ router = APIRouter(
|
|||||||
|
|
||||||
@router.get("/info", response_model=ApiResponse)
|
@router.get("/info", response_model=ApiResponse)
|
||||||
async def get_storage_info(
|
async def get_storage_info(
|
||||||
storage_id: str,
|
storage_id: str,
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Example wrapper endpoint - retrieves storage information
|
Example wrapper endpoint - retrieves storage information
|
||||||
@@ -75,17 +75,12 @@ async def get_storage_info(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
|
|
||||||
def create_config(
|
def create_config(
|
||||||
payload: ConfigParamsCreate,
|
payload: ConfigParamsCreate,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
|
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
@@ -107,9 +102,11 @@ def create_config(
|
|||||||
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
|
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
|
||||||
lang = get_language_from_header(x_language_type)
|
lang = get_language_from_header(x_language_type)
|
||||||
if lang == "en":
|
if lang == "en":
|
||||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
|
msg = fail(BizCode.BAD_REQUEST, "Config name already exists",
|
||||||
|
f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
|
||||||
else:
|
else:
|
||||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
|
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在",
|
||||||
|
f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
|
||||||
return JSONResponse(status_code=400, content=msg)
|
return JSONResponse(status_code=400, content=msg)
|
||||||
api_logger.error(f"Create config failed: {err_str}")
|
api_logger.error(f"Create config failed: {err_str}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
|
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
|
||||||
@@ -119,9 +116,11 @@ def create_config(
|
|||||||
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
|
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
|
||||||
lang = get_language_from_header(x_language_type)
|
lang = get_language_from_header(x_language_type)
|
||||||
if lang == "en":
|
if lang == "en":
|
||||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
|
msg = fail(BizCode.BAD_REQUEST, "Config name already exists",
|
||||||
|
f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
|
||||||
else:
|
else:
|
||||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
|
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在",
|
||||||
|
f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
|
||||||
return JSONResponse(status_code=400, content=msg)
|
return JSONResponse(status_code=400, content=msg)
|
||||||
api_logger.error(f"Create config failed: {str(e)}")
|
api_logger.error(f"Create config failed: {str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
|
||||||
@@ -129,10 +128,10 @@ def create_config(
|
|||||||
|
|
||||||
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
|
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
|
||||||
def delete_config(
|
def delete_config(
|
||||||
config_id: UUID|int,
|
config_id: UUID | int,
|
||||||
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
|
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""删除记忆配置(带终端用户保护)
|
"""删除记忆配置(带终端用户保护)
|
||||||
|
|
||||||
@@ -145,7 +144,7 @@ def delete_config(
|
|||||||
force: 设置为 true 可强制删除(即使有终端用户正在使用)
|
force: 设置为 true 可强制删除(即使有终端用户正在使用)
|
||||||
"""
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
config_id=resolve_config_id(config_id, db)
|
config_id = resolve_config_id(config_id, db)
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
|
||||||
@@ -203,9 +202,9 @@ def delete_config(
|
|||||||
|
|
||||||
@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc
|
@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc
|
||||||
def update_config(
|
def update_config(
|
||||||
payload: ConfigUpdate,
|
payload: ConfigUpdate,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||||
@@ -217,7 +216,8 @@ def update_config(
|
|||||||
# 校验至少有一个字段需要更新
|
# 校验至少有一个字段需要更新
|
||||||
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
|
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
|
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空")
|
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段",
|
||||||
|
"config_name, config_desc, scene_id 均为空")
|
||||||
|
|
||||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
|
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
|
||||||
try:
|
try:
|
||||||
@@ -231,9 +231,9 @@ def update_config(
|
|||||||
|
|
||||||
@router.post("/update_config_extracted", response_model=ApiResponse) # 更新数据库中的部分内容 所有业务字段均可选
|
@router.post("/update_config_extracted", response_model=ApiResponse) # 更新数据库中的部分内容 所有业务字段均可选
|
||||||
def update_config_extracted(
|
def update_config_extracted(
|
||||||
payload: ConfigUpdateExtracted,
|
payload: ConfigUpdateExtracted,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||||
@@ -256,11 +256,11 @@ def update_config_extracted(
|
|||||||
# 遗忘引擎配置接口已迁移到 memory_forget_controller.py
|
# 遗忘引擎配置接口已迁移到 memory_forget_controller.py
|
||||||
# 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config
|
# 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config
|
||||||
|
|
||||||
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||||
def read_config_extracted(
|
def read_config_extracted(
|
||||||
config_id: UUID | int,
|
config_id: UUID | int,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
config_id = resolve_config_id(config_id, db)
|
config_id = resolve_config_id(config_id, db)
|
||||||
@@ -278,10 +278,11 @@ def read_config_extracted(
|
|||||||
api_logger.error(f"Read config extracted failed: {str(e)}")
|
api_logger.error(f"Read config extracted failed: {str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e))
|
||||||
|
|
||||||
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
|
|
||||||
|
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
|
||||||
def read_all_config(
|
def read_all_config(
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
@@ -303,10 +304,10 @@ def read_all_config(
|
|||||||
|
|
||||||
@router.post("/pilot_run", response_model=None)
|
@router.post("/pilot_run", response_model=None)
|
||||||
async def pilot_run(
|
async def pilot_run(
|
||||||
payload: ConfigPilotRun,
|
payload: ConfigPilotRun,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> StreamingResponse:
|
) -> StreamingResponse:
|
||||||
# 使用集中化的语言校验
|
# 使用集中化的语言校验
|
||||||
language = get_language_from_header(language_type)
|
language = get_language_from_header(language_type)
|
||||||
@@ -333,9 +334,9 @@ async def pilot_run(
|
|||||||
|
|
||||||
@router.get("/search/kb_type_distribution", response_model=ApiResponse)
|
@router.get("/search/kb_type_distribution", response_model=ApiResponse)
|
||||||
async def get_kb_type_distribution(
|
async def get_kb_type_distribution(
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
api_logger.info(f"KB type distribution requested for end_user_id: {end_user_id}")
|
api_logger.info(f"KB type distribution requested for end_user_id: {end_user_id}")
|
||||||
try:
|
try:
|
||||||
result = await kb_type_distribution(end_user_id)
|
result = await kb_type_distribution(end_user_id)
|
||||||
@@ -347,9 +348,9 @@ async def get_kb_type_distribution(
|
|||||||
|
|
||||||
@router.get("/search/dialogue", response_model=ApiResponse)
|
@router.get("/search/dialogue", response_model=ApiResponse)
|
||||||
async def search_dialogues_num(
|
async def search_dialogues_num(
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
api_logger.info(f"Search dialogue requested for end_user_id: {end_user_id}")
|
api_logger.info(f"Search dialogue requested for end_user_id: {end_user_id}")
|
||||||
try:
|
try:
|
||||||
result = await search_dialogue(end_user_id)
|
result = await search_dialogue(end_user_id)
|
||||||
@@ -361,9 +362,9 @@ async def search_dialogues_num(
|
|||||||
|
|
||||||
@router.get("/search/chunk", response_model=ApiResponse)
|
@router.get("/search/chunk", response_model=ApiResponse)
|
||||||
async def search_chunks_num(
|
async def search_chunks_num(
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
api_logger.info(f"Search chunk requested for end_user_id: {end_user_id}")
|
api_logger.info(f"Search chunk requested for end_user_id: {end_user_id}")
|
||||||
try:
|
try:
|
||||||
result = await search_chunk(end_user_id)
|
result = await search_chunk(end_user_id)
|
||||||
@@ -375,9 +376,9 @@ async def search_chunks_num(
|
|||||||
|
|
||||||
@router.get("/search/statement", response_model=ApiResponse)
|
@router.get("/search/statement", response_model=ApiResponse)
|
||||||
async def search_statements_num(
|
async def search_statements_num(
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
api_logger.info(f"Search statement requested for end_user_id: {end_user_id}")
|
api_logger.info(f"Search statement requested for end_user_id: {end_user_id}")
|
||||||
try:
|
try:
|
||||||
result = await search_statement(end_user_id)
|
result = await search_statement(end_user_id)
|
||||||
@@ -389,9 +390,9 @@ async def search_statements_num(
|
|||||||
|
|
||||||
@router.get("/search/entity", response_model=ApiResponse)
|
@router.get("/search/entity", response_model=ApiResponse)
|
||||||
async def search_entities_num(
|
async def search_entities_num(
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
api_logger.info(f"Search entity requested for end_user_id: {end_user_id}")
|
api_logger.info(f"Search entity requested for end_user_id: {end_user_id}")
|
||||||
try:
|
try:
|
||||||
result = await search_entity(end_user_id)
|
result = await search_entity(end_user_id)
|
||||||
@@ -403,9 +404,9 @@ async def search_entities_num(
|
|||||||
|
|
||||||
@router.get("/search", response_model=ApiResponse)
|
@router.get("/search", response_model=ApiResponse)
|
||||||
async def search_all_num(
|
async def search_all_num(
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
api_logger.info(f"Search all requested for end_user_id: {end_user_id}")
|
api_logger.info(f"Search all requested for end_user_id: {end_user_id}")
|
||||||
try:
|
try:
|
||||||
result = await search_all(end_user_id)
|
result = await search_all(end_user_id)
|
||||||
@@ -417,9 +418,9 @@ async def search_all_num(
|
|||||||
|
|
||||||
@router.get("/search/detials", response_model=ApiResponse)
|
@router.get("/search/detials", response_model=ApiResponse)
|
||||||
async def search_entities_detials(
|
async def search_entities_detials(
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
api_logger.info(f"Search details requested for end_user_id: {end_user_id}")
|
api_logger.info(f"Search details requested for end_user_id: {end_user_id}")
|
||||||
try:
|
try:
|
||||||
result = await search_detials(end_user_id)
|
result = await search_detials(end_user_id)
|
||||||
@@ -431,9 +432,9 @@ async def search_entities_detials(
|
|||||||
|
|
||||||
@router.get("/search/edges", response_model=ApiResponse)
|
@router.get("/search/edges", response_model=ApiResponse)
|
||||||
async def search_entity_edges(
|
async def search_entity_edges(
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
api_logger.info(f"Search edges requested for end_user_id: {end_user_id}")
|
api_logger.info(f"Search edges requested for end_user_id: {end_user_id}")
|
||||||
try:
|
try:
|
||||||
result = await search_edges(end_user_id)
|
result = await search_edges(end_user_id)
|
||||||
@@ -443,14 +444,12 @@ async def search_entity_edges(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
|
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
|
||||||
async def get_hot_memory_tags_api(
|
async def get_hot_memory_tags_api(
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
获取热门记忆标签(带Redis缓存)
|
获取热门记忆标签(带Redis缓存)
|
||||||
|
|
||||||
@@ -505,8 +504,8 @@ async def get_hot_memory_tags_api(
|
|||||||
|
|
||||||
@router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse)
|
@router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse)
|
||||||
async def clear_hot_memory_tags_cache(
|
async def clear_hot_memory_tags_cache(
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
清除热门标签缓存
|
清除热门标签缓存
|
||||||
|
|
||||||
@@ -543,7 +542,7 @@ async def clear_hot_memory_tags_cache(
|
|||||||
|
|
||||||
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
|
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
|
||||||
async def get_recent_activity_stats_api(
|
async def get_recent_activity_stats_api(
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None
|
workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None
|
||||||
api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}")
|
api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}")
|
||||||
@@ -553,4 +552,3 @@ async def get_recent_activity_stats_api(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Recent activity stats failed: {str(e)}")
|
api_logger.error(f"Recent activity stats failed: {str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))
|
||||||
|
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ def get_model_strategies():
|
|||||||
@router.get("", response_model=ApiResponse)
|
@router.get("", response_model=ApiResponse)
|
||||||
def get_model_list(
|
def get_model_list(
|
||||||
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||||
|
capability: Optional[list[str]] = Query(None, description="能力筛选(支持多个,如 ?capability=chat 或 ?capability=chat, embedding)"),
|
||||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
||||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||||
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||||
@@ -74,10 +75,21 @@ def get_model_list(
|
|||||||
unique_flat_type = list(dict.fromkeys(flat_type))
|
unique_flat_type = list(dict.fromkeys(flat_type))
|
||||||
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
||||||
|
|
||||||
|
capability_list = []
|
||||||
|
if capability is not None:
|
||||||
|
flat_capability = []
|
||||||
|
for item in capability:
|
||||||
|
split_items = [c.strip() for c in item.split(', ') if c.strip()]
|
||||||
|
flat_capability.extend(split_items)
|
||||||
|
|
||||||
|
unique_flat_capability = list(dict.fromkeys(flat_capability))
|
||||||
|
capability_list = unique_flat_capability
|
||||||
|
|
||||||
api_logger.error(f"获取模型type_list: {type_list}")
|
api_logger.error(f"获取模型type_list: {type_list}")
|
||||||
query = model_schema.ModelConfigQuery(
|
query = model_schema.ModelConfigQuery(
|
||||||
type=type_list,
|
type=type_list,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
|
capability=capability_list,
|
||||||
is_active=is_active,
|
is_active=is_active,
|
||||||
is_public=is_public,
|
is_public=is_public,
|
||||||
search=search,
|
search=search,
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
import datetime
|
import datetime
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from fastapi import APIRouter, Depends,Header
|
from fastapi import APIRouter, Depends, Header
|
||||||
|
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.core.language_utils import get_language_from_header
|
from app.core.language_utils import get_language_from_header
|
||||||
@@ -19,7 +19,7 @@ from app.services.user_memory_service import (
|
|||||||
analytics_graph_data,
|
analytics_graph_data,
|
||||||
analytics_community_graph_data,
|
analytics_community_graph_data,
|
||||||
)
|
)
|
||||||
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
|
from app.services.memory_entity_relationship_service import MemoryEntityService, MemoryEmotion, MemoryInteraction
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.schemas.memory_storage_schema import GenerateCacheRequest
|
from app.schemas.memory_storage_schema import GenerateCacheRequest
|
||||||
from app.repositories.workspace_repository import WorkspaceRepository
|
from app.repositories.workspace_repository import WorkspaceRepository
|
||||||
@@ -45,9 +45,9 @@ router = APIRouter(
|
|||||||
|
|
||||||
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
||||||
async def get_memory_insight_report_api(
|
async def get_memory_insight_report_api(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
获取缓存的记忆洞察报告
|
获取缓存的记忆洞察报告
|
||||||
@@ -73,10 +73,10 @@ async def get_memory_insight_report_api(
|
|||||||
|
|
||||||
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
||||||
async def get_user_summary_api(
|
async def get_user_summary_api(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
获取缓存的用户摘要
|
获取缓存的用户摘要
|
||||||
@@ -102,7 +102,7 @@ async def get_user_summary_api(
|
|||||||
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||||
try:
|
try:
|
||||||
# 调用服务层获取缓存数据
|
# 调用服务层获取缓存数据
|
||||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language)
|
result = await user_memory_service.get_cached_user_summary(db, end_user_id, model_id, language)
|
||||||
|
|
||||||
if result["is_cached"]:
|
if result["is_cached"]:
|
||||||
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
|
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
|
||||||
@@ -117,10 +117,10 @@ async def get_user_summary_api(
|
|||||||
|
|
||||||
@router.post("/analytics/generate_cache", response_model=ApiResponse)
|
@router.post("/analytics/generate_cache", response_model=ApiResponse)
|
||||||
async def generate_cache_api(
|
async def generate_cache_api(
|
||||||
request: GenerateCacheRequest,
|
request: GenerateCacheRequest,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
手动触发缓存生成
|
手动触发缓存生成
|
||||||
@@ -155,10 +155,12 @@ async def generate_cache_api(
|
|||||||
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
|
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
|
||||||
|
|
||||||
# 生成记忆洞察
|
# 生成记忆洞察
|
||||||
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id, language=language)
|
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id,
|
||||||
|
language=language)
|
||||||
|
|
||||||
# 生成用户摘要
|
# 生成用户摘要
|
||||||
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id, language=language)
|
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id,
|
||||||
|
language=language)
|
||||||
|
|
||||||
# 构建响应
|
# 构建响应
|
||||||
result = {
|
result = {
|
||||||
@@ -209,9 +211,9 @@ async def generate_cache_api(
|
|||||||
|
|
||||||
@router.get("/analytics/node_statistics", response_model=ApiResponse)
|
@router.get("/analytics/node_statistics", response_model=ApiResponse)
|
||||||
async def get_node_statistics_api(
|
async def get_node_statistics_api(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
@@ -220,7 +222,8 @@ async def get_node_statistics_api(
|
|||||||
api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间")
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
api_logger.info(f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
|
api_logger.info(
|
||||||
|
f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 调用新的记忆类型统计函数
|
# 调用新的记忆类型统计函数
|
||||||
@@ -228,21 +231,23 @@ async def get_node_statistics_api(
|
|||||||
|
|
||||||
# 计算总数用于日志
|
# 计算总数用于日志
|
||||||
total_count = sum(item["count"] for item in result)
|
total_count = sum(item["count"] for item in result)
|
||||||
api_logger.info(f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
|
api_logger.info(
|
||||||
|
f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
|
||||||
return success(data=result, msg="查询成功")
|
return success(data=result, msg="查询成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}")
|
api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/graph_data", response_model=ApiResponse)
|
@router.get("/analytics/graph_data", response_model=ApiResponse)
|
||||||
async def get_graph_data_api(
|
async def get_graph_data_api(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
node_types: Optional[str] = None,
|
node_types: Optional[str] = None,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
depth: int = 1,
|
depth: int = 1,
|
||||||
center_node_id: Optional[str] = None,
|
center_node_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
@@ -298,9 +303,9 @@ async def get_graph_data_api(
|
|||||||
|
|
||||||
@router.get("/analytics/community_graph", response_model=ApiResponse)
|
@router.get("/analytics/community_graph", response_model=ApiResponse)
|
||||||
async def get_community_graph_data_api(
|
async def get_community_graph_data_api(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
@@ -334,9 +339,9 @@ async def get_community_graph_data_api(
|
|||||||
|
|
||||||
@router.get("/read_end_user/profile", response_model=ApiResponse)
|
@router.get("/read_end_user/profile", response_model=ApiResponse)
|
||||||
async def get_end_user_profile(
|
async def get_end_user_profile(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
workspace_repo = WorkspaceRepository(db)
|
workspace_repo = WorkspaceRepository(db)
|
||||||
@@ -385,9 +390,9 @@ async def get_end_user_profile(
|
|||||||
|
|
||||||
@router.post("/updated_end_user/profile", response_model=ApiResponse)
|
@router.post("/updated_end_user/profile", response_model=ApiResponse)
|
||||||
async def update_end_user_profile(
|
async def update_end_user_profile(
|
||||||
profile_update: EndUserProfileUpdate,
|
profile_update: EndUserProfileUpdate,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
更新终端用户的基本信息
|
更新终端用户的基本信息
|
||||||
@@ -427,15 +432,18 @@ async def update_end_user_profile(
|
|||||||
# 只有未预期的错误才使用 INTERNAL_ERROR
|
# 只有未预期的错误才使用 INTERNAL_ERROR
|
||||||
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg)
|
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
|
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
|
||||||
async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str = Header(default=None, alias="X-Language-Type"),
|
async def memory_space_timeline_of_shared_memories(
|
||||||
current_user: User = Depends(get_current_user),
|
id: str, label: str,
|
||||||
db: Session = Depends(get_db),
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
):
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
# 使用集中化的语言校验
|
# 使用集中化的语言校验
|
||||||
language = get_language_from_header(language_type)
|
language = get_language_from_header(language_type)
|
||||||
|
|
||||||
workspace_id=current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
workspace_repo = WorkspaceRepository(db)
|
workspace_repo = WorkspaceRepository(db)
|
||||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||||
|
|
||||||
@@ -447,11 +455,13 @@ async def memory_space_timeline_of_shared_memories(id: str, label: str,language_
|
|||||||
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language)
|
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language)
|
||||||
|
|
||||||
return success(data=timeline_memories_result, msg="共同记忆时间线")
|
return success(data=timeline_memories_result, msg="共同记忆时间线")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
|
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
|
||||||
async def memory_space_relationship_evolution(id: str, label: str,
|
async def memory_space_relationship_evolution(id: str, label: str,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}")
|
api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}")
|
||||||
|
|
||||||
|
|||||||
@@ -598,8 +598,10 @@ class LangChainAgent:
|
|||||||
for msg in reversed(output_messages):
|
for msg in reversed(output_messages):
|
||||||
if isinstance(msg, AIMessage):
|
if isinstance(msg, AIMessage):
|
||||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens",
|
total_tokens = response_meta.get("token_usage", {}).get(
|
||||||
0) if response_meta else 0
|
"total_tokens",
|
||||||
|
0
|
||||||
|
) if response_meta else 0
|
||||||
yield total_tokens
|
yield total_tokens
|
||||||
break
|
break
|
||||||
if memory_flag:
|
if memory_flag:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ async def get_chunked_dialogs(
|
|||||||
chunker_strategy: str = "RecursiveChunker",
|
chunker_strategy: str = "RecursiveChunker",
|
||||||
end_user_id: str = "group_1",
|
end_user_id: str = "group_1",
|
||||||
messages: list = None,
|
messages: list = None,
|
||||||
ref_id: str = "wyl_20251027",
|
ref_id: str = "",
|
||||||
config_id: str = None
|
config_id: str = None
|
||||||
) -> List[DialogData]:
|
) -> List[DialogData]:
|
||||||
"""Generate chunks from structured messages using the specified chunker strategy.
|
"""Generate chunks from structured messages using the specified chunker strategy.
|
||||||
@@ -40,12 +40,13 @@ async def get_chunked_dialogs(
|
|||||||
|
|
||||||
role = msg['role']
|
role = msg['role']
|
||||||
content = msg['content']
|
content = msg['content']
|
||||||
|
files = msg.get("file_content", [])
|
||||||
|
|
||||||
if role not in ['user', 'assistant']:
|
if role not in ['user', 'assistant']:
|
||||||
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
|
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
|
||||||
|
|
||||||
if content.strip():
|
if content.strip():
|
||||||
conversation_messages.append(ConversationMessage(role=role, msg=content.strip()))
|
conversation_messages.append(ConversationMessage(role=role, msg=content.strip(), files=files))
|
||||||
|
|
||||||
if not conversation_messages:
|
if not conversation_messages:
|
||||||
raise ValueError("Message list cannot be empty after filtering")
|
raise ValueError("Message list cannot be empty after filtering")
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
|||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
@@ -13,7 +14,8 @@ from dotenv import load_dotenv
|
|||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
||||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import memory_summary_generation
|
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import \
|
||||||
|
memory_summary_generation
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.core.memory.utils.log.logging_utils import log_time
|
from app.core.memory.utils.log.logging_utils import log_time
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
@@ -23,18 +25,17 @@ from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo
|
|||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
from app.schemas.memory_config_schema import MemoryConfig
|
from app.schemas.memory_config_schema import MemoryConfig
|
||||||
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def write(
|
async def write(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
memory_config: MemoryConfig,
|
memory_config: MemoryConfig,
|
||||||
messages: list,
|
messages: list,
|
||||||
ref_id: str = "wyl20251027",
|
ref_id: str = "",
|
||||||
language: str = "zh",
|
language: str = "zh",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Execute the complete knowledge extraction pipeline.
|
Execute the complete knowledge extraction pipeline.
|
||||||
@@ -43,9 +44,11 @@ async def write(
|
|||||||
end_user_id: Group identifier
|
end_user_id: Group identifier
|
||||||
memory_config: MemoryConfig object containing all configuration
|
memory_config: MemoryConfig object containing all configuration
|
||||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||||
ref_id: Reference ID, defaults to "wyl20251027"
|
ref_id: Reference ID, defaults to ""
|
||||||
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
|
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
|
||||||
"""
|
"""
|
||||||
|
if not ref_id:
|
||||||
|
ref_id = uuid.uuid4().hex
|
||||||
# Extract config values
|
# Extract config values
|
||||||
embedding_model_id = str(memory_config.embedding_model_id)
|
embedding_model_id = str(memory_config.embedding_model_id)
|
||||||
chunker_strategy = memory_config.chunker_strategy
|
chunker_strategy = memory_config.chunker_strategy
|
||||||
@@ -135,9 +138,11 @@ async def write(
|
|||||||
all_chunk_nodes,
|
all_chunk_nodes,
|
||||||
all_statement_nodes,
|
all_statement_nodes,
|
||||||
all_entity_nodes,
|
all_entity_nodes,
|
||||||
|
all_perceptual_nodes,
|
||||||
all_statement_chunk_edges,
|
all_statement_chunk_edges,
|
||||||
all_statement_entity_edges,
|
all_statement_entity_edges,
|
||||||
all_entity_entity_edges,
|
all_entity_entity_edges,
|
||||||
|
all_perceptual_edges,
|
||||||
all_dedup_details,
|
all_dedup_details,
|
||||||
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
|
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
|
||||||
|
|
||||||
@@ -162,9 +167,11 @@ async def write(
|
|||||||
chunk_nodes=all_chunk_nodes,
|
chunk_nodes=all_chunk_nodes,
|
||||||
statement_nodes=all_statement_nodes,
|
statement_nodes=all_statement_nodes,
|
||||||
entity_nodes=all_entity_nodes,
|
entity_nodes=all_entity_nodes,
|
||||||
|
perceptual_nodes=all_perceptual_nodes,
|
||||||
statement_chunk_edges=all_statement_chunk_edges,
|
statement_chunk_edges=all_statement_chunk_edges,
|
||||||
statement_entity_edges=all_statement_entity_edges,
|
statement_entity_edges=all_statement_entity_edges,
|
||||||
entity_edges=all_entity_entity_edges,
|
entity_edges=all_entity_entity_edges,
|
||||||
|
perceptual_edges=all_perceptual_edges,
|
||||||
connector=neo4j_connector,
|
connector=neo4j_connector,
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
@@ -173,7 +180,8 @@ async def write(
|
|||||||
await _trigger_clustering_sync(
|
await _trigger_clustering_sync(
|
||||||
all_entity_nodes,
|
all_entity_nodes,
|
||||||
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||||
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
embedding_model_id=str(
|
||||||
|
memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
@@ -208,9 +216,8 @@ async def write(
|
|||||||
summaries = await memory_summary_generation(
|
summaries = await memory_summary_generation(
|
||||||
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language
|
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language
|
||||||
)
|
)
|
||||||
|
ms_connector = Neo4jConnector()
|
||||||
try:
|
try:
|
||||||
ms_connector = Neo4jConnector()
|
|
||||||
await add_memory_summary_nodes(summaries, ms_connector)
|
await add_memory_summary_nodes(summaries, ms_connector)
|
||||||
await add_memory_summary_statement_edges(summaries, ms_connector)
|
await add_memory_summary_statement_edges(summaries, ms_connector)
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
from typing import Any, List
|
|
||||||
import re
|
|
||||||
import os
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
# Fix tokenizer parallelism warning
|
# Fix tokenizer parallelism warning
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
@@ -246,6 +246,7 @@ class ChunkerClient:
|
|||||||
"total_sub_chunks": len(sub_chunks),
|
"total_sub_chunks": len(sub_chunks),
|
||||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||||
},
|
},
|
||||||
|
files=msg.files
|
||||||
)
|
)
|
||||||
dialogue.chunks.append(chunk)
|
dialogue.chunks.append(chunk)
|
||||||
else:
|
else:
|
||||||
@@ -258,6 +259,7 @@ class ChunkerClient:
|
|||||||
"message_role": msg.role,
|
"message_role": msg.role,
|
||||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||||
},
|
},
|
||||||
|
files=msg.files
|
||||||
)
|
)
|
||||||
dialogue.chunks.append(chunk)
|
dialogue.chunks.append(chunk)
|
||||||
|
|
||||||
|
|||||||
@@ -114,7 +114,7 @@ class Edge(BaseModel):
|
|||||||
end_user_id: str = Field(..., description="The end user ID of the edge.")
|
end_user_id: str = Field(..., description="The end user ID of the edge.")
|
||||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||||
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
|
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
|
||||||
expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.")
|
expired_at: Optional[datetime] = Field(default=None, description="The expired time of the edge from system perspective.")
|
||||||
|
|
||||||
|
|
||||||
class ChunkEdge(Edge):
|
class ChunkEdge(Edge):
|
||||||
@@ -175,6 +175,12 @@ class EntityEntityEdge(Edge):
|
|||||||
return parse_historical_datetime(v)
|
return parse_historical_datetime(v)
|
||||||
|
|
||||||
|
|
||||||
|
class PerceptualEdge(Edge):
|
||||||
|
"""Edge connecting perceptual nodes to their source chunks
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Node(BaseModel):
|
class Node(BaseModel):
|
||||||
"""Base class for all graph nodes in the knowledge graph.
|
"""Base class for all graph nodes in the knowledge graph.
|
||||||
|
|
||||||
@@ -206,7 +212,8 @@ class DialogueNode(Node):
|
|||||||
ref_id: str = Field(..., description="Reference identifier of the dialog")
|
ref_id: str = Field(..., description="Reference identifier of the dialog")
|
||||||
content: str = Field(..., description="Dialogue content")
|
content: str = Field(..., description="Dialogue content")
|
||||||
dialog_embedding: Optional[List[float]] = Field(None, description="Dialog embedding vector")
|
dialog_embedding: Optional[List[float]] = Field(None, description="Dialog embedding vector")
|
||||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialogue (integer or string)")
|
config_id: Optional[int | str] = Field(None,
|
||||||
|
description="Configuration ID used to process this dialogue (integer or string)")
|
||||||
|
|
||||||
|
|
||||||
class StatementNode(Node):
|
class StatementNode(Node):
|
||||||
@@ -281,7 +288,8 @@ class StatementNode(Node):
|
|||||||
statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector")
|
statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector")
|
||||||
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
|
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
|
||||||
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
|
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
|
||||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)")
|
config_id: Optional[int | str] = Field(None,
|
||||||
|
description="Configuration ID used to process this statement (integer or string)")
|
||||||
|
|
||||||
# ACT-R Memory Activation Properties
|
# ACT-R Memory Activation Properties
|
||||||
importance_score: float = Field(
|
importance_score: float = Field(
|
||||||
@@ -416,7 +424,8 @@ class ExtractedEntityNode(Node):
|
|||||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
# fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
# fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
||||||
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
|
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
|
||||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
|
config_id: Optional[int | str] = Field(None,
|
||||||
|
description="Configuration ID used to process this entity (integer or string)")
|
||||||
|
|
||||||
# ACT-R Memory Activation Properties
|
# ACT-R Memory Activation Properties
|
||||||
importance_score: float = Field(
|
importance_score: float = Field(
|
||||||
@@ -453,7 +462,7 @@ class ExtractedEntityNode(Node):
|
|||||||
|
|
||||||
@field_validator('aliases', mode='before')
|
@field_validator('aliases', mode='before')
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
|
def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
|
||||||
"""Validate and clean aliases field using utility function.
|
"""Validate and clean aliases field using utility function.
|
||||||
|
|
||||||
This validator ensures that the aliases field is always a valid list of strings.
|
This validator ensures that the aliases field is always a valid list of strings.
|
||||||
@@ -507,7 +516,8 @@ class MemorySummaryNode(Node):
|
|||||||
memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory")
|
memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory")
|
||||||
summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary")
|
summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary")
|
||||||
metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary")
|
metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary")
|
||||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)")
|
config_id: Optional[int | str] = Field(None,
|
||||||
|
description="Configuration ID used to process this summary (integer or string)")
|
||||||
|
|
||||||
# ACT-R Forgetting Engine Properties
|
# ACT-R Forgetting Engine Properties
|
||||||
original_statement_id: Optional[str] = Field(
|
original_statement_id: Optional[str] = Field(
|
||||||
@@ -549,3 +559,18 @@ class MemorySummaryNode(Node):
|
|||||||
ge=0,
|
ge=0,
|
||||||
description="Total number of times this node has been accessed (reset to 1 on creation)"
|
description="Total number of times this node has been accessed (reset to 1 on creation)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PerceptualNode(Node):
|
||||||
|
"""Node representing a multimodal message in the knowledge graph.
|
||||||
|
"""
|
||||||
|
perceptual_type: int
|
||||||
|
file_path: str
|
||||||
|
file_name: str
|
||||||
|
file_ext: str
|
||||||
|
summary: str
|
||||||
|
keywords: list[str]
|
||||||
|
topic: str
|
||||||
|
domain: str
|
||||||
|
file_type: str
|
||||||
|
summary_embedding: list[float] | None
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ class ConversationMessage(BaseModel):
|
|||||||
"""
|
"""
|
||||||
role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').")
|
role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').")
|
||||||
msg: str = Field(..., description="The text content of the message.")
|
msg: str = Field(..., description="The text content of the message.")
|
||||||
|
files: list[tuple] = Field(default_factory=list, description="The file content of the message", exclude=True)
|
||||||
|
|
||||||
|
|
||||||
class TemporalValidityRange(BaseModel):
|
class TemporalValidityRange(BaseModel):
|
||||||
@@ -130,7 +131,8 @@ class Chunk(BaseModel):
|
|||||||
content: str = Field(..., description="The content of the chunk as a string.")
|
content: str = Field(..., description="The content of the chunk as a string.")
|
||||||
speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).")
|
speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).")
|
||||||
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
|
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
|
||||||
chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.")
|
files: list[tuple] = Field(default_factory=list, description="List of files in the chunk.")
|
||||||
|
chunk_embedding: Optional[List[float]] = Field(default=None, description="The embedding vector of the chunk.")
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -25,17 +25,17 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
|||||||
|
|
||||||
|
|
||||||
async def dedup_layers_and_merge_and_return(
|
async def dedup_layers_and_merge_and_return(
|
||||||
dialogue_nodes: List[DialogueNode],
|
dialogue_nodes: List[DialogueNode],
|
||||||
chunk_nodes: List[ChunkNode],
|
chunk_nodes: List[ChunkNode],
|
||||||
statement_nodes: List[StatementNode],
|
statement_nodes: List[StatementNode],
|
||||||
entity_nodes: List[ExtractedEntityNode],
|
entity_nodes: List[ExtractedEntityNode],
|
||||||
statement_chunk_edges: List[StatementChunkEdge],
|
statement_chunk_edges: List[StatementChunkEdge],
|
||||||
statement_entity_edges: List[StatementEntityEdge],
|
statement_entity_edges: List[StatementEntityEdge],
|
||||||
entity_entity_edges: List[EntityEntityEdge],
|
entity_entity_edges: List[EntityEntityEdge],
|
||||||
dialog_data_list: List[DialogData],
|
dialog_data_list: List[DialogData],
|
||||||
pipeline_config: ExtractionPipelineConfig,
|
pipeline_config: ExtractionPipelineConfig,
|
||||||
connector: Optional[Neo4jConnector] = None,
|
connector: Optional[Neo4jConnector] = None,
|
||||||
llm_client = None,
|
llm_client=None,
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
List[DialogueNode],
|
List[DialogueNode],
|
||||||
List[ChunkNode],
|
List[ChunkNode],
|
||||||
@@ -44,7 +44,7 @@ async def dedup_layers_and_merge_and_return(
|
|||||||
List[StatementChunkEdge],
|
List[StatementChunkEdge],
|
||||||
List[StatementEntityEdge],
|
List[StatementEntityEdge],
|
||||||
List[EntityEntityEdge],
|
List[EntityEntityEdge],
|
||||||
dict, # 新增:返回去重详情
|
dict
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
执行两层实体去重与融合:
|
执行两层实体去重与融合:
|
||||||
|
|||||||
@@ -32,10 +32,11 @@ from app.core.memory.models.graph_models import (
|
|||||||
StatementChunkEdge,
|
StatementChunkEdge,
|
||||||
StatementEntityEdge,
|
StatementEntityEdge,
|
||||||
StatementNode,
|
StatementNode,
|
||||||
|
PerceptualEdge,
|
||||||
|
PerceptualNode
|
||||||
)
|
)
|
||||||
from app.core.memory.models.message_models import DialogData
|
from app.core.memory.models.message_models import DialogData
|
||||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
|
||||||
from app.core.memory.models.variate_config import (
|
from app.core.memory.models.variate_config import (
|
||||||
ExtractionPipelineConfig,
|
ExtractionPipelineConfig,
|
||||||
)
|
)
|
||||||
@@ -46,7 +47,6 @@ from app.core.memory.storage_services.extraction_engine.knowledge_extraction.emb
|
|||||||
embedding_generation,
|
embedding_generation,
|
||||||
generate_entity_embeddings_from_triplets,
|
generate_entity_embeddings_from_triplets,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 导入各个提取模块
|
# 导入各个提取模块
|
||||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.statement_extraction import (
|
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.statement_extraction import (
|
||||||
StatementExtractor,
|
StatementExtractor,
|
||||||
@@ -90,16 +90,16 @@ class ExtractionOrchestrator:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
embedder_client: OpenAIEmbedderClient,
|
embedder_client: OpenAIEmbedderClient,
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
config: Optional[ExtractionPipelineConfig] = None,
|
config: Optional[ExtractionPipelineConfig] = None,
|
||||||
progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None,
|
progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None,
|
||||||
embedding_id: Optional[str] = None,
|
embedding_id: Optional[str] = None,
|
||||||
ontology_types: Optional[OntologyTypeList] = None,
|
ontology_types: Optional[OntologyTypeList] = None,
|
||||||
enable_general_types: bool = True,
|
enable_general_types: bool = True,
|
||||||
language: str = "zh",
|
language: str = "zh",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化流水线编排器
|
初始化流水线编排器
|
||||||
@@ -157,19 +157,27 @@ class ExtractionOrchestrator:
|
|||||||
llm_client=llm_client,
|
llm_client=llm_client,
|
||||||
config=self.config.statement_extraction,
|
config=self.config.statement_extraction,
|
||||||
)
|
)
|
||||||
self.triplet_extractor = TripletExtractor(llm_client=llm_client,ontology_types=self.ontology_types, language=language)
|
self.triplet_extractor = TripletExtractor(llm_client=llm_client, ontology_types=self.ontology_types,
|
||||||
|
language=language)
|
||||||
self.temporal_extractor = TemporalExtractor(llm_client=llm_client)
|
self.temporal_extractor = TemporalExtractor(llm_client=llm_client)
|
||||||
|
|
||||||
logger.info("ExtractionOrchestrator 初始化完成")
|
logger.info("ExtractionOrchestrator 初始化完成")
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
dialog_data_list: List[DialogData],
|
dialog_data_list: List[DialogData],
|
||||||
is_pilot_run: bool = False,
|
is_pilot_run: bool = False,
|
||||||
) -> Tuple[
|
) -> tuple[
|
||||||
Tuple[List[DialogueNode], List[ChunkNode], List[StatementNode]],
|
list[DialogueNode],
|
||||||
Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]],
|
list[ChunkNode],
|
||||||
Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]],
|
list[StatementNode],
|
||||||
|
list[ExtractedEntityNode],
|
||||||
|
list[PerceptualNode],
|
||||||
|
list[StatementChunkEdge],
|
||||||
|
list[StatementEntityEdge],
|
||||||
|
list[EntityEntityEdge],
|
||||||
|
list[PerceptualEdge],
|
||||||
|
dict
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
运行完整的知识提取流水线(优化版:并行执行)
|
运行完整的知识提取流水线(优化版:并行执行)
|
||||||
@@ -208,7 +216,6 @@ class ExtractionOrchestrator:
|
|||||||
for dialog in dialog_data_list:
|
for dialog in dialog_data_list:
|
||||||
for chunk in dialog.chunks:
|
for chunk in dialog.chunks:
|
||||||
all_statements_list.extend(chunk.statements)
|
all_statements_list.extend(chunk.statements)
|
||||||
len(all_statements_list)
|
|
||||||
|
|
||||||
# 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成
|
# 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成
|
||||||
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成")
|
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成")
|
||||||
@@ -230,10 +237,6 @@ class ExtractionOrchestrator:
|
|||||||
all_entities_list.extend(triplet_info.entities)
|
all_entities_list.extend(triplet_info.entities)
|
||||||
all_triplets_list.extend(triplet_info.triplets)
|
all_triplets_list.extend(triplet_info.triplets)
|
||||||
|
|
||||||
len(all_entities_list)
|
|
||||||
len(all_triplets_list)
|
|
||||||
sum(len(temporal_map) for temporal_map in temporal_maps)
|
|
||||||
|
|
||||||
# 步骤 3: 生成实体嵌入(依赖三元组提取结果)
|
# 步骤 3: 生成实体嵌入(依赖三元组提取结果)
|
||||||
logger.info("步骤 3/6: 生成实体嵌入")
|
logger.info("步骤 3/6: 生成实体嵌入")
|
||||||
triplet_maps = await self._generate_entity_embeddings(triplet_maps)
|
triplet_maps = await self._generate_entity_embeddings(triplet_maps)
|
||||||
@@ -260,9 +263,11 @@ class ExtractionOrchestrator:
|
|||||||
chunk_nodes,
|
chunk_nodes,
|
||||||
statement_nodes,
|
statement_nodes,
|
||||||
entity_nodes,
|
entity_nodes,
|
||||||
|
perceptual_nodes,
|
||||||
statement_chunk_edges,
|
statement_chunk_edges,
|
||||||
statement_entity_edges,
|
statement_entity_edges,
|
||||||
entity_entity_edges,
|
entity_entity_edges,
|
||||||
|
perceptual_edges
|
||||||
) = await self._create_nodes_and_edges(dialog_data_list)
|
) = await self._create_nodes_and_edges(dialog_data_list)
|
||||||
|
|
||||||
# 导出去重前的测试输入文档(试运行和正式模式都需要,用于生成结果汇总)
|
# 导出去重前的测试输入文档(试运行和正式模式都需要,用于生成结果汇总)
|
||||||
@@ -276,7 +281,16 @@ class ExtractionOrchestrator:
|
|||||||
|
|
||||||
# 注意:deduplication 消息已在创建节点和边完成后立即发送
|
# 注意:deduplication 消息已在创建节点和边完成后立即发送
|
||||||
|
|
||||||
result = await self._run_dedup_and_write_summary(
|
(
|
||||||
|
dialogue_nodes,
|
||||||
|
chunk_nodes,
|
||||||
|
statement_nodes,
|
||||||
|
entity_nodes,
|
||||||
|
statement_chunk_edges,
|
||||||
|
statement_entity_edges,
|
||||||
|
entity_entity_edges,
|
||||||
|
dialog_data_list,
|
||||||
|
) = await self._run_dedup_and_write_summary(
|
||||||
dialogue_nodes,
|
dialogue_nodes,
|
||||||
chunk_nodes,
|
chunk_nodes,
|
||||||
statement_nodes,
|
statement_nodes,
|
||||||
@@ -287,17 +301,26 @@ class ExtractionOrchestrator:
|
|||||||
dialog_data_list,
|
dialog_data_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
logger.info(f"知识提取流水线运行完成({mode_str})")
|
logger.info(f"知识提取流水线运行完成({mode_str})")
|
||||||
return result
|
return (
|
||||||
|
dialogue_nodes,
|
||||||
|
chunk_nodes,
|
||||||
|
statement_nodes,
|
||||||
|
entity_nodes,
|
||||||
|
perceptual_nodes,
|
||||||
|
statement_chunk_edges,
|
||||||
|
statement_entity_edges,
|
||||||
|
entity_entity_edges,
|
||||||
|
perceptual_edges,
|
||||||
|
dialog_data_list,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"知识提取流水线运行失败: {e}", exc_info=True)
|
logger.error(f"知识提取流水线运行失败: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _extract_statements(
|
async def _extract_statements(
|
||||||
self, dialog_data_list: List[DialogData]
|
self, dialog_data_list: List[DialogData]
|
||||||
) -> List[DialogData]:
|
) -> List[DialogData]:
|
||||||
"""
|
"""
|
||||||
从对话中提取陈述句(流式输出版本:边提取边发送进度)
|
从对话中提取陈述句(流式输出版本:边提取边发送进度)
|
||||||
@@ -395,7 +418,7 @@ class ExtractionOrchestrator:
|
|||||||
return dialog_data_list
|
return dialog_data_list
|
||||||
|
|
||||||
async def _extract_triplets(
|
async def _extract_triplets(
|
||||||
self, dialog_data_list: List[DialogData]
|
self, dialog_data_list: List[DialogData]
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
从对话中提取三元组(流式输出版本:边提取边发送进度)
|
从对话中提取三元组(流式输出版本:边提取边发送进度)
|
||||||
@@ -478,7 +501,7 @@ class ExtractionOrchestrator:
|
|||||||
return triplet_maps
|
return triplet_maps
|
||||||
|
|
||||||
async def _extract_temporal(
|
async def _extract_temporal(
|
||||||
self, dialog_data_list: List[DialogData]
|
self, dialog_data_list: List[DialogData]
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
从对话中提取时间信息(流式输出版本:边提取边发送进度)
|
从对话中提取时间信息(流式输出版本:边提取边发送进度)
|
||||||
@@ -585,7 +608,7 @@ class ExtractionOrchestrator:
|
|||||||
return temporal_maps
|
return temporal_maps
|
||||||
|
|
||||||
async def _extract_emotions(
|
async def _extract_emotions(
|
||||||
self, dialog_data_list: List[DialogData]
|
self, dialog_data_list: List[DialogData]
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
从对话中提取情绪信息(仅针对用户消息,全局陈述句级并行)
|
从对话中提取情绪信息(仅针对用户消息,全局陈述句级并行)
|
||||||
@@ -706,7 +729,7 @@ class ExtractionOrchestrator:
|
|||||||
return emotion_maps
|
return emotion_maps
|
||||||
|
|
||||||
async def _parallel_extract_and_embed(
|
async def _parallel_extract_and_embed(
|
||||||
self, dialog_data_list: List[DialogData]
|
self, dialog_data_list: List[DialogData]
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
List[Dict[str, Any]],
|
List[Dict[str, Any]],
|
||||||
List[Dict[str, Any]],
|
List[Dict[str, Any]],
|
||||||
@@ -777,7 +800,7 @@ class ExtractionOrchestrator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _generate_basic_embeddings(
|
async def _generate_basic_embeddings(
|
||||||
self, dialog_data_list: List[DialogData]
|
self, dialog_data_list: List[DialogData]
|
||||||
) -> Tuple[List[Dict[str, List[float]]], List[Dict[str, List[float]]], List[List[float]]]:
|
) -> Tuple[List[Dict[str, List[float]]], List[Dict[str, List[float]]], List[List[float]]]:
|
||||||
"""
|
"""
|
||||||
生成基础嵌入向量(陈述句、分块、对话)
|
生成基础嵌入向量(陈述句、分块、对话)
|
||||||
@@ -836,7 +859,7 @@ class ExtractionOrchestrator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _generate_entity_embeddings(
|
async def _generate_entity_embeddings(
|
||||||
self, triplet_maps: List[Dict[str, Any]]
|
self, triplet_maps: List[Dict[str, Any]]
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
生成实体嵌入向量
|
生成实体嵌入向量
|
||||||
@@ -874,17 +897,15 @@ class ExtractionOrchestrator:
|
|||||||
logger.error(f"实体嵌入生成失败: {e}", exc_info=True)
|
logger.error(f"实体嵌入生成失败: {e}", exc_info=True)
|
||||||
return triplet_maps
|
return triplet_maps
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def _assign_extracted_data(
|
async def _assign_extracted_data(
|
||||||
self,
|
self,
|
||||||
dialog_data_list: List[DialogData],
|
dialog_data_list: List[DialogData],
|
||||||
temporal_maps: List[Dict[str, Any]],
|
temporal_maps: List[Dict[str, Any]],
|
||||||
triplet_maps: List[Dict[str, Any]],
|
triplet_maps: List[Dict[str, Any]],
|
||||||
emotion_maps: List[Dict[str, Any]],
|
emotion_maps: List[Dict[str, Any]],
|
||||||
statement_embedding_maps: List[Dict[str, List[float]]],
|
statement_embedding_maps: List[Dict[str, List[float]]],
|
||||||
chunk_embedding_maps: List[Dict[str, List[float]]],
|
chunk_embedding_maps: List[Dict[str, List[float]]],
|
||||||
dialog_embeddings: List[List[float]],
|
dialog_embeddings: List[List[float]],
|
||||||
) -> List[DialogData]:
|
) -> List[DialogData]:
|
||||||
"""
|
"""
|
||||||
将提取的数据赋值到语句
|
将提取的数据赋值到语句
|
||||||
@@ -906,12 +927,12 @@ class ExtractionOrchestrator:
|
|||||||
# 确保列表长度匹配
|
# 确保列表长度匹配
|
||||||
expected_length = len(dialog_data_list)
|
expected_length = len(dialog_data_list)
|
||||||
if (
|
if (
|
||||||
len(temporal_maps) != expected_length
|
len(temporal_maps) != expected_length
|
||||||
or len(triplet_maps) != expected_length
|
or len(triplet_maps) != expected_length
|
||||||
or len(emotion_maps) != expected_length
|
or len(emotion_maps) != expected_length
|
||||||
or len(statement_embedding_maps) != expected_length
|
or len(statement_embedding_maps) != expected_length
|
||||||
or len(chunk_embedding_maps) != expected_length
|
or len(chunk_embedding_maps) != expected_length
|
||||||
or len(dialog_embeddings) != expected_length
|
or len(dialog_embeddings) != expected_length
|
||||||
):
|
):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"数据大小不匹配 - 对话: {len(dialog_data_list)}, "
|
f"数据大小不匹配 - 对话: {len(dialog_data_list)}, "
|
||||||
@@ -999,15 +1020,17 @@ class ExtractionOrchestrator:
|
|||||||
return dialog_data_list
|
return dialog_data_list
|
||||||
|
|
||||||
async def _create_nodes_and_edges(
|
async def _create_nodes_and_edges(
|
||||||
self, dialog_data_list: List[DialogData]
|
self, dialog_data_list: List[DialogData]
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
List[DialogueNode],
|
List[DialogueNode],
|
||||||
List[ChunkNode],
|
List[ChunkNode],
|
||||||
List[StatementNode],
|
List[StatementNode],
|
||||||
List[ExtractedEntityNode],
|
List[ExtractedEntityNode],
|
||||||
|
List[PerceptualNode],
|
||||||
List[StatementChunkEdge],
|
List[StatementChunkEdge],
|
||||||
List[StatementEntityEdge],
|
List[StatementEntityEdge],
|
||||||
List[EntityEntityEdge],
|
List[EntityEntityEdge],
|
||||||
|
List[PerceptualEdge]
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
创建图数据库节点和边
|
创建图数据库节点和边
|
||||||
@@ -1031,6 +1054,8 @@ class ExtractionOrchestrator:
|
|||||||
statement_chunk_edges = []
|
statement_chunk_edges = []
|
||||||
statement_entity_edges = []
|
statement_entity_edges = []
|
||||||
entity_entity_edges = []
|
entity_entity_edges = []
|
||||||
|
perceptual_nodes = []
|
||||||
|
perceptual_edges = []
|
||||||
|
|
||||||
# 用于去重的集合
|
# 用于去重的集合
|
||||||
entity_id_set = set()
|
entity_id_set = set()
|
||||||
@@ -1074,6 +1099,46 @@ class ExtractionOrchestrator:
|
|||||||
metadata=chunk.metadata,
|
metadata=chunk.metadata,
|
||||||
)
|
)
|
||||||
chunk_nodes.append(chunk_node)
|
chunk_nodes.append(chunk_node)
|
||||||
|
logger.error(f"chunk file: {chunk.files}")
|
||||||
|
|
||||||
|
for p, file_type in chunk.files:
|
||||||
|
|
||||||
|
meta = p.meta_data or {}
|
||||||
|
content_meta = meta.get("content", {})
|
||||||
|
|
||||||
|
# 生成 summary embedding(如果有 embedder_client)
|
||||||
|
summary_embedding = None
|
||||||
|
if self.embedder_client and p.summary:
|
||||||
|
try:
|
||||||
|
summary_embedding = (await self.embedder_client.response([p.summary]))[0]
|
||||||
|
except Exception as emb_err:
|
||||||
|
print(f"Failed to embed perceptual summary: {emb_err}")
|
||||||
|
|
||||||
|
perceptual = PerceptualNode(
|
||||||
|
name=f"Perceptual_{p.id}",
|
||||||
|
**{
|
||||||
|
"id": str(p.id),
|
||||||
|
"end_user_id": str(p.end_user_id),
|
||||||
|
"perceptual_type": p.perceptual_type,
|
||||||
|
"file_path": p.file_path or "",
|
||||||
|
"file_name": p.file_name or "",
|
||||||
|
"file_ext": p.file_ext or "",
|
||||||
|
"summary": p.summary or "",
|
||||||
|
"keywords": content_meta.get("keywords", []),
|
||||||
|
"topic": content_meta.get("topic", ""),
|
||||||
|
"domain": content_meta.get("domain", ""),
|
||||||
|
"created_at": p.created_time.isoformat() if p.created_time else None,
|
||||||
|
"file_type": file_type,
|
||||||
|
"summary_embedding": summary_embedding,
|
||||||
|
})
|
||||||
|
perceptual_nodes.append(perceptual)
|
||||||
|
perceptual_edges.append(PerceptualEdge(
|
||||||
|
source=perceptual.id,
|
||||||
|
target=chunk.id,
|
||||||
|
end_user_id=dialog_data.end_user_id,
|
||||||
|
run_id=dialog_data.run_id,
|
||||||
|
created_at=dialog_data.created_at,
|
||||||
|
))
|
||||||
|
|
||||||
# 处理每个陈述句
|
# 处理每个陈述句
|
||||||
for statement in chunk.statements:
|
for statement in chunk.statements:
|
||||||
@@ -1083,15 +1148,19 @@ class ExtractionOrchestrator:
|
|||||||
name=f"Statement_{statement.id}", # 添加必需的 name 字段
|
name=f"Statement_{statement.id}", # 添加必需的 name 字段
|
||||||
chunk_id=chunk.id,
|
chunk_id=chunk.id,
|
||||||
stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段
|
stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段
|
||||||
temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段
|
temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL),
|
||||||
connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
|
# 添加必需的 temporal_info 字段
|
||||||
|
connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong',
|
||||||
|
# 添加必需的 connect_strength 字段
|
||||||
end_user_id=dialog_data.end_user_id,
|
end_user_id=dialog_data.end_user_id,
|
||||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||||
statement=statement.statement,
|
statement=statement.statement,
|
||||||
speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段
|
speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段
|
||||||
statement_embedding=statement.statement_embedding,
|
statement_embedding=statement.statement_embedding,
|
||||||
valid_at=statement.temporal_validity.valid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None,
|
valid_at=statement.temporal_validity.valid_at if hasattr(statement,
|
||||||
invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None,
|
'temporal_validity') and statement.temporal_validity else None,
|
||||||
|
invalid_at=statement.temporal_validity.invalid_at if hasattr(statement,
|
||||||
|
'temporal_validity') and statement.temporal_validity else None,
|
||||||
created_at=dialog_data.created_at,
|
created_at=dialog_data.created_at,
|
||||||
expired_at=dialog_data.expired_at,
|
expired_at=dialog_data.expired_at,
|
||||||
config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None,
|
config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None,
|
||||||
@@ -1141,7 +1210,8 @@ class ExtractionOrchestrator:
|
|||||||
example=getattr(entity, 'example', ''), # 新增:传递示例字段
|
example=getattr(entity, 'example', ''), # 新增:传递示例字段
|
||||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
# fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
|
# fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
|
||||||
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
|
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong',
|
||||||
|
# 添加必需的 connect_strength 字段
|
||||||
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
|
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
|
||||||
name_embedding=getattr(entity, 'name_embedding', None),
|
name_embedding=getattr(entity, 'name_embedding', None),
|
||||||
is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记
|
is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记
|
||||||
@@ -1248,25 +1318,32 @@ class ExtractionOrchestrator:
|
|||||||
chunk_nodes,
|
chunk_nodes,
|
||||||
statement_nodes,
|
statement_nodes,
|
||||||
entity_nodes,
|
entity_nodes,
|
||||||
|
perceptual_nodes,
|
||||||
statement_chunk_edges,
|
statement_chunk_edges,
|
||||||
statement_entity_edges,
|
statement_entity_edges,
|
||||||
entity_entity_edges,
|
entity_entity_edges,
|
||||||
|
perceptual_edges
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _run_dedup_and_write_summary(
|
async def _run_dedup_and_write_summary(
|
||||||
self,
|
self,
|
||||||
dialogue_nodes: List[DialogueNode],
|
dialogue_nodes: List[DialogueNode],
|
||||||
chunk_nodes: List[ChunkNode],
|
chunk_nodes: List[ChunkNode],
|
||||||
statement_nodes: List[StatementNode],
|
statement_nodes: List[StatementNode],
|
||||||
entity_nodes: List[ExtractedEntityNode],
|
entity_nodes: List[ExtractedEntityNode],
|
||||||
statement_chunk_edges: List[StatementChunkEdge],
|
statement_chunk_edges: List[StatementChunkEdge],
|
||||||
statement_entity_edges: List[StatementEntityEdge],
|
statement_entity_edges: List[StatementEntityEdge],
|
||||||
entity_entity_edges: List[EntityEntityEdge],
|
entity_entity_edges: List[EntityEntityEdge],
|
||||||
dialog_data_list: List[DialogData],
|
dialog_data_list: List[DialogData],
|
||||||
) -> Tuple[
|
) -> tuple[
|
||||||
Tuple[List[DialogueNode], List[ChunkNode], List[StatementNode]],
|
list[DialogueNode],
|
||||||
Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]],
|
list[ChunkNode],
|
||||||
Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]],
|
list[StatementNode],
|
||||||
|
list[ExtractedEntityNode],
|
||||||
|
list[StatementChunkEdge],
|
||||||
|
list[StatementEntityEdge],
|
||||||
|
list[EntityEntityEdge],
|
||||||
|
dict
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
执行两阶段去重并写入汇总
|
执行两阶段去重并写入汇总
|
||||||
@@ -1415,7 +1492,6 @@ class ExtractionOrchestrator:
|
|||||||
len(entity_entity_edges), len(final_entity_entity_edges)
|
len(entity_entity_edges), len(final_entity_entity_edges)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# 写入提取结果汇总(试运行和正式模式都需要生成)
|
# 写入提取结果汇总(试运行和正式模式都需要生成)
|
||||||
try:
|
try:
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
@@ -1436,10 +1512,10 @@ class ExtractionOrchestrator:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def _save_dedup_details(
|
def _save_dedup_details(
|
||||||
self,
|
self,
|
||||||
dedup_details: Dict[str, Any],
|
dedup_details: Dict[str, Any],
|
||||||
original_entities: List[ExtractedEntityNode],
|
original_entities: List[ExtractedEntityNode],
|
||||||
final_entities: List[ExtractedEntityNode]
|
final_entities: List[ExtractedEntityNode]
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
保存去重消歧的详细记录到实例变量(基于内存数据结构)
|
保存去重消歧的详细记录到实例变量(基于内存数据结构)
|
||||||
@@ -1537,15 +1613,16 @@ class ExtractionOrchestrator:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"解析消歧记录失败: {record}, 错误: {e}")
|
logger.debug(f"解析消歧记录失败: {record}, 错误: {e}")
|
||||||
|
|
||||||
logger.info(f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录")
|
logger.info(
|
||||||
|
f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存去重消歧详情失败: {e}", exc_info=True)
|
logger.error(f"保存去重消歧详情失败: {e}", exc_info=True)
|
||||||
|
|
||||||
async def _analyze_entity_merges(
|
async def _analyze_entity_merges(
|
||||||
self,
|
self,
|
||||||
original_entities: List[ExtractedEntityNode],
|
original_entities: List[ExtractedEntityNode],
|
||||||
final_entities: List[ExtractedEntityNode]
|
final_entities: List[ExtractedEntityNode]
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
分析实体合并情况,直接使用内存中的合并记录(不再解析日志文件)
|
分析实体合并情况,直接使用内存中的合并记录(不再解析日志文件)
|
||||||
@@ -1585,9 +1662,9 @@ class ExtractionOrchestrator:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
async def _analyze_entity_disambiguation(
|
async def _analyze_entity_disambiguation(
|
||||||
self,
|
self,
|
||||||
original_entities: List[ExtractedEntityNode],
|
original_entities: List[ExtractedEntityNode],
|
||||||
final_entities: List[ExtractedEntityNode]
|
final_entities: List[ExtractedEntityNode]
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
分析实体消歧情况,直接使用内存中的消歧记录(不再解析日志文件)
|
分析实体消歧情况,直接使用内存中的消歧记录(不再解析日志文件)
|
||||||
@@ -1645,9 +1722,9 @@ class ExtractionOrchestrator:
|
|||||||
return type_mapping.get(entity_type, f"{entity_type}实体节点")
|
return type_mapping.get(entity_type, f"{entity_type}实体节点")
|
||||||
|
|
||||||
async def _output_relationship_creation_results(
|
async def _output_relationship_creation_results(
|
||||||
self,
|
self,
|
||||||
entity_entity_edges: List[EntityEntityEdge],
|
entity_entity_edges: List[EntityEntityEdge],
|
||||||
entity_nodes: List[ExtractedEntityNode]
|
entity_nodes: List[ExtractedEntityNode]
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
输出关系创建结果
|
输出关系创建结果
|
||||||
@@ -1681,13 +1758,13 @@ class ExtractionOrchestrator:
|
|||||||
logger.error(f"输出关系创建结果失败: {e}", exc_info=True)
|
logger.error(f"输出关系创建结果失败: {e}", exc_info=True)
|
||||||
|
|
||||||
async def _send_dedup_progress_callback(
|
async def _send_dedup_progress_callback(
|
||||||
self,
|
self,
|
||||||
original_entities: int,
|
original_entities: int,
|
||||||
final_entities: int,
|
final_entities: int,
|
||||||
original_stmt_edges: int,
|
original_stmt_edges: int,
|
||||||
final_stmt_edges: int,
|
final_stmt_edges: int,
|
||||||
original_ent_edges: int,
|
original_ent_edges: int,
|
||||||
final_ent_edges: int,
|
final_ent_edges: int,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
发送去重消歧完成的进度回调,传递具体的去重和消歧效果
|
发送去重消歧完成的进度回调,传递具体的去重和消歧效果
|
||||||
@@ -1715,7 +1792,8 @@ class ExtractionOrchestrator:
|
|||||||
"original_count": original_entities,
|
"original_count": original_entities,
|
||||||
"final_count": final_entities,
|
"final_count": final_entities,
|
||||||
"reduced_count": entities_reduced,
|
"reduced_count": entities_reduced,
|
||||||
"reduction_rate": round(entities_reduced / original_entities * 100, 1) if original_entities > 0 else 0,
|
"reduction_rate": round(entities_reduced / original_entities * 100,
|
||||||
|
1) if original_entities > 0 else 0,
|
||||||
},
|
},
|
||||||
"statement_entity_edges": {
|
"statement_entity_edges": {
|
||||||
"original_count": original_stmt_edges,
|
"original_count": original_stmt_edges,
|
||||||
@@ -1790,7 +1868,8 @@ class ExtractionOrchestrator:
|
|||||||
|
|
||||||
disamb_examples.append({
|
disamb_examples.append({
|
||||||
"entity1_name": entity_name,
|
"entity1_name": entity_name,
|
||||||
"entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:", "").strip() if "vs" in disamb_type else "未知",
|
"entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:",
|
||||||
|
"").strip() if "vs" in disamb_type else "未知",
|
||||||
"entity2_name": entity_name,
|
"entity2_name": entity_name,
|
||||||
"entity2_type": disamb_type.split("vs")[1].strip() if "vs" in disamb_type else "未知",
|
"entity2_type": disamb_type.split("vs")[1].strip() if "vs" in disamb_type else "未知",
|
||||||
"description": f"{entity_name},消歧区分成功"
|
"description": f"{entity_name},消歧区分成功"
|
||||||
@@ -1815,9 +1894,9 @@ class ExtractionOrchestrator:
|
|||||||
|
|
||||||
|
|
||||||
async def get_chunked_dialogs(
|
async def get_chunked_dialogs(
|
||||||
chunker_strategy: str = "RecursiveChunker",
|
chunker_strategy: str = "RecursiveChunker",
|
||||||
end_user_id: str = "group_1",
|
end_user_id: str = "group_1",
|
||||||
indices: Optional[List[int]] = None,
|
indices: Optional[List[int]] = None,
|
||||||
) -> List[DialogData]:
|
) -> List[DialogData]:
|
||||||
"""从测试数据生成分块对话
|
"""从测试数据生成分块对话
|
||||||
|
|
||||||
@@ -1924,10 +2003,10 @@ async def get_chunked_dialogs(
|
|||||||
|
|
||||||
|
|
||||||
def preprocess_data(
|
def preprocess_data(
|
||||||
input_path: Optional[str] = None,
|
input_path: Optional[str] = None,
|
||||||
output_path: Optional[str] = None,
|
output_path: Optional[str] = None,
|
||||||
skip_cleaning: bool = True,
|
skip_cleaning: bool = True,
|
||||||
indices: Optional[List[int]] = None
|
indices: Optional[List[int]] = None
|
||||||
) -> List[DialogData]:
|
) -> List[DialogData]:
|
||||||
"""数据预处理
|
"""数据预处理
|
||||||
|
|
||||||
@@ -1946,7 +2025,8 @@ def preprocess_data(
|
|||||||
)
|
)
|
||||||
preprocessor = DataPreprocessor()
|
preprocessor = DataPreprocessor()
|
||||||
try:
|
try:
|
||||||
cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path, skip_cleaning=skip_cleaning, indices=indices)
|
cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path,
|
||||||
|
skip_cleaning=skip_cleaning, indices=indices)
|
||||||
logger.debug(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据")
|
logger.debug(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据")
|
||||||
return cleaned_data
|
return cleaned_data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1955,9 +2035,9 @@ def preprocess_data(
|
|||||||
|
|
||||||
|
|
||||||
async def get_chunked_dialogs_from_preprocessed(
|
async def get_chunked_dialogs_from_preprocessed(
|
||||||
data: List[DialogData],
|
data: List[DialogData],
|
||||||
chunker_strategy: str = "RecursiveChunker",
|
chunker_strategy: str = "RecursiveChunker",
|
||||||
llm_client: Optional[Any] = None,
|
llm_client: Optional[Any] = None,
|
||||||
) -> List[DialogData]:
|
) -> List[DialogData]:
|
||||||
"""从预处理后的数据中生成分块
|
"""从预处理后的数据中生成分块
|
||||||
|
|
||||||
@@ -1988,15 +2068,15 @@ async def get_chunked_dialogs_from_preprocessed(
|
|||||||
|
|
||||||
|
|
||||||
async def get_chunked_dialogs_with_preprocessing(
|
async def get_chunked_dialogs_with_preprocessing(
|
||||||
chunker_strategy: str = "RecursiveChunker",
|
chunker_strategy: str = "RecursiveChunker",
|
||||||
end_user_id: str = "default",
|
end_user_id: str = "default",
|
||||||
user_id: str = "default",
|
user_id: str = "default",
|
||||||
apply_id: str = "default",
|
apply_id: str = "default",
|
||||||
indices: Optional[List[int]] = None,
|
indices: Optional[List[int]] = None,
|
||||||
input_data_path: Optional[str] = None,
|
input_data_path: Optional[str] = None,
|
||||||
llm_client: Optional[Any] = None,
|
llm_client: Optional[Any] = None,
|
||||||
skip_cleaning: bool = True,
|
skip_cleaning: bool = True,
|
||||||
pruning_config: Optional[Dict] = None,
|
pruning_config: Optional[Dict] = None,
|
||||||
) -> List[DialogData]:
|
) -> List[DialogData]:
|
||||||
"""包含数据预处理步骤的完整分块流程
|
"""包含数据预处理步骤的完整分块流程
|
||||||
|
|
||||||
@@ -2046,7 +2126,8 @@ async def get_chunked_dialogs_with_preprocessing(
|
|||||||
if pruning_config:
|
if pruning_config:
|
||||||
# 使用传入的配置
|
# 使用传入的配置
|
||||||
config = PruningConfig(**pruning_config)
|
config = PruningConfig(**pruning_config)
|
||||||
logger.debug(f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}")
|
logger.debug(
|
||||||
|
f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}")
|
||||||
else:
|
else:
|
||||||
# 使用默认配置(关闭剪枝)
|
# 使用默认配置(关闭剪枝)
|
||||||
config = None
|
config = None
|
||||||
|
|||||||
@@ -188,7 +188,6 @@ async def _process_chunk_summary(
|
|||||||
response_model=MemorySummaryResponse,
|
response_model=MemorySummaryResponse,
|
||||||
)
|
)
|
||||||
summary_text = structured.summary.strip()
|
summary_text = structured.summary.strip()
|
||||||
|
|
||||||
# Generate title and type for the summary
|
# Generate title and type for the summary
|
||||||
title = None
|
title = None
|
||||||
episodic_type = None
|
episodic_type = None
|
||||||
|
|||||||
@@ -390,6 +390,8 @@ class GraphBuilder:
|
|||||||
for edge in self.edges:
|
for edge in self.edges:
|
||||||
source = edge.get("source")
|
source = edge.get("source")
|
||||||
target = edge.get("target")
|
target = edge.get("target")
|
||||||
|
if source not in self.reachable_nodes or target not in self.reachable_nodes:
|
||||||
|
continue
|
||||||
condition = edge.get("condition")
|
condition = edge.get("condition")
|
||||||
edge_type = edge.get("type")
|
edge_type = edge.get("type")
|
||||||
|
|
||||||
|
|||||||
@@ -12,14 +12,26 @@ class ExecutionContext(BaseModel):
|
|||||||
execution_id: str
|
execution_id: str
|
||||||
workspace_id: str
|
workspace_id: str
|
||||||
user_id: str
|
user_id: str
|
||||||
|
memory_storage_type: str
|
||||||
|
user_rag_memory_id: str
|
||||||
checkpoint_config: RunnableConfig
|
checkpoint_config: RunnableConfig
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, execution_id: str, workspace_id: str, user_id: str):
|
def create(
|
||||||
|
cls,
|
||||||
|
execution_id: str,
|
||||||
|
workspace_id: str,
|
||||||
|
user_id: str,
|
||||||
|
memory_storage_type: str,
|
||||||
|
user_rag_memory_id: str
|
||||||
|
):
|
||||||
return cls(
|
return cls(
|
||||||
execution_id=execution_id,
|
execution_id=execution_id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
memory_storage_type=memory_storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
|
|
||||||
checkpoint_config=RunnableConfig(
|
checkpoint_config=RunnableConfig(
|
||||||
configurable={
|
configurable={
|
||||||
"thread_id": uuid.uuid4(),
|
"thread_id": uuid.uuid4(),
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ class WorkflowState(dict):
|
|||||||
"workspace_id",
|
"workspace_id",
|
||||||
"user_id",
|
"user_id",
|
||||||
"activate",
|
"activate",
|
||||||
|
"memory_storage_type",
|
||||||
|
"user_rag_memory_id"
|
||||||
})
|
})
|
||||||
__optional_keys__ = frozenset({
|
__optional_keys__ = frozenset({
|
||||||
"error",
|
"error",
|
||||||
@@ -62,6 +64,9 @@ class WorkflowState(dict):
|
|||||||
# node activate status
|
# node activate status
|
||||||
activate: Annotated[dict[str, bool], merge_activate_state]
|
activate: Annotated[dict[str, bool], merge_activate_state]
|
||||||
|
|
||||||
|
memory_storage_type: str
|
||||||
|
user_rag_memory_id: str
|
||||||
|
|
||||||
|
|
||||||
class WorkflowStateManager:
|
class WorkflowStateManager:
|
||||||
def create_initial_state(
|
def create_initial_state(
|
||||||
@@ -85,7 +90,9 @@ class WorkflowStateManager:
|
|||||||
looping=0,
|
looping=0,
|
||||||
activate={
|
activate={
|
||||||
start_node_id: True
|
start_node_id: True
|
||||||
}
|
},
|
||||||
|
memory_storage_type=execution_context.memory_storage_type,
|
||||||
|
user_rag_memory_id=execution_context.user_rag_memory_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from app.core.workflow.engine.runtime_schema import ExecutionContext
|
from app.core.workflow.engine.runtime_schema import ExecutionContext
|
||||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||||
from app.core.workflow.variable.variable_objects import T, create_variable_instance
|
from app.core.workflow.variable.variable_objects import T, create_variable_instance, ArrayVariable, FileVariable
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -373,6 +373,16 @@ class VariablePool:
|
|||||||
def copy(self, pool: 'VariablePool'):
|
def copy(self, pool: 'VariablePool'):
|
||||||
self.variables = deepcopy(pool.variables)
|
self.variables = deepcopy(pool.variables)
|
||||||
|
|
||||||
|
def is_file_variable(self, selector):
|
||||||
|
variable_struct = self.get_instance(selector, default=None, strict=False)
|
||||||
|
if variable_struct is None:
|
||||||
|
return False
|
||||||
|
if isinstance(variable_struct, FileVariable):
|
||||||
|
return True
|
||||||
|
elif isinstance(variable_struct, ArrayVariable) and variable_struct.child_type == FileVariable:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
"""导出为字典
|
"""导出为字典
|
||||||
|
|
||||||
|
|||||||
@@ -409,7 +409,9 @@ async def execute_workflow(
|
|||||||
input_data: dict[str, Any],
|
input_data: dict[str, Any],
|
||||||
execution_id: str,
|
execution_id: str,
|
||||||
workspace_id: str,
|
workspace_id: str,
|
||||||
user_id: str
|
user_id: str,
|
||||||
|
memory_storage_type: str,
|
||||||
|
user_rag_memory_id: str
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Execute a workflow (convenience function, non-streaming).
|
Execute a workflow (convenience function, non-streaming).
|
||||||
@@ -420,6 +422,8 @@ async def execute_workflow(
|
|||||||
execution_id (str): Execution ID.
|
execution_id (str): Execution ID.
|
||||||
workspace_id (str): Workspace ID.
|
workspace_id (str): Workspace ID.
|
||||||
user_id (str): User ID.
|
user_id (str): User ID.
|
||||||
|
user_rag_memory_id: rag knowledge db id
|
||||||
|
memory_storage_type: neo4j / rag
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Workflow execution result.
|
dict: Workflow execution result.
|
||||||
@@ -427,7 +431,9 @@ async def execute_workflow(
|
|||||||
execution_context = ExecutionContext.create(
|
execution_context = ExecutionContext.create(
|
||||||
execution_id=execution_id,
|
execution_id=execution_id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
user_id=user_id
|
user_id=user_id,
|
||||||
|
memory_storage_type=memory_storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id
|
||||||
)
|
)
|
||||||
executor = WorkflowExecutor(
|
executor = WorkflowExecutor(
|
||||||
workflow_config=workflow_config,
|
workflow_config=workflow_config,
|
||||||
@@ -441,7 +447,9 @@ async def execute_workflow_stream(
|
|||||||
input_data: dict[str, Any],
|
input_data: dict[str, Any],
|
||||||
execution_id: str,
|
execution_id: str,
|
||||||
workspace_id: str,
|
workspace_id: str,
|
||||||
user_id: str
|
user_id: str,
|
||||||
|
memory_storage_type: str,
|
||||||
|
user_rag_memory_id: str
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Execute a workflow in streaming mode (convenience function).
|
Execute a workflow in streaming mode (convenience function).
|
||||||
@@ -452,6 +460,8 @@ async def execute_workflow_stream(
|
|||||||
execution_id (str): Execution ID.
|
execution_id (str): Execution ID.
|
||||||
workspace_id (str): Workspace ID.
|
workspace_id (str): Workspace ID.
|
||||||
user_id (str): User ID.
|
user_id (str): User ID.
|
||||||
|
user_rag_memory_id: rag knowledge db id
|
||||||
|
memory_storage_type: neo4j / rag
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
dict: Streaming workflow events, e.g. node start, node end, chunk messages, workflow end.
|
dict: Streaming workflow events, e.g. node start, node end, chunk messages, workflow end.
|
||||||
@@ -459,7 +469,9 @@ async def execute_workflow_stream(
|
|||||||
execution_context = ExecutionContext.create(
|
execution_context = ExecutionContext.create(
|
||||||
execution_id=execution_id,
|
execution_id=execution_id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
user_id=user_id
|
user_id=user_id,
|
||||||
|
memory_storage_type=memory_storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id
|
||||||
)
|
)
|
||||||
executor = WorkflowExecutor(
|
executor = WorkflowExecutor(
|
||||||
workflow_config=workflow_config,
|
workflow_config=workflow_config,
|
||||||
|
|||||||
@@ -623,7 +623,6 @@ class BaseNode(ABC):
|
|||||||
async def process_message(
|
async def process_message(
|
||||||
api_config: ModelInfo,
|
api_config: ModelInfo,
|
||||||
content: str | dict | FileObject,
|
content: str | dict | FileObject,
|
||||||
end_user_id: str,
|
|
||||||
enable_file=False
|
enable_file=False
|
||||||
) -> list | str | None:
|
) -> list | str | None:
|
||||||
provider = api_config.provider
|
provider = api_config.provider
|
||||||
@@ -642,8 +641,8 @@ class BaseNode(ABC):
|
|||||||
return content
|
return content
|
||||||
|
|
||||||
elif isinstance(content, FileObject):
|
elif isinstance(content, FileObject):
|
||||||
if content.content_cache.get(provider):
|
if content.content_cache.get(f"{provider}_{ModelInfo.is_omni}"):
|
||||||
return content.content_cache[provider]
|
return content.content_cache[f"{provider}_{ModelInfo.is_omni}"]
|
||||||
with get_db_read() as db:
|
with get_db_read() as db:
|
||||||
multimodel_service = MultimodalService(db, api_config=api_config)
|
multimodel_service = MultimodalService(db, api_config=api_config)
|
||||||
file_obj = FileInput(
|
file_obj = FileInput(
|
||||||
@@ -655,12 +654,11 @@ class BaseNode(ABC):
|
|||||||
)
|
)
|
||||||
file_obj.set_content(content.get_content())
|
file_obj.set_content(content.get_content())
|
||||||
message = await multimodel_service.process_files(
|
message = await multimodel_service.process_files(
|
||||||
end_user_id,
|
|
||||||
[file_obj],
|
[file_obj],
|
||||||
)
|
)
|
||||||
content.set_content(file_obj.get_content())
|
content.set_content(file_obj.get_content())
|
||||||
if message:
|
if message:
|
||||||
content.content_cache[provider] = message
|
content.content_cache[f"{provider}_{ModelInfo.is_omni}"] = message
|
||||||
return message
|
return message
|
||||||
return None
|
return None
|
||||||
raise TypeError(f'Unexpect input value type - {type(content)}')
|
raise TypeError(f'Unexpect input value type - {type(content)}')
|
||||||
|
|||||||
@@ -144,7 +144,6 @@ class LLMNode(BaseNode):
|
|||||||
f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}")
|
f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}")
|
||||||
|
|
||||||
messages_config = self.typed_config.messages
|
messages_config = self.typed_config.messages
|
||||||
|
|
||||||
if messages_config:
|
if messages_config:
|
||||||
# 使用 LangChain 消息格式
|
# 使用 LangChain 消息格式
|
||||||
messages = []
|
messages = []
|
||||||
@@ -153,7 +152,6 @@ class LLMNode(BaseNode):
|
|||||||
content_template = msg_config.content
|
content_template = msg_config.content
|
||||||
content_template = self._render_context(content_template, variable_pool)
|
content_template = self._render_context(content_template, variable_pool)
|
||||||
content = self._render_template(content_template, variable_pool)
|
content = self._render_template(content_template, variable_pool)
|
||||||
user_id = self.get_variable("sys.user_id", variable_pool)
|
|
||||||
# 根据角色创建对应的消息对象
|
# 根据角色创建对应的消息对象
|
||||||
if role == "system":
|
if role == "system":
|
||||||
messages.append({
|
messages.append({
|
||||||
@@ -161,32 +159,31 @@ class LLMNode(BaseNode):
|
|||||||
"content": await self.process_message(
|
"content": await self.process_message(
|
||||||
model_info,
|
model_info,
|
||||||
content,
|
content,
|
||||||
user_id,
|
|
||||||
self.typed_config.vision,
|
self.typed_config.vision,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
elif role in ["user", "human"]:
|
elif role in ["user", "human"]:
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
"content": await self.process_message(model_info, content, self.typed_config.vision)
|
||||||
})
|
})
|
||||||
elif role in ["ai", "assistant"]:
|
elif role in ["ai", "assistant"]:
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
"content": await self.process_message(model_info, content, self.typed_config.vision)
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
"content": await self.process_message(model_info, content, self.typed_config.vision)
|
||||||
})
|
})
|
||||||
|
|
||||||
if self.typed_config.vision_input and self.typed_config.vision:
|
if self.typed_config.vision_input and self.typed_config.vision:
|
||||||
file_content = []
|
file_content = []
|
||||||
files = variable_pool.get_instance(self.typed_config.vision_input)
|
files = variable_pool.get_instance(self.typed_config.vision_input)
|
||||||
for file in files.value:
|
for file in files.value:
|
||||||
content = await self.process_message(model_info, file.value, user_id, self.typed_config.vision)
|
content = await self.process_message(model_info, file.value, self.typed_config.vision)
|
||||||
if content:
|
if content:
|
||||||
file_content.extend(content)
|
file_content.extend(content)
|
||||||
if messages and messages[-1]["role"] == 'user':
|
if messages and messages[-1]["role"] == 'user':
|
||||||
@@ -200,7 +197,7 @@ class LLMNode(BaseNode):
|
|||||||
if isinstance(message["content"], list):
|
if isinstance(message["content"], list):
|
||||||
file_content = []
|
file_content = []
|
||||||
for file in message["content"]:
|
for file in message["content"]:
|
||||||
content = await self.process_message(model_info, file, user_id, self.typed_config.vision)
|
content = await self.process_message(model_info, file, self.typed_config.vision)
|
||||||
if content:
|
if content:
|
||||||
file_content.extend(content)
|
file_content.extend(content)
|
||||||
history_message.append(
|
history_message.append(
|
||||||
@@ -210,7 +207,6 @@ class LLMNode(BaseNode):
|
|||||||
message["content"] = await self.process_message(
|
message["content"] = await self.process_message(
|
||||||
model_info,
|
model_info,
|
||||||
message["content"],
|
message["content"],
|
||||||
user_id,
|
|
||||||
self.typed_config.vision
|
self.typed_config.vision
|
||||||
)
|
)
|
||||||
history_message.append(message)
|
history_message.append(message)
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.core.workflow.engine.state_manager import WorkflowState
|
from app.core.workflow.engine.state_manager import WorkflowState
|
||||||
@@ -5,7 +6,9 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
|||||||
from app.core.workflow.nodes.base_node import BaseNode
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
|
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
|
||||||
from app.core.workflow.variable.base_variable import VariableType
|
from app.core.workflow.variable.base_variable import VariableType
|
||||||
|
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
|
||||||
from app.db import get_db_read
|
from app.db import get_db_read
|
||||||
|
from app.schemas import FileInput
|
||||||
from app.services.memory_agent_service import MemoryAgentService
|
from app.services.memory_agent_service import MemoryAgentService
|
||||||
from app.tasks import write_message_task
|
from app.tasks import write_message_task
|
||||||
|
|
||||||
@@ -36,8 +39,8 @@ class MemoryReadNode(BaseNode):
|
|||||||
search_switch=self.typed_config.search_switch,
|
search_switch=self.typed_config.search_switch,
|
||||||
history=[],
|
history=[],
|
||||||
db=db,
|
db=db,
|
||||||
storage_type="neo4j",
|
storage_type=state["memory_storage_type"],
|
||||||
user_rag_memory_id=""
|
user_rag_memory_id=state["user_rag_memory_id"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -49,6 +52,19 @@ class MemoryWriteNode(BaseNode):
|
|||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
return {"output": VariableType.STRING}
|
return {"output": VariableType.STRING}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_multimodal_memory_variables(content: str, variable_pool: VariablePool) -> tuple[list[str], str]:
|
||||||
|
variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}'
|
||||||
|
variable_pattern = re.compile(variable_pattern_string)
|
||||||
|
variables = variable_pattern.findall(content)
|
||||||
|
file_variables = []
|
||||||
|
for variable in variables:
|
||||||
|
if variable_pool.is_file_variable(variable):
|
||||||
|
file_variables.append(variable)
|
||||||
|
for var in file_variables:
|
||||||
|
content = content.replace(var, "")
|
||||||
|
return file_variables, content
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||||
self.typed_config = MemoryWriteNodeConfig(**self.config)
|
self.typed_config = MemoryWriteNodeConfig(**self.config)
|
||||||
end_user_id = self.get_variable("sys.user_id", variable_pool)
|
end_user_id = self.get_variable("sys.user_id", variable_pool)
|
||||||
@@ -63,17 +79,42 @@ class MemoryWriteNode(BaseNode):
|
|||||||
})
|
})
|
||||||
|
|
||||||
for message in self.typed_config.messages:
|
for message in self.typed_config.messages:
|
||||||
|
file_variables, content = self._extract_multimodal_memory_variables(
|
||||||
|
message.content,
|
||||||
|
variable_pool
|
||||||
|
)
|
||||||
|
file_info = []
|
||||||
|
for var in file_variables:
|
||||||
|
instence: FileVariable | ArrayVariable[FileVariable] = variable_pool.get_instance(var)
|
||||||
|
if isinstance(instence, FileVariable):
|
||||||
|
file_info.append(FileInput(
|
||||||
|
type=instence.value.type,
|
||||||
|
transfer_method=instence.value.transfer_method,
|
||||||
|
upload_file_id=instence.value.file_id,
|
||||||
|
url=instence.value.url,
|
||||||
|
file_type=instence.value.origin_file_type
|
||||||
|
).model_dump())
|
||||||
|
elif isinstance(instence, ArrayVariable) and instence.child_type == FileVariable:
|
||||||
|
for file_instence in instence.value:
|
||||||
|
file_info.append(FileInput(
|
||||||
|
type=file_instence.value.type,
|
||||||
|
transfer_method=file_instence.value.transfer_method,
|
||||||
|
upload_file_id=file_instence.value.file_id,
|
||||||
|
url=file_instence.value.url,
|
||||||
|
file_type=file_instence.value.origin_file_type
|
||||||
|
).model_dump())
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": message.role,
|
"role": message.role,
|
||||||
"content": self._render_template(message.content, variable_pool)
|
"content": self._render_template(content, variable_pool),
|
||||||
|
"files": file_info
|
||||||
})
|
})
|
||||||
|
|
||||||
write_message_task.delay(
|
write_message_task.delay(
|
||||||
end_user_id,
|
end_user_id=end_user_id,
|
||||||
messages,
|
message=messages,
|
||||||
str(self.typed_config.config_id),
|
config_id=str(self.typed_config.config_id),
|
||||||
"neo4j",
|
storage_type=state["memory_storage_type"],
|
||||||
""
|
user_rag_memory_id=state["user_rag_memory_id"]
|
||||||
)
|
)
|
||||||
|
|
||||||
return "success"
|
return "success"
|
||||||
|
|||||||
@@ -30,6 +30,9 @@ class MemoryConfig(Base):
|
|||||||
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
|
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
|
||||||
embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
|
embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
|
||||||
rerank_id = Column(String, nullable=True, comment="重排序模型配置ID")
|
rerank_id = Column(String, nullable=True, comment="重排序模型配置ID")
|
||||||
|
vision_id = Column(String, nullable=True, comment="视觉模型配置ID")
|
||||||
|
audio_id = Column(String, nullable=True, comment="语音模型配置ID")
|
||||||
|
video_id = Column(String, nullable=True, comment="视频模型配置ID")
|
||||||
|
|
||||||
# 记忆萃取引擎配置
|
# 记忆萃取引擎配置
|
||||||
enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重")
|
enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重")
|
||||||
|
|||||||
@@ -2,10 +2,11 @@ import datetime
|
|||||||
import uuid
|
import uuid
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
|
||||||
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table, text
|
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, UniqueConstraint, Integer, Table, text
|
||||||
from sqlalchemy.dialects.postgresql import UUID, JSON
|
from sqlalchemy.dialects.postgresql import UUID, JSON, ARRAY
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
from sqlalchemy.sql import func
|
from sqlalchemy.sql import func
|
||||||
|
|
||||||
from app.db import Base
|
from app.db import Base
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -9,21 +9,22 @@ Classes:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from uuid import UUID
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import desc, select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.logging_config import get_config_logger, get_db_logger
|
from app.core.logging_config import get_config_logger, get_db_logger
|
||||||
from app.models.memory_config_model import MemoryConfig
|
from app.models.memory_config_model import MemoryConfig
|
||||||
|
from app.models.workspace_model import Workspace
|
||||||
from app.schemas.memory_storage_schema import (
|
from app.schemas.memory_storage_schema import (
|
||||||
ConfigKey,
|
|
||||||
ConfigParamsCreate,
|
ConfigParamsCreate,
|
||||||
ConfigUpdate,
|
ConfigUpdate,
|
||||||
ConfigUpdateExtracted,
|
ConfigUpdateExtracted,
|
||||||
ConfigUpdateForget,
|
ConfigUpdateForget,
|
||||||
)
|
)
|
||||||
from sqlalchemy import desc, select
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.utils.config_utils import resolve_config_id
|
from app.utils.config_utils import resolve_config_id
|
||||||
|
|
||||||
# 获取数据库专用日志器
|
# 获取数据库专用日志器
|
||||||
@@ -157,7 +158,7 @@ class MemoryConfigRepository:
|
|||||||
return memory_config_obj
|
return memory_config_obj
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID|int|str) -> MemoryConfig:
|
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID | int | str) -> MemoryConfig:
|
||||||
"""构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数)
|
"""构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -309,57 +310,21 @@ class MemoryConfigRepository:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[MemoryConfig]: 更新后的配置对象,不存在则返回None
|
Optional[MemoryConfig]: 更新后的配置对象,不存在则返回None
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 没有字段需要更新时抛出
|
|
||||||
"""
|
"""
|
||||||
db_logger.debug(f"更新萃取配置: config_id={update.config_id}")
|
db_logger.debug(f"更新萃取配置: config_id={update.config_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == update.config_id).first()
|
stmt = select(MemoryConfig).where(MemoryConfig.config_id == update.config_id)
|
||||||
|
db_config = db.execute(stmt).scalar_one_or_none()
|
||||||
if not db_config:
|
if not db_config:
|
||||||
db_logger.warning(f"记忆配置不存在: config_id={update.config_id}")
|
db_logger.warning(f"记忆配置不存在: config_id={update.config_id}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 更新字段映射
|
update_data = update.model_dump(exclude_unset=True)
|
||||||
field_mapping = {
|
update_data.pop("config_id", None)
|
||||||
# 模型选择
|
|
||||||
"llm_id": "llm_id",
|
|
||||||
"embedding_id": "embedding_id",
|
|
||||||
"rerank_id": "rerank_id",
|
|
||||||
# 记忆萃取引擎
|
|
||||||
"enable_llm_dedup_blockwise": "enable_llm_dedup_blockwise",
|
|
||||||
"enable_llm_disambiguation": "enable_llm_disambiguation",
|
|
||||||
"deep_retrieval": "deep_retrieval",
|
|
||||||
"t_type_strict": "t_type_strict",
|
|
||||||
"t_name_strict": "t_name_strict",
|
|
||||||
"t_overall": "t_overall",
|
|
||||||
"state": "state",
|
|
||||||
"chunker_strategy": "chunker_strategy",
|
|
||||||
# 句子提取
|
|
||||||
"statement_granularity": "statement_granularity",
|
|
||||||
"include_dialogue_context": "include_dialogue_context",
|
|
||||||
"max_context": "max_context",
|
|
||||||
# 剪枝配置
|
|
||||||
"pruning_enabled": "pruning_enabled",
|
|
||||||
"pruning_scene": "pruning_scene",
|
|
||||||
"pruning_threshold": "pruning_threshold",
|
|
||||||
# 自我反思配置
|
|
||||||
"enable_self_reflexion": "enable_self_reflexion",
|
|
||||||
"iteration_period": "iteration_period",
|
|
||||||
"reflexion_range": "reflexion_range",
|
|
||||||
"baseline": "baseline",
|
|
||||||
}
|
|
||||||
|
|
||||||
has_update = False
|
for field, value in update_data.items():
|
||||||
for api_field, db_field in field_mapping.items():
|
setattr(db_config, field, value)
|
||||||
value = getattr(update, api_field, None)
|
|
||||||
if value is not None:
|
|
||||||
setattr(db_config, db_field, value)
|
|
||||||
has_update = True
|
|
||||||
|
|
||||||
if not has_update:
|
|
||||||
raise ValueError("No fields to update")
|
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(db_config)
|
db.refresh(db_config)
|
||||||
@@ -443,6 +408,9 @@ class MemoryConfigRepository:
|
|||||||
"llm_id": db_config.llm_id,
|
"llm_id": db_config.llm_id,
|
||||||
"embedding_id": db_config.embedding_id,
|
"embedding_id": db_config.embedding_id,
|
||||||
"rerank_id": db_config.rerank_id,
|
"rerank_id": db_config.rerank_id,
|
||||||
|
"vision_id": db_config.vision_id,
|
||||||
|
"audio_id": db_config.audio_id,
|
||||||
|
"video_id": db_config.video_id,
|
||||||
"enable_llm_dedup_blockwise": db_config.enable_llm_dedup_blockwise,
|
"enable_llm_dedup_blockwise": db_config.enable_llm_dedup_blockwise,
|
||||||
"enable_llm_disambiguation": db_config.enable_llm_disambiguation,
|
"enable_llm_disambiguation": db_config.enable_llm_disambiguation,
|
||||||
"deep_retrieval": db_config.deep_retrieval,
|
"deep_retrieval": db_config.deep_retrieval,
|
||||||
@@ -527,7 +495,10 @@ class MemoryConfigRepository:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_config_with_workspace(db: Session, config_id: uuid.UUID | int | str) -> Optional[tuple]:
|
def get_config_with_workspace(
|
||||||
|
db: Session,
|
||||||
|
config_id: uuid.UUID | int | str
|
||||||
|
) -> Optional[tuple[MemoryConfig, Workspace]]:
|
||||||
"""Get memory config and its associated workspace information
|
"""Get memory config and its associated workspace information
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -542,8 +513,6 @@ class MemoryConfigRepository:
|
|||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from app.models.workspace_model import Workspace
|
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
config_id = resolve_config_id(config_id, db)
|
config_id = resolve_config_id(config_id, db)
|
||||||
|
|
||||||
@@ -630,7 +599,7 @@ class MemoryConfigRepository:
|
|||||||
|
|
||||||
db_logger.debug(
|
db_logger.debug(
|
||||||
f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
|
f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
|
||||||
return (config, workspace)
|
return config, workspace
|
||||||
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# Re-raise known business exceptions
|
# Re-raise known business exceptions
|
||||||
@@ -775,9 +744,9 @@ class MemoryConfigRepository:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_with_fallback(
|
def get_with_fallback(
|
||||||
db: Session,
|
db: Session,
|
||||||
config_id: Optional[uuid.UUID],
|
config_id: Optional[uuid.UUID],
|
||||||
workspace_id: uuid.UUID
|
workspace_id: uuid.UUID
|
||||||
) -> Optional[MemoryConfig]:
|
) -> Optional[MemoryConfig]:
|
||||||
"""获取记忆配置,支持回退到工作空间默认配置
|
"""获取记忆配置,支持回退到工作空间默认配置
|
||||||
|
|
||||||
@@ -807,4 +776,3 @@ class MemoryConfigRepository:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return MemoryConfigRepository.get_workspace_default(db, workspace_id)
|
return MemoryConfigRepository.get_workspace_default(db, workspace_id)
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,15 @@
|
|||||||
from sqlalchemy.orm import Session, joinedload, selectinload
|
|
||||||
from sqlalchemy import and_, or_, func, desc, select
|
|
||||||
from typing import List, Optional, Dict, Any, Tuple
|
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import List, Optional, Dict, Any, Tuple
|
||||||
|
|
||||||
|
from sqlalchemy import and_, or_, func, desc
|
||||||
|
from sqlalchemy.orm import Session, joinedload
|
||||||
|
|
||||||
|
from app.core.logging_config import get_db_logger
|
||||||
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelBase, model_config_api_key_association
|
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelBase, model_config_api_key_association
|
||||||
from app.schemas.model_schema import (
|
from app.schemas.model_schema import (
|
||||||
ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
|
ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
|
||||||
ModelConfigQuery, ModelConfigQueryNew
|
ModelConfigQuery, ModelConfigQueryNew
|
||||||
)
|
)
|
||||||
from app.core.logging_config import get_db_logger
|
|
||||||
|
|
||||||
# 获取数据库专用日志器
|
# 获取数据库专用日志器
|
||||||
db_logger = get_db_logger()
|
db_logger = get_db_logger()
|
||||||
@@ -137,6 +138,9 @@ class ModelConfigRepository:
|
|||||||
type_values.append(ModelType.LLM)
|
type_values.append(ModelType.LLM)
|
||||||
filters.append(ModelConfig.type.in_(type_values))
|
filters.append(ModelConfig.type.in_(type_values))
|
||||||
|
|
||||||
|
if query.capability:
|
||||||
|
filters.append(ModelConfig.capability.contains(query.capability))
|
||||||
|
|
||||||
if query.is_active is not None:
|
if query.is_active is not None:
|
||||||
filters.append(ModelConfig.is_active == query.is_active)
|
filters.append(ModelConfig.is_active == query.is_active)
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
from typing import List, Optional
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE,MEMORY_SUMMARY_NODE_SAVE
|
|
||||||
from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode
|
from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode
|
||||||
|
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE, \
|
||||||
|
MEMORY_SUMMARY_NODE_SAVE
|
||||||
# 使用新的仓储层
|
# 使用新的仓储层
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
|
||||||
@@ -12,9 +13,10 @@ logger = logging.getLogger(__name__)
|
|||||||
async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector):
|
async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector):
|
||||||
"""Delete all nodes in the database."""
|
"""Delete all nodes in the database."""
|
||||||
result = await connector.execute_query(f"MATCH (n {{end_user_id: '{end_user_id}'}}) DETACH DELETE n")
|
result = await connector.execute_query(f"MATCH (n {{end_user_id: '{end_user_id}'}}) DETACH DELETE n")
|
||||||
print(f"All end_user_id: {end_user_id} node and edge deleted successfully")
|
logger.warning(f"All end_user_id: {end_user_id} node and edge deleted successfully")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||||
"""Add dialogue nodes to Neo4j database.
|
"""Add dialogue nodes to Neo4j database.
|
||||||
|
|
||||||
@@ -26,7 +28,7 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn
|
|||||||
List of created node UUIDs or None if failed
|
List of created node UUIDs or None if failed
|
||||||
"""
|
"""
|
||||||
if not dialogues:
|
if not dialogues:
|
||||||
print("No dialogues to save")
|
logger.info("No dialogues to save")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -51,11 +53,11 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn
|
|||||||
)
|
)
|
||||||
|
|
||||||
created_uuids = [record["uuid"] for record in result]
|
created_uuids = [record["uuid"] for record in result]
|
||||||
print(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}")
|
logger.info(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}")
|
||||||
return created_uuids
|
return created_uuids
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error creating dialogue nodes: {e}")
|
logger.error(f"Error creating dialogue nodes: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -70,7 +72,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
|
|||||||
List of created node UUIDs or None if failed
|
List of created node UUIDs or None if failed
|
||||||
"""
|
"""
|
||||||
if not statements:
|
if not statements:
|
||||||
print("No statements to save")
|
logger.info("No statements to save")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -123,13 +125,14 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
|
|||||||
)
|
)
|
||||||
|
|
||||||
created_uuids = [record["uuid"] for record in result]
|
created_uuids = [record["uuid"] for record in result]
|
||||||
print(f"Successfully created {len(created_uuids)} statement nodes")
|
logger.info(f"Successfully created {len(created_uuids)} statement nodes")
|
||||||
return created_uuids
|
return created_uuids
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error creating statement nodes: {e}")
|
logger.error(f"Error creating statement nodes: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||||
"""Add chunk nodes to Neo4j in batch.
|
"""Add chunk nodes to Neo4j in batch.
|
||||||
|
|
||||||
@@ -141,7 +144,7 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
|
|||||||
List of created chunk UUIDs or None if failed
|
List of created chunk UUIDs or None if failed
|
||||||
"""
|
"""
|
||||||
if not chunks:
|
if not chunks:
|
||||||
print("No chunk nodes to add")
|
logger.info("No chunk nodes to add")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -174,16 +177,18 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
|
|||||||
)
|
)
|
||||||
|
|
||||||
created_uuids = [record["uuid"] for record in result]
|
created_uuids = [record["uuid"] for record in result]
|
||||||
print(f"Successfully created {len(created_uuids)} chunk nodes")
|
logger.info(f"Successfully created {len(created_uuids)} chunk nodes")
|
||||||
return created_uuids
|
return created_uuids
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error creating chunk nodes: {e}")
|
logger.error(f"Error creating chunk nodes: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def add_memory_summary_nodes(
|
||||||
async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
summaries: List[MemorySummaryNode],
|
||||||
|
connector: Neo4jConnector
|
||||||
|
) -> Optional[List[str]]:
|
||||||
"""Add memory summary nodes to Neo4j in batch.
|
"""Add memory summary nodes to Neo4j in batch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -194,7 +199,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
|
|||||||
List of created summary node ids or None if failed
|
List of created summary node ids or None if failed
|
||||||
"""
|
"""
|
||||||
if not summaries:
|
if not summaries:
|
||||||
print("No memory summary nodes to add")
|
logger.info("No memory summary nodes to add")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -225,5 +230,3 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to save MemorySummary nodes to Neo4j: {e}")
|
logger.error(f"Failed to save MemorySummary nodes to Neo4j: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -709,7 +709,6 @@ SET r.end_user_id = e.end_user_id,
|
|||||||
RETURN elementId(r) AS uuid
|
RETURN elementId(r) AS uuid
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
# Entity Merge Query
|
# Entity Merge Query
|
||||||
MERGE_ENTITIES = """
|
MERGE_ENTITIES = """
|
||||||
MATCH (canonical:ExtractedEntity {id: $canonical_id})
|
MATCH (canonical:ExtractedEntity {id: $canonical_id})
|
||||||
@@ -829,9 +828,8 @@ neo4j_query_all = """
|
|||||||
other as entity2
|
other as entity2
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
'''针对当前节点下扩长的句子,实体和总结'''
|
'''针对当前节点下扩长的句子,实体和总结'''
|
||||||
Memory_Timeline_ExtractedEntity="""
|
Memory_Timeline_ExtractedEntity = """
|
||||||
MATCH (n)-[r1]-(e)-[r2]-(ms)
|
MATCH (n)-[r1]-(e)-[r2]-(ms)
|
||||||
WHERE elementId(n) = $id
|
WHERE elementId(n) = $id
|
||||||
AND (ms:ExtractedEntity OR ms:MemorySummary)
|
AND (ms:ExtractedEntity OR ms:MemorySummary)
|
||||||
@@ -869,7 +867,7 @@ RETURN
|
|||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Memory_Timeline_MemorySummary="""
|
Memory_Timeline_MemorySummary = """
|
||||||
MATCH (n)-[r1]-(e)-[r2]-(ms)
|
MATCH (n)-[r1]-(e)-[r2]-(ms)
|
||||||
WHERE elementId(n) =$id
|
WHERE elementId(n) =$id
|
||||||
AND (ms:MemorySummary OR ms:ExtractedEntity)
|
AND (ms:MemorySummary OR ms:ExtractedEntity)
|
||||||
@@ -904,7 +902,7 @@ RETURN
|
|||||||
}
|
}
|
||||||
) AS statement;
|
) AS statement;
|
||||||
"""
|
"""
|
||||||
Memory_Timeline_Statement="""
|
Memory_Timeline_Statement = """
|
||||||
MATCH (n)
|
MATCH (n)
|
||||||
WHERE elementId(n) = $id
|
WHERE elementId(n) = $id
|
||||||
|
|
||||||
@@ -947,7 +945,7 @@ RETURN
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
'''针对当前节点,主要获取更加完整的句子节点'''
|
'''针对当前节点,主要获取更加完整的句子节点'''
|
||||||
Memory_Space_Emotion_Statement="""
|
Memory_Space_Emotion_Statement = """
|
||||||
MATCH (n)
|
MATCH (n)
|
||||||
WHERE elementId(n) = $id
|
WHERE elementId(n) = $id
|
||||||
RETURN
|
RETURN
|
||||||
@@ -957,7 +955,7 @@ RETURN
|
|||||||
n.statement AS statement;
|
n.statement AS statement;
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Memory_Space_Emotion_MemorySummary="""
|
Memory_Space_Emotion_MemorySummary = """
|
||||||
MATCH (n)-[]-(e)
|
MATCH (n)-[]-(e)
|
||||||
WHERE elementId(n) = $id
|
WHERE elementId(n) = $id
|
||||||
AND EXISTS {
|
AND EXISTS {
|
||||||
@@ -970,7 +968,7 @@ RETURN DISTINCT
|
|||||||
e.emotion_type AS emotion_type,
|
e.emotion_type AS emotion_type,
|
||||||
e.statement AS statement;
|
e.statement AS statement;
|
||||||
"""
|
"""
|
||||||
Memory_Space_Emotion_ExtractedEntity="""
|
Memory_Space_Emotion_ExtractedEntity = """
|
||||||
MATCH (n)-[]-(e)
|
MATCH (n)-[]-(e)
|
||||||
WHERE elementId(n) = $id
|
WHERE elementId(n) = $id
|
||||||
AND EXISTS {
|
AND EXISTS {
|
||||||
@@ -985,18 +983,18 @@ RETURN DISTINCT
|
|||||||
|
|
||||||
'''获取实体'''
|
'''获取实体'''
|
||||||
|
|
||||||
Memory_Space_User="""
|
Memory_Space_User = """
|
||||||
MATCH (n)-[r]->(m)
|
MATCH (n)-[r]->(m)
|
||||||
WHERE n.end_user_id = $end_user_id AND m.name="用户"
|
WHERE n.end_user_id = $end_user_id AND m.name="用户"
|
||||||
return DISTINCT elementId(m) as id
|
return DISTINCT elementId(m) as id
|
||||||
"""
|
"""
|
||||||
Memory_Space_Entity="""
|
Memory_Space_Entity = """
|
||||||
MATCH (n)-[]-(m)
|
MATCH (n)-[]-(m)
|
||||||
WHERE elementId(m) = $id AND m.entity_type = "Person"
|
WHERE elementId(m) = $id AND m.entity_type = "Person"
|
||||||
RETURN
|
RETURN
|
||||||
DISTINCT m.name as name,m.end_user_id as end_user_id
|
DISTINCT m.name as name,m.end_user_id as end_user_id
|
||||||
"""
|
"""
|
||||||
Memory_Space_Associative="""
|
Memory_Space_Associative = """
|
||||||
MATCH (u)-[]-(x)-[]-(h)
|
MATCH (u)-[]-(x)-[]-(h)
|
||||||
WHERE elementId(u) = $user_id
|
WHERE elementId(u) = $user_id
|
||||||
AND elementId(h) = $id
|
AND elementId(h) = $id
|
||||||
@@ -1005,61 +1003,69 @@ RETURN DISTINCT
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
Graph_Node_query = """
|
Graph_Node_query = """
|
||||||
MATCH (n:MemorySummary)
|
MATCH (n:MemorySummary)
|
||||||
WHERE n.end_user_id = $end_user_id
|
WHERE n.end_user_id = $end_user_id
|
||||||
RETURN
|
RETURN
|
||||||
elementId(n) AS id,
|
elementId(n) AS id,
|
||||||
labels(n) AS labels,
|
labels(n) AS labels,
|
||||||
properties(n) AS properties,
|
properties(n) AS properties,
|
||||||
0 AS priority
|
0 AS priority
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
|
|
||||||
UNION ALL
|
UNION ALL
|
||||||
|
|
||||||
MATCH (n:Dialogue)
|
MATCH (n:Dialogue)
|
||||||
WHERE n.end_user_id = $end_user_id
|
WHERE n.end_user_id = $end_user_id
|
||||||
RETURN
|
RETURN
|
||||||
elementId(n) AS id,
|
elementId(n) AS id,
|
||||||
labels(n) AS labels,
|
labels(n) AS labels,
|
||||||
properties(n) AS properties,
|
properties(n) AS properties,
|
||||||
1 AS priority
|
1 AS priority
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
|
|
||||||
UNION ALL
|
UNION ALL
|
||||||
|
|
||||||
MATCH (n:Statement)
|
MATCH (n:Statement)
|
||||||
WHERE n.end_user_id = $end_user_id
|
WHERE n.end_user_id = $end_user_id
|
||||||
RETURN
|
RETURN
|
||||||
elementId(n) AS id,
|
elementId(n) AS id,
|
||||||
labels(n) AS labels,
|
labels(n) AS labels,
|
||||||
properties(n) AS properties,
|
properties(n) AS properties,
|
||||||
1 AS priority
|
1 AS priority
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
|
|
||||||
UNION ALL
|
UNION ALL
|
||||||
|
|
||||||
MATCH (n:ExtractedEntity)
|
MATCH (n:ExtractedEntity)
|
||||||
WHERE n.end_user_id = $end_user_id
|
WHERE n.end_user_id = $end_user_id
|
||||||
RETURN
|
RETURN
|
||||||
elementId(n) AS id,
|
elementId(n) AS id,
|
||||||
labels(n) AS labels,
|
labels(n) AS labels,
|
||||||
properties(n) AS properties,
|
properties(n) AS properties,
|
||||||
2 AS priority
|
2 AS priority
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
|
|
||||||
UNION ALL
|
UNION ALL
|
||||||
|
|
||||||
MATCH (n:Chunk)
|
MATCH (n:Chunk)
|
||||||
WHERE n.end_user_id = $end_user_id
|
WHERE n.end_user_id = $end_user_id
|
||||||
RETURN
|
RETURN
|
||||||
elementId(n) AS id,
|
elementId(n) AS id,
|
||||||
labels(n) AS labels,
|
labels(n) AS labels,
|
||||||
properties(n) AS properties,
|
properties(n) AS properties,
|
||||||
3 AS priority
|
3 AS priority
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
|
|
||||||
"""
|
UNION ALL
|
||||||
|
MATCH (n:Perceptual)
|
||||||
|
WHERE n.end_user_id = $end_user_id
|
||||||
|
RETURN
|
||||||
|
elementId(n) AS id,
|
||||||
|
labels(n) AS labels,
|
||||||
|
properties(n) AS properties,
|
||||||
|
4 AS priority
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# Community 节点 & BELONGS_TO_COMMUNITY 边
|
# Community 节点 & BELONGS_TO_COMMUNITY 边
|
||||||
@@ -1363,3 +1369,36 @@ RETURN s.statement AS statement,
|
|||||||
ORDER BY COALESCE(s.activation_value, 0) DESC
|
ORDER BY COALESCE(s.activation_value, 0) DESC
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# 感知记忆节点保存
|
||||||
|
PERCEPTUAL_NODE_SAVE = """
|
||||||
|
UNWIND $perceptuals AS p
|
||||||
|
MERGE (n:Perceptual {id: p.id})
|
||||||
|
SET n += {
|
||||||
|
id: p.id,
|
||||||
|
end_user_id: p.end_user_id,
|
||||||
|
perceptual_type: p.perceptual_type,
|
||||||
|
file_path: p.file_path,
|
||||||
|
file_name: p.file_name,
|
||||||
|
file_ext: p.file_ext,
|
||||||
|
summary: p.summary,
|
||||||
|
keywords: p.keywords,
|
||||||
|
topic: p.topic,
|
||||||
|
domain: p.domain,
|
||||||
|
created_at: p.created_at,
|
||||||
|
file_type: p.file_type,
|
||||||
|
summary_embedding: p.summary_embedding
|
||||||
|
}
|
||||||
|
RETURN n.id AS uuid
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 感知记忆与对话的关联边
|
||||||
|
PERCEPTUAL_CHUNK_EDGE_SAVE = """
|
||||||
|
UNWIND $edges AS edge
|
||||||
|
MATCH (p:Perceptual {id: edge.perceptual_id, end_user_id: edge.end_user_id})
|
||||||
|
MATCH (c:Chunk {id: edge.chunk_id, end_user_id: edge.end_user_id})
|
||||||
|
MERGE (c)-[r:HAS_PERCEPTUAL]->(p)
|
||||||
|
ON CREATE SET r.end_user_id = edge.end_user_id,
|
||||||
|
r.created_at = edge.created_at
|
||||||
|
RETURN elementId(r) AS uuid
|
||||||
|
"""
|
||||||
|
|||||||
@@ -22,13 +22,18 @@ from app.core.memory.models.graph_models import (
|
|||||||
StatementNode,
|
StatementNode,
|
||||||
ExtractedEntityNode,
|
ExtractedEntityNode,
|
||||||
EntityEntityEdge,
|
EntityEntityEdge,
|
||||||
|
PerceptualNode,
|
||||||
|
PerceptualEdge,
|
||||||
)
|
)
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def save_entities_and_relationships(
|
async def save_entities_and_relationships(
|
||||||
entity_nodes: List[ExtractedEntityNode],
|
entity_nodes: List[ExtractedEntityNode],
|
||||||
entity_entity_edges: List[EntityEntityEdge],
|
entity_entity_edges: List[EntityEntityEdge],
|
||||||
connector: Neo4jConnector
|
connector: Neo4jConnector
|
||||||
):
|
):
|
||||||
"""Save entities and their relationships using graph models"""
|
"""Save entities and their relationships using graph models"""
|
||||||
all_entities = [entity.model_dump() for entity in entity_nodes]
|
all_entities = [entity.model_dump() for entity in entity_nodes]
|
||||||
@@ -73,8 +78,8 @@ async def save_entities_and_relationships(
|
|||||||
|
|
||||||
|
|
||||||
async def save_chunk_nodes(
|
async def save_chunk_nodes(
|
||||||
chunk_nodes: List[ChunkNode],
|
chunk_nodes: List[ChunkNode],
|
||||||
connector: Neo4jConnector
|
connector: Neo4jConnector
|
||||||
):
|
):
|
||||||
"""Save chunk nodes using graph models"""
|
"""Save chunk nodes using graph models"""
|
||||||
if not chunk_nodes:
|
if not chunk_nodes:
|
||||||
@@ -89,8 +94,8 @@ async def save_chunk_nodes(
|
|||||||
|
|
||||||
|
|
||||||
async def save_statement_chunk_edges(
|
async def save_statement_chunk_edges(
|
||||||
statement_chunk_edges: List[StatementChunkEdge],
|
statement_chunk_edges: List[StatementChunkEdge],
|
||||||
connector: Neo4jConnector
|
connector: Neo4jConnector
|
||||||
):
|
):
|
||||||
"""Save statement-chunk edges using graph models"""
|
"""Save statement-chunk edges using graph models"""
|
||||||
if not statement_chunk_edges:
|
if not statement_chunk_edges:
|
||||||
@@ -118,8 +123,8 @@ async def save_statement_chunk_edges(
|
|||||||
|
|
||||||
|
|
||||||
async def save_statement_entity_edges(
|
async def save_statement_entity_edges(
|
||||||
statement_entity_edges: List[StatementEntityEdge],
|
statement_entity_edges: List[StatementEntityEdge],
|
||||||
connector: Neo4jConnector
|
connector: Neo4jConnector
|
||||||
):
|
):
|
||||||
"""Save statement-entity edges using graph models"""
|
"""Save statement-entity edges using graph models"""
|
||||||
if not statement_entity_edges:
|
if not statement_entity_edges:
|
||||||
@@ -154,9 +159,11 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
chunk_nodes: List[ChunkNode],
|
chunk_nodes: List[ChunkNode],
|
||||||
statement_nodes: List[StatementNode],
|
statement_nodes: List[StatementNode],
|
||||||
entity_nodes: List[ExtractedEntityNode],
|
entity_nodes: List[ExtractedEntityNode],
|
||||||
|
perceptual_nodes: List[PerceptualNode],
|
||||||
entity_edges: List[EntityEntityEdge],
|
entity_edges: List[EntityEntityEdge],
|
||||||
statement_chunk_edges: List[StatementChunkEdge],
|
statement_chunk_edges: List[StatementChunkEdge],
|
||||||
statement_entity_edges: List[StatementEntityEdge],
|
statement_entity_edges: List[StatementEntityEdge],
|
||||||
|
perceptual_edges: List[PerceptualEdge],
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
||||||
@@ -169,9 +176,11 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
chunk_nodes: List of ChunkNode objects to save
|
chunk_nodes: List of ChunkNode objects to save
|
||||||
statement_nodes: List of StatementNode objects to save
|
statement_nodes: List of StatementNode objects to save
|
||||||
entity_nodes: List of ExtractedEntityNode objects to save
|
entity_nodes: List of ExtractedEntityNode objects to save
|
||||||
|
perceptual_nodes: List of PerceptualNode objects to save
|
||||||
entity_edges: List of EntityEntityEdge objects to save
|
entity_edges: List of EntityEntityEdge objects to save
|
||||||
statement_chunk_edges: List of StatementChunkEdge objects to save
|
statement_chunk_edges: List of StatementChunkEdge objects to save
|
||||||
statement_entity_edges: List of StatementEntityEdge objects to save
|
statement_entity_edges: List of StatementEntityEdge objects to save
|
||||||
|
perceptual_edges: List of PerceptualEdge objects to save
|
||||||
connector: Neo4j connector instance
|
connector: Neo4j connector instance
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -190,7 +199,7 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data)
|
result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data)
|
||||||
dialogue_uuids = [record["uuid"] async for record in result]
|
dialogue_uuids = [record["uuid"] async for record in result]
|
||||||
results['dialogues'] = dialogue_uuids
|
results['dialogues'] = dialogue_uuids
|
||||||
print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
|
logger.info(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
|
||||||
|
|
||||||
# 2. Save all chunk nodes in batch
|
# 2. Save all chunk nodes in batch
|
||||||
if chunk_nodes:
|
if chunk_nodes:
|
||||||
@@ -201,6 +210,14 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
results['chunks'] = chunk_uuids
|
results['chunks'] = chunk_uuids
|
||||||
logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j")
|
logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j")
|
||||||
|
|
||||||
|
if perceptual_nodes:
|
||||||
|
from app.repositories.neo4j.cypher_queries import PERCEPTUAL_NODE_SAVE
|
||||||
|
perceptual_data = [node.model_dump() for node in perceptual_nodes]
|
||||||
|
result = await tx.run(PERCEPTUAL_NODE_SAVE, perceptuals=perceptual_data)
|
||||||
|
perceptual_uuids = [record["uuid"] async for record in result]
|
||||||
|
results["perceptuals"] = perceptual_uuids
|
||||||
|
logger.info(f"Successfully saved {len(perceptual_uuids)} perceptual nodes to Neo4j")
|
||||||
|
|
||||||
# 3. Save all statement nodes in batch
|
# 3. Save all statement nodes in batch
|
||||||
if statement_nodes:
|
if statement_nodes:
|
||||||
from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE
|
from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE
|
||||||
@@ -281,6 +298,22 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
results['statement_entity_edges'] = se_uuids
|
results['statement_entity_edges'] = se_uuids
|
||||||
logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j")
|
logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j")
|
||||||
|
|
||||||
|
if perceptual_edges:
|
||||||
|
from app.repositories.neo4j.cypher_queries import PERCEPTUAL_CHUNK_EDGE_SAVE
|
||||||
|
perceptual_edge_data = []
|
||||||
|
for edge in perceptual_edges:
|
||||||
|
print(edge.source, edge.target)
|
||||||
|
perceptual_edge_data.append({
|
||||||
|
"perceptual_id": edge.source,
|
||||||
|
"chunk_id": edge.target,
|
||||||
|
"end_user_id": edge.end_user_id,
|
||||||
|
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
||||||
|
})
|
||||||
|
result = await tx.run(PERCEPTUAL_CHUNK_EDGE_SAVE, edges=perceptual_edge_data)
|
||||||
|
perceptual_edges_uuids = [record["uuid"] async for record in result]
|
||||||
|
results['perceptual_chunk_edges'] = perceptual_edges_uuids
|
||||||
|
logger.info(f"Successfully saved {len(perceptual_edges_uuids)} perceptual-chunk edges to Neo4j")
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -304,9 +337,9 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
|
|
||||||
|
|
||||||
async def _trigger_clustering_sync(
|
async def _trigger_clustering_sync(
|
||||||
entity_nodes: List,
|
entity_nodes: List,
|
||||||
llm_model_id: Optional[str] = None,
|
llm_model_id: Optional[str] = None,
|
||||||
embedding_model_id: Optional[str] = None,
|
embedding_model_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
同步等待聚类完成,避免与其他 LLM 任务并发冲突。
|
同步等待聚类完成,避免与其他 LLM 任务并发冲突。
|
||||||
@@ -322,14 +355,15 @@ async def _trigger_clustering_sync(
|
|||||||
end_user_id = entity_nodes[0].end_user_id
|
end_user_id = entity_nodes[0].end_user_id
|
||||||
new_entity_ids = [e.id for e in entity_nodes]
|
new_entity_ids = [e.id for e in entity_nodes]
|
||||||
logger.info(f"[Clustering] 准备触发聚类(同步),实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
|
logger.info(f"[Clustering] 准备触发聚类(同步),实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
|
||||||
await _trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id)
|
await _trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id,
|
||||||
|
embedding_model_id=embedding_model_id)
|
||||||
|
|
||||||
|
|
||||||
async def _trigger_clustering(
|
async def _trigger_clustering(
|
||||||
new_entity_ids: List[str],
|
new_entity_ids: List[str],
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
llm_model_id: Optional[str] = None,
|
llm_model_id: Optional[str] = None,
|
||||||
embedding_model_id: Optional[str] = None,
|
embedding_model_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
聚类触发函数,自动判断全量初始化还是增量更新。
|
聚类触发函数,自动判断全量初始化还是增量更新。
|
||||||
|
|||||||
@@ -387,6 +387,12 @@ class MemoryConfig:
|
|||||||
|
|
||||||
rerank_model_id: Optional[UUID] = None
|
rerank_model_id: Optional[UUID] = None
|
||||||
rerank_model_name: Optional[str] = None
|
rerank_model_name: Optional[str] = None
|
||||||
|
video_model_id: Optional[UUID] = None
|
||||||
|
video_model_name: Optional[str] = None
|
||||||
|
vision_model_id: Optional[UUID] = None
|
||||||
|
vision_model_name: Optional[str] = None
|
||||||
|
audio_model_id: Optional[UUID] = None
|
||||||
|
audio_model_name: Optional[str] = None
|
||||||
|
|
||||||
llm_params: Dict[str, Any] = field(default_factory=dict)
|
llm_params: Dict[str, Any] = field(default_factory=dict)
|
||||||
embedding_params: Dict[str, Any] = field(default_factory=dict)
|
embedding_params: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|||||||
@@ -8,9 +8,6 @@ import uuid
|
|||||||
from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator
|
from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# 从 json_schema.py 迁移的 Schema
|
# 从 json_schema.py 迁移的 Schema
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@@ -58,10 +55,13 @@ class MemoryVerifySchema(BaseModel):
|
|||||||
|
|
||||||
class ConflictResultSchema(BaseModel):
|
class ConflictResultSchema(BaseModel):
|
||||||
"""Schema for the conflict result data in the reflexion_data.json file."""
|
"""Schema for the conflict result data in the reflexion_data.json file."""
|
||||||
data: List[BaseDataSchema] = Field(..., description="The conflict memory data. Only contains conflicting records when conflict is True.")
|
data: List[BaseDataSchema] = Field(...,
|
||||||
|
description="The conflict memory data. Only contains conflicting records when conflict is True.")
|
||||||
conflict: bool = Field(..., description="Whether the memory is in conflict.")
|
conflict: bool = Field(..., description="Whether the memory is in conflict.")
|
||||||
quality_assessment: Optional[QualityAssessmentSchema] = Field(None, description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.")
|
quality_assessment: Optional[QualityAssessmentSchema] = Field(None,
|
||||||
memory_verify: Optional[MemoryVerifySchema] = Field(None, description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.")
|
description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.")
|
||||||
|
memory_verify: Optional[MemoryVerifySchema] = Field(None,
|
||||||
|
description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.")
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
def _normalize_data(cls, v):
|
def _normalize_data(cls, v):
|
||||||
@@ -105,12 +105,15 @@ class ChangeRecordSchema(BaseModel):
|
|||||||
description="List of field changes. First item: {id: value}, followed by changed fields as {field_name: [old_value, new_value]} or {field_name: new_value} or nested structures like {entity2: {field_name: [old, new]}}"
|
description="List of field changes. First item: {id: value}, followed by changed fields as {field_name: [old_value, new_value]} or {field_name: new_value} or nested structures like {entity2: {field_name: [old, new]}}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ResolvedSchema(BaseModel):
|
class ResolvedSchema(BaseModel):
|
||||||
"""Schema for the resolved memory data in the reflexion_data"""
|
"""Schema for the resolved memory data in the reflexion_data"""
|
||||||
original_memory_id: Optional[str] = Field(None, description="The original memory identifier.")
|
original_memory_id: Optional[str] = Field(None, description="The original memory identifier.")
|
||||||
# resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data (only contains records that need modification).")
|
# resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data (only contains records that need modification).")
|
||||||
resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None, description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.")
|
resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None,
|
||||||
change: Optional[List[ChangeRecordSchema]] = Field(None, description="List of detailed change records with IDs and field information.")
|
description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.")
|
||||||
|
change: Optional[List[ChangeRecordSchema]] = Field(None,
|
||||||
|
description="List of detailed change records with IDs and field information.")
|
||||||
|
|
||||||
|
|
||||||
class SingleReflexionResultSchema(BaseModel):
|
class SingleReflexionResultSchema(BaseModel):
|
||||||
@@ -120,9 +123,11 @@ class SingleReflexionResultSchema(BaseModel):
|
|||||||
resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data for this conflict.")
|
resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data for this conflict.")
|
||||||
type: str = Field("reflexion_result", description="The type identifier.")
|
type: str = Field("reflexion_result", description="The type identifier.")
|
||||||
|
|
||||||
|
|
||||||
class ReflexionResultSchema(BaseModel):
|
class ReflexionResultSchema(BaseModel):
|
||||||
"""Schema for the complete reflexion result data - a list of individual conflict resolutions."""
|
"""Schema for the complete reflexion result data - a list of individual conflict resolutions."""
|
||||||
results: List[SingleReflexionResultSchema] = Field(..., description="List of individual conflict resolution results, grouped by conflict type.")
|
results: List[SingleReflexionResultSchema] = Field(...,
|
||||||
|
description="List of individual conflict resolution results, grouped by conflict type.")
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
def _normalize_resolved(cls, v):
|
def _normalize_resolved(cls, v):
|
||||||
@@ -147,9 +152,9 @@ class ReflexionResultSchema(BaseModel):
|
|||||||
# Composite key identifying a config row
|
# Composite key identifying a config row
|
||||||
class ConfigKey(BaseModel): # 配置参数键模型
|
class ConfigKey(BaseModel): # 配置参数键模型
|
||||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
config_id:Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)")
|
config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)")
|
||||||
user_id: str = Field("user_id", description="用户标识(字符串)")
|
user_id: str | None = Field(default=None, description="用户标识(字符串)")
|
||||||
apply_id: str = Field("apply_id", description="应用或场景标识(字符串)")
|
apply_id: str | None = Field(default=None, description="应用或场景标识(字符串)")
|
||||||
|
|
||||||
|
|
||||||
# Allowed chunking strategies (extendable later)
|
# Allowed chunking strategies (extendable later)
|
||||||
@@ -241,10 +246,12 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body,
|
|||||||
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
|
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
|
||||||
reflection_model_id: Optional[str] = Field(None, description="反思模型ID,默认与llm_id一致")
|
reflection_model_id: Optional[str] = Field(None, description="反思模型ID,默认与llm_id一致")
|
||||||
emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID,默认与llm_id一致")
|
emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID,默认与llm_id一致")
|
||||||
|
|
||||||
|
|
||||||
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
|
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
|
||||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
# config_name: str = Field("配置名称", description="配置名称(字符串)")
|
# config_name: str = Field("配置名称", description="配置名称(字符串)")
|
||||||
config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID(支持UUID、整数或字符串)")
|
config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID(支持UUID、整数或字符串)")
|
||||||
|
|
||||||
|
|
||||||
class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
|
class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
|
||||||
@@ -255,8 +262,11 @@ class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用
|
|||||||
|
|
||||||
|
|
||||||
class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
|
class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
|
||||||
config_id:Union[uuid.UUID, int, str] = None
|
config_id: Union[uuid.UUID, int, str] = None
|
||||||
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
|
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
|
||||||
|
audio_id: Optional[str] = Field(None, description="语音模型ID")
|
||||||
|
vision_id: Optional[str] = Field(None, description="视觉模型ID")
|
||||||
|
video_id: Optional[str] = Field(None, description="视频模型ID")
|
||||||
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
|
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
|
||||||
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
|
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
|
||||||
enable_llm_dedup_blockwise: Optional[bool] = None
|
enable_llm_dedup_blockwise: Optional[bool] = None
|
||||||
@@ -322,14 +332,14 @@ class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数
|
|||||||
|
|
||||||
class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用的模型
|
class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用的模型
|
||||||
# 遗忘引擎配置参数更新模型
|
# 遗忘引擎配置参数更新模型
|
||||||
config_id:Union[uuid.UUID, int, str] = None
|
config_id: Union[uuid.UUID, int, str] = None
|
||||||
lambda_time: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="最低保持度,0-1 小数;默认 0.5")
|
lambda_time: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="最低保持度,0-1 小数;默认 0.5")
|
||||||
lambda_mem: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="遗忘率,0-1 小数;默认 0.5")
|
lambda_mem: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="遗忘率,0-1 小数;默认 0.5")
|
||||||
offset: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="偏移度,0-1 小数;默认 0.0")
|
offset: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="偏移度,0-1 小数;默认 0.0")
|
||||||
|
|
||||||
|
|
||||||
class ConfigPilotRun(BaseModel): # 试运行触发请求模型
|
class ConfigPilotRun(BaseModel): # 试运行触发请求模型
|
||||||
config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID(唯一,支持UUID、整数或字符串)")
|
config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID(唯一,支持UUID、整数或字符串)")
|
||||||
dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填")
|
dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填")
|
||||||
custom_text: Optional[str] = Field(None, description="自定义输入文本,当配置关联本体场景时使用此字段进行试运行")
|
custom_text: Optional[str] = Field(None, description="自定义输入文本,当配置关联本体场景时使用此字段进行试运行")
|
||||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
@@ -364,11 +374,11 @@ def ok(msg: str = "OK", data: Optional[Any] = None, time: Optional[int] = None)
|
|||||||
|
|
||||||
|
|
||||||
def fail(
|
def fail(
|
||||||
msg: str,
|
msg: str,
|
||||||
error_code: str = "ERROR",
|
error_code: str = "ERROR",
|
||||||
data: Optional[Any] = None,
|
data: Optional[Any] = None,
|
||||||
time: Optional[int] = None,
|
time: Optional[int] = None,
|
||||||
query_preview: Optional[str] = None,
|
query_preview: Optional[str] = None,
|
||||||
) -> ApiResponse:
|
) -> ApiResponse:
|
||||||
payload = data
|
payload = data
|
||||||
if query_preview is not None:
|
if query_preview is not None:
|
||||||
@@ -387,6 +397,7 @@ def fail(
|
|||||||
time=time or _now_ms(),
|
time=time or _now_ms(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class GenerateCacheRequest(BaseModel):
|
class GenerateCacheRequest(BaseModel):
|
||||||
"""缓存生成请求模型"""
|
"""缓存生成请求模型"""
|
||||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
@@ -432,7 +443,7 @@ class ForgettingConfigUpdateRequest(BaseModel):
|
|||||||
"""遗忘引擎配置更新请求模型"""
|
"""遗忘引擎配置更新请求模型"""
|
||||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
|
|
||||||
config_id: Union[uuid.UUID, int,str] = Field(..., description="配置唯一标识(UUID或int)")
|
config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)")
|
||||||
decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d")
|
decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d")
|
||||||
lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数")
|
lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数")
|
||||||
lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数")
|
lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数")
|
||||||
@@ -472,7 +483,8 @@ class ForgettingStatsResponse(BaseModel):
|
|||||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
activation_metrics: Dict[str, Any] = Field(..., description="激活值相关指标")
|
activation_metrics: Dict[str, Any] = Field(..., description="激活值相关指标")
|
||||||
node_distribution: Dict[str, int] = Field(..., description="节点类型分布")
|
node_distribution: Dict[str, int] = Field(..., description="节点类型分布")
|
||||||
recent_trends: List[ForgettingCycleHistoryPoint] = Field(..., description="最近7个日期的遗忘趋势数据(每天取最后一次执行)")
|
recent_trends: List[ForgettingCycleHistoryPoint] = Field(...,
|
||||||
|
description="最近7个日期的遗忘趋势数据(每天取最后一次执行)")
|
||||||
pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表(前20个满足遗忘条件的节点)")
|
pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表(前20个满足遗忘条件的节点)")
|
||||||
timestamp: int = Field(..., description="统计时间(时间戳)")
|
timestamp: int = Field(..., description="统计时间(时间戳)")
|
||||||
|
|
||||||
|
|||||||
@@ -81,6 +81,12 @@ class ModelConfig(ModelConfigBase):
|
|||||||
updated_at: datetime.datetime
|
updated_at: datetime.datetime
|
||||||
api_keys: List["ModelApiKey"] = []
|
api_keys: List["ModelApiKey"] = []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def mask_api_key(key: str, prefix: int = 4, suffix: int = 4) -> str:
|
||||||
|
if not key or len(key) <= prefix + suffix:
|
||||||
|
return "*" * len(key)
|
||||||
|
return key[:prefix] + "*" * (len(key) - prefix - suffix) + key[-suffix:]
|
||||||
|
|
||||||
@field_validator("api_keys", mode="after")
|
@field_validator("api_keys", mode="after")
|
||||||
@classmethod
|
@classmethod
|
||||||
def filter_active_api_keys(cls, api_keys: List["ModelApiKey"]) -> List["ModelApiKey"]:
|
def filter_active_api_keys(cls, api_keys: List["ModelApiKey"]) -> List["ModelApiKey"]:
|
||||||
@@ -90,6 +96,15 @@ class ModelConfig(ModelConfigBase):
|
|||||||
def _serialize_created_at(self, dt: datetime.datetime | None):
|
def _serialize_created_at(self, dt: datetime.datetime | None):
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
@field_serializer("api_keys", when_used="json")
|
||||||
|
def _serialize_api_keys(self, api_keys: List["ModelApiKey"]):
|
||||||
|
result = []
|
||||||
|
for api_key in api_keys:
|
||||||
|
data = api_key.model_dump()
|
||||||
|
data["api_key"] = self.mask_api_key(api_key.api_key)
|
||||||
|
result.append(data)
|
||||||
|
return result
|
||||||
|
|
||||||
@field_serializer("updated_at", when_used="json")
|
@field_serializer("updated_at", when_used="json")
|
||||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
@@ -166,8 +181,8 @@ class ModelApiKey(ModelApiKeyBase):
|
|||||||
self.model_config_ids = [
|
self.model_config_ids = [
|
||||||
mc.id for mc in self.model_configs
|
mc.id for mc in self.model_configs
|
||||||
if hasattr(mc, 'id')
|
if hasattr(mc, 'id')
|
||||||
and not getattr(mc, 'is_composite', False)
|
and not getattr(mc, 'is_composite', False)
|
||||||
and getattr(mc, 'name', None) == self.model_name
|
and getattr(mc, 'name', None) == self.model_name
|
||||||
]
|
]
|
||||||
# 情况2:字典列表
|
# 情况2:字典列表
|
||||||
elif isinstance(self.model_configs, list):
|
elif isinstance(self.model_configs, list):
|
||||||
@@ -193,7 +208,6 @@ class ModelApiKey(ModelApiKeyBase):
|
|||||||
validate_assignment=True # 确保赋值触发校验
|
validate_assignment=True # 确保赋值触发校验
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@field_serializer("created_at", when_used="json")
|
@field_serializer("created_at", when_used="json")
|
||||||
def _serialize_created_at(self, dt: datetime.datetime):
|
def _serialize_created_at(self, dt: datetime.datetime):
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
@@ -211,6 +225,7 @@ class ModelConfigQuery(BaseModel):
|
|||||||
"""模型配置查询Schema"""
|
"""模型配置查询Schema"""
|
||||||
type: Optional[List[ModelType]] = Field(None, description="模型类型筛选(支持多个)")
|
type: Optional[List[ModelType]] = Field(None, description="模型类型筛选(支持多个)")
|
||||||
provider: Optional[ModelProvider] = Field(None, description="提供商筛选(通过API Key)")
|
provider: Optional[ModelProvider] = Field(None, description="提供商筛选(通过API Key)")
|
||||||
|
capability: Optional[List[str]] = Field(None, description="能力筛选(支持多个)")
|
||||||
is_active: Optional[bool] = Field(None, description="激活状态筛选")
|
is_active: Optional[bool] = Field(None, description="激活状态筛选")
|
||||||
is_public: Optional[bool] = Field(None, description="公开状态筛选")
|
is_public: Optional[bool] = Field(None, description="公开状态筛选")
|
||||||
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
||||||
@@ -228,6 +243,7 @@ class ModelConfigQueryNew(BaseModel):
|
|||||||
is_composite: Optional[bool] = Field(None, description="组合模型筛选")
|
is_composite: Optional[bool] = Field(None, description="组合模型筛选")
|
||||||
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
||||||
|
|
||||||
|
|
||||||
class ModelMarketplace(BaseModel):
|
class ModelMarketplace(BaseModel):
|
||||||
"""模型广场响应Schema"""
|
"""模型广场响应Schema"""
|
||||||
llm_models: List[ModelConfig] = []
|
llm_models: List[ModelConfig] = []
|
||||||
@@ -327,6 +343,7 @@ class ModelBaseQuery(BaseModel):
|
|||||||
is_deprecated: Optional[bool] = Field(None, description="是否弃用")
|
is_deprecated: Optional[bool] = Field(None, description="是否弃用")
|
||||||
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
||||||
|
|
||||||
|
|
||||||
class ModelInfo(BaseModel):
|
class ModelInfo(BaseModel):
|
||||||
"""模型信息Schema"""
|
"""模型信息Schema"""
|
||||||
model_name: str = Field(..., description="模型名称")
|
model_name: str = Field(..., description="模型名称")
|
||||||
@@ -336,4 +353,3 @@ class ModelInfo(BaseModel):
|
|||||||
is_omni: bool = Field(default=False, description="是否为omni模型")
|
is_omni: bool = Field(default=False, description="是否为omni模型")
|
||||||
model_type: ModelType = Field(..., description="模型类型")
|
model_type: ModelType = Field(..., description="模型类型")
|
||||||
capability: List[str] = Field(default_factory=list, description="模型能力列表")
|
capability: List[str] = Field(default_factory=list, description="模型能力列表")
|
||||||
|
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ class AppChatService:
|
|||||||
processed_files = None
|
processed_files = None
|
||||||
if files:
|
if files:
|
||||||
multimodal_service = MultimodalService(self.db, model_info)
|
multimodal_service = MultimodalService(self.db, model_info)
|
||||||
processed_files = await multimodal_service.process_files(user_id, files)
|
processed_files = await multimodal_service.process_files(files)
|
||||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||||
|
|
||||||
# 调用 Agent(支持多模态)
|
# 调用 Agent(支持多模态)
|
||||||
@@ -343,7 +343,7 @@ class AppChatService:
|
|||||||
processed_files = None
|
processed_files = None
|
||||||
if files:
|
if files:
|
||||||
multimodal_service = MultimodalService(self.db, model_info)
|
multimodal_service = MultimodalService(self.db, model_info)
|
||||||
processed_files = await multimodal_service.process_files(user_id, files)
|
processed_files = await multimodal_service.process_files(files)
|
||||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||||
|
|
||||||
# 流式调用 Agent(支持多模态),同时并行启动 TTS
|
# 流式调用 Agent(支持多模态),同时并行启动 TTS
|
||||||
|
|||||||
@@ -603,7 +603,7 @@ class AgentRunService:
|
|||||||
# 获取 provider 信息
|
# 获取 provider 信息
|
||||||
provider = api_key_config.get("provider", "openai")
|
provider = api_key_config.get("provider", "openai")
|
||||||
multimodal_service = MultimodalService(self.db, model_info)
|
multimodal_service = MultimodalService(self.db, model_info)
|
||||||
processed_files = await multimodal_service.process_files(user_id, files)
|
processed_files = await multimodal_service.process_files(files)
|
||||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||||
|
|
||||||
# 7. 知识库检索
|
# 7. 知识库检索
|
||||||
@@ -846,7 +846,7 @@ class AgentRunService:
|
|||||||
# 获取 provider 信息
|
# 获取 provider 信息
|
||||||
provider = api_key_config.get("provider", "openai")
|
provider = api_key_config.get("provider", "openai")
|
||||||
multimodal_service = MultimodalService(self.db, model_info)
|
multimodal_service = MultimodalService(self.db, model_info)
|
||||||
processed_files = await multimodal_service.process_files(user_id, files)
|
processed_files = await multimodal_service.process_files(files)
|
||||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||||
|
|
||||||
# 7. 知识库检索
|
# 7. 知识库检索
|
||||||
|
|||||||
@@ -19,32 +19,35 @@ from typing import Any, AsyncGenerator, Dict, List, Optional
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import redis
|
import redis
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.cache import InterestMemoryCache
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.logging_config import get_config_logger, get_logger
|
from app.core.logging_config import get_config_logger, get_logger
|
||||||
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
||||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
|
|
||||||
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
||||||
from app.core.memory.agent.utils.messages_tools import (
|
from app.core.memory.agent.utils.messages_tools import (
|
||||||
merge_multiple_search_results,
|
merge_multiple_search_results,
|
||||||
reorder_output_results,
|
reorder_output_results,
|
||||||
)
|
)
|
||||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||||
|
from app.core.memory.agent.utils.write_tools import write as write_neo4j
|
||||||
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
|
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
from app.schemas import FileInput
|
||||||
from app.schemas.memory_agent_schema import Write_UserInput
|
from app.schemas.memory_agent_schema import Write_UserInput
|
||||||
from app.schemas.memory_config_schema import ConfigurationError
|
from app.schemas.memory_config_schema import ConfigurationError
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
from app.services.memory_konwledges_server import (
|
from app.services.memory_konwledges_server import (
|
||||||
write_rag,
|
write_rag,
|
||||||
)
|
)
|
||||||
|
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from app.core.memory.utils.log.audit_logger import audit_logger
|
from app.core.memory.utils.log.audit_logger import audit_logger
|
||||||
@@ -267,8 +270,16 @@ class MemoryAgentService:
|
|||||||
logger.info("Log streaming completed, cleaning up resources")
|
logger.info("Log streaming completed, cleaning up resources")
|
||||||
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
||||||
|
|
||||||
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID] | int,
|
async def write_memory(
|
||||||
db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str:
|
self,
|
||||||
|
end_user_id: str,
|
||||||
|
messages: list[dict],
|
||||||
|
config_id: Optional[uuid.UUID] | int,
|
||||||
|
db: Session,
|
||||||
|
storage_type: str,
|
||||||
|
user_rag_memory_id: str,
|
||||||
|
language: str = "zh"
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Process write operation with config_id
|
Process write operation with config_id
|
||||||
|
|
||||||
@@ -297,8 +308,8 @@ class MemoryAgentService:
|
|||||||
config_id = connected_config.get("memory_config_id")
|
config_id = connected_config.get("memory_config_id")
|
||||||
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
|
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
|
||||||
if config_id is None and workspace_id is None:
|
if config_id is None and workspace_id is None:
|
||||||
raise ValueError(
|
raise ValueError(f"No memory configuration found for end_user {end_user_id}. "
|
||||||
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
f"Please ensure the user has a connected memory configuration.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "No memory configuration found" in str(e):
|
if "No memory configuration found" in str(e):
|
||||||
raise # Re-raise our specific error
|
raise # Re-raise our specific error
|
||||||
@@ -334,45 +345,57 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
perceptual_serivce = MemoryPerceptualService(db)
|
||||||
|
for message in messages:
|
||||||
|
message["file_content"] = []
|
||||||
|
for file in message["files"]:
|
||||||
|
file_object = await perceptual_serivce.generate_perceptual_memory(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
memory_config=memory_config,
|
||||||
|
file=FileInput(**file)
|
||||||
|
)
|
||||||
|
if file_object is None:
|
||||||
|
continue
|
||||||
|
message["file_content"].append((file_object, file["type"]))
|
||||||
|
|
||||||
|
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||||
try:
|
try:
|
||||||
if storage_type == "rag":
|
if storage_type == "rag":
|
||||||
# For RAG storage, convert messages to single string
|
# For RAG storage, convert messages to single string
|
||||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
await write_rag(end_user_id, message_text, user_rag_memory_id)
|
||||||
result = await write_rag(end_user_id, message_text, user_rag_memory_id)
|
return "success"
|
||||||
return result
|
|
||||||
else:
|
else:
|
||||||
async with make_write_graph() as graph:
|
await write_neo4j(
|
||||||
config = {"configurable": {"thread_id": end_user_id}}
|
end_user_id=end_user_id,
|
||||||
# Convert structured messages to LangChain messages
|
messages=messages,
|
||||||
langchain_messages = []
|
memory_config=memory_config,
|
||||||
for msg in messages:
|
ref_id='',
|
||||||
if msg['role'] == 'user':
|
language=language
|
||||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
)
|
||||||
elif msg['role'] == 'assistant':
|
for lang in ["zh", "en"]:
|
||||||
langchain_messages.append(AIMessage(content=msg['content']))
|
deleted = await InterestMemoryCache.delete_interest_distribution(
|
||||||
# 初始状态 - 包含所有必要字段
|
end_user_id, lang
|
||||||
initial_state = {
|
)
|
||||||
"messages": langchain_messages,
|
if deleted:
|
||||||
"end_user_id": end_user_id,
|
logger.info(
|
||||||
"memory_config": memory_config,
|
f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}")
|
||||||
"language": language
|
for message in messages:
|
||||||
|
message["file_content"] = [
|
||||||
|
perceptual[0].file_path for perceptual in message["file_content"]
|
||||||
|
]
|
||||||
|
return self.writer_messages_deal(
|
||||||
|
"success",
|
||||||
|
start_time,
|
||||||
|
end_user_id,
|
||||||
|
config_id,
|
||||||
|
message_text,
|
||||||
|
{
|
||||||
|
"status": "success",
|
||||||
|
"data": messages,
|
||||||
|
"config_id": memory_config.config_id,
|
||||||
|
"config_name": memory_config.config_name
|
||||||
}
|
}
|
||||||
|
)
|
||||||
# 获取节点更新信息
|
|
||||||
async for update_event in graph.astream(
|
|
||||||
initial_state,
|
|
||||||
stream_mode="updates",
|
|
||||||
config=config
|
|
||||||
):
|
|
||||||
for node_name, node_data in update_event.items():
|
|
||||||
if 'save_neo4j' == node_name:
|
|
||||||
massages = node_data
|
|
||||||
massagesstatus = massages.get('write_result')['status']
|
|
||||||
contents = massages.get('write_result')
|
|
||||||
# Convert messages back to string for logging
|
|
||||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
|
||||||
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text,
|
|
||||||
contents)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Ensure proper error handling and logging
|
# Ensure proper error handling and logging
|
||||||
error_msg = f"Write operation failed: {str(e)}"
|
error_msg = f"Write operation failed: {str(e)}"
|
||||||
|
|||||||
@@ -38,9 +38,9 @@ class MemoryAPIService:
|
|||||||
self.db = db
|
self.db = db
|
||||||
|
|
||||||
def validate_end_user(
|
def validate_end_user(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
workspace_id: uuid.UUID
|
workspace_id: uuid.UUID
|
||||||
) -> EndUser:
|
) -> EndUser:
|
||||||
"""Validate that end_user exists and belongs to the workspace.
|
"""Validate that end_user exists and belongs to the workspace.
|
||||||
|
|
||||||
@@ -125,13 +125,13 @@ class MemoryAPIService:
|
|||||||
logger.warning(f"Failed to update memory_config_id for end_user {end_user_id}: {e}")
|
logger.warning(f"Failed to update memory_config_id for end_user {end_user_id}: {e}")
|
||||||
|
|
||||||
async def write_memory(
|
async def write_memory(
|
||||||
self,
|
self,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
config_id: str,
|
config_id: str,
|
||||||
storage_type: str = "neo4j",
|
storage_type: str = "neo4j",
|
||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Write memory with validation.
|
"""Write memory with validation.
|
||||||
|
|
||||||
@@ -171,7 +171,7 @@ class MemoryAPIService:
|
|||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
db=self.db,
|
db=self.db,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id or ""
|
user_rag_memory_id=user_rag_memory_id or "",
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Memory write successful for end_user: {end_user_id}")
|
logger.info(f"Memory write successful for end_user: {end_user_id}")
|
||||||
@@ -206,14 +206,14 @@ class MemoryAPIService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def read_memory(
|
async def read_memory(
|
||||||
self,
|
self,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
search_switch: str = "0",
|
search_switch: str = "0",
|
||||||
config_id: str = "",
|
config_id: str = "",
|
||||||
storage_type: str = "neo4j",
|
storage_type: str = "neo4j",
|
||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Read memory with validation.
|
"""Read memory with validation.
|
||||||
|
|
||||||
@@ -244,7 +244,6 @@ class MemoryAPIService:
|
|||||||
# Update end user's memory_config_id
|
# Update end user's memory_config_id
|
||||||
self._update_end_user_config(end_user_id, config_id)
|
self._update_end_user_config(end_user_id, config_id)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Delegate to MemoryAgentService
|
# Delegate to MemoryAgentService
|
||||||
result = await MemoryAgentService().read_memory(
|
result = await MemoryAgentService().read_memory(
|
||||||
@@ -282,8 +281,8 @@ class MemoryAPIService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def list_memory_configs(
|
def list_memory_configs(
|
||||||
self,
|
self,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""List all memory configs for a workspace.
|
"""List all memory configs for a workspace.
|
||||||
|
|
||||||
|
|||||||
@@ -154,10 +154,10 @@ class MemoryConfigService:
|
|||||||
self.db = db
|
self.db = db
|
||||||
|
|
||||||
def load_memory_config(
|
def load_memory_config(
|
||||||
self,
|
self,
|
||||||
config_id: Optional[UUID] = None,
|
config_id: Optional[UUID] = None,
|
||||||
workspace_id: Optional[UUID] = None,
|
workspace_id: Optional[UUID] = None,
|
||||||
service_name: str = "MemoryConfigService",
|
service_name: str = "MemoryConfigService",
|
||||||
) -> MemoryConfig:
|
) -> MemoryConfig:
|
||||||
"""
|
"""
|
||||||
Load memory configuration from database with optional fallback.
|
Load memory configuration from database with optional fallback.
|
||||||
@@ -194,8 +194,8 @@ class MemoryConfigService:
|
|||||||
try:
|
try:
|
||||||
# Use get_config_with_fallback if workspace_id is provided
|
# Use get_config_with_fallback if workspace_id is provided
|
||||||
memory_config = None
|
memory_config = None
|
||||||
|
validated_config_id = None
|
||||||
if workspace_id:
|
if workspace_id:
|
||||||
validated_config_id = None
|
|
||||||
if config_id:
|
if config_id:
|
||||||
try:
|
try:
|
||||||
validated_config_id = _validate_config_id(config_id, self.db)
|
validated_config_id = _validate_config_id(config_id, self.db)
|
||||||
@@ -243,10 +243,10 @@ class MemoryConfigService:
|
|||||||
|
|
||||||
# Helper function to validate model with workspace fallback
|
# Helper function to validate model with workspace fallback
|
||||||
def _validate_model_with_fallback(
|
def _validate_model_with_fallback(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
model_type: str,
|
model_type: str,
|
||||||
workspace_default: str,
|
workspace_default: str,
|
||||||
required: bool = False
|
required: bool = False
|
||||||
) -> tuple:
|
) -> tuple:
|
||||||
"""Validate model ID, falling back to workspace default if invalid.
|
"""Validate model ID, falling back to workspace default if invalid.
|
||||||
|
|
||||||
@@ -343,6 +343,35 @@ class MemoryConfigService:
|
|||||||
if memory_config.rerank_id or workspace.rerank:
|
if memory_config.rerank_id or workspace.rerank:
|
||||||
logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s")
|
logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s")
|
||||||
|
|
||||||
|
vision_uuid, vision_name = validate_and_resolve_model_id(
|
||||||
|
memory_config.vision_id,
|
||||||
|
"llm",
|
||||||
|
self.db,
|
||||||
|
workspace.tenant_id,
|
||||||
|
required=False,
|
||||||
|
config_id=validated_config_id,
|
||||||
|
workspace_id=workspace.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_uuid, audio_name = validate_and_resolve_model_id(
|
||||||
|
memory_config.audio_id,
|
||||||
|
"llm",
|
||||||
|
self.db,
|
||||||
|
workspace.tenant_id,
|
||||||
|
required=False,
|
||||||
|
config_id=validated_config_id,
|
||||||
|
workspace_id=workspace.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
video_uuid, video_name = validate_and_resolve_model_id(
|
||||||
|
memory_config.video_id,
|
||||||
|
"llm",
|
||||||
|
self.db,
|
||||||
|
workspace.tenant_id,
|
||||||
|
required=False,
|
||||||
|
config_id=validated_config_id,
|
||||||
|
workspace_id=workspace.id,
|
||||||
|
)
|
||||||
# Create immutable MemoryConfig object
|
# Create immutable MemoryConfig object
|
||||||
config = MemoryConfig(
|
config = MemoryConfig(
|
||||||
config_id=memory_config.config_id,
|
config_id=memory_config.config_id,
|
||||||
@@ -356,6 +385,12 @@ class MemoryConfigService:
|
|||||||
embedding_model_name=embedding_name,
|
embedding_model_name=embedding_name,
|
||||||
rerank_model_id=rerank_uuid,
|
rerank_model_id=rerank_uuid,
|
||||||
rerank_model_name=rerank_name,
|
rerank_model_name=rerank_name,
|
||||||
|
video_model_id=video_uuid,
|
||||||
|
video_model_name=video_name,
|
||||||
|
vision_model_id=vision_uuid,
|
||||||
|
vision_model_name=vision_name,
|
||||||
|
audio_model_id=audio_uuid,
|
||||||
|
audio_model_name=audio_name,
|
||||||
storage_type=workspace.storage_type or "neo4j",
|
storage_type=workspace.storage_type or "neo4j",
|
||||||
chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker",
|
chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker",
|
||||||
reflexion_enabled=memory_config.enable_self_reflexion or False,
|
reflexion_enabled=memory_config.enable_self_reflexion or False,
|
||||||
@@ -364,24 +399,31 @@ class MemoryConfigService:
|
|||||||
reflexion_baseline=memory_config.baseline or "Time",
|
reflexion_baseline=memory_config.baseline or "Time",
|
||||||
loaded_at=datetime.now(),
|
loaded_at=datetime.now(),
|
||||||
# Pipeline config: Deduplication
|
# Pipeline config: Deduplication
|
||||||
enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
|
enable_llm_dedup_blockwise=bool(
|
||||||
enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
|
memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
|
||||||
|
enable_llm_disambiguation=bool(
|
||||||
|
memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
|
||||||
deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True,
|
deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True,
|
||||||
t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8,
|
t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8,
|
||||||
t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8,
|
t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8,
|
||||||
t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8,
|
t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8,
|
||||||
# Pipeline config: Statement extraction
|
# Pipeline config: Statement extraction
|
||||||
statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
|
statement_granularity=int(
|
||||||
include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
|
memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
|
||||||
max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000,
|
include_dialogue_context=bool(
|
||||||
|
memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
|
||||||
|
max_dialogue_context_chars=int(
|
||||||
|
memory_config.max_context) if memory_config.max_context is not None else 1000,
|
||||||
# Pipeline config: Forgetting engine
|
# Pipeline config: Forgetting engine
|
||||||
lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5,
|
lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5,
|
||||||
lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5,
|
lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5,
|
||||||
offset=float(memory_config.offset) if memory_config.offset is not None else 0.0,
|
offset=float(memory_config.offset) if memory_config.offset is not None else 0.0,
|
||||||
# Pipeline config: Pruning
|
# Pipeline config: Pruning
|
||||||
pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
|
pruning_enabled=bool(
|
||||||
|
memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
|
||||||
pruning_scene=memory_config.pruning_scene or "education",
|
pruning_scene=memory_config.pruning_scene or "education",
|
||||||
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
pruning_threshold=float(
|
||||||
|
memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
||||||
# Ontology scene association
|
# Ontology scene association
|
||||||
scene_id=memory_config.scene_id,
|
scene_id=memory_config.scene_id,
|
||||||
ontology_class_infos=_load_ontology_class_infos(self.db, memory_config.scene_id),
|
ontology_class_infos=_load_ontology_class_infos(self.db, memory_config.scene_id),
|
||||||
@@ -598,8 +640,8 @@ class MemoryConfigService:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def get_workspace_default_config(
|
def get_workspace_default_config(
|
||||||
self,
|
self,
|
||||||
workspace_id: UUID
|
workspace_id: UUID
|
||||||
) -> Optional["MemoryConfigModel"]:
|
) -> Optional["MemoryConfigModel"]:
|
||||||
"""Get workspace default memory config.
|
"""Get workspace default memory config.
|
||||||
|
|
||||||
@@ -623,9 +665,9 @@ class MemoryConfigService:
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
def get_config_with_fallback(
|
def get_config_with_fallback(
|
||||||
self,
|
self,
|
||||||
memory_config_id: Optional[UUID],
|
memory_config_id: Optional[UUID],
|
||||||
workspace_id: UUID
|
workspace_id: UUID
|
||||||
) -> Optional["MemoryConfigModel"]:
|
) -> Optional["MemoryConfigModel"]:
|
||||||
"""Get memory config with fallback to workspace default.
|
"""Get memory config with fallback to workspace default.
|
||||||
|
|
||||||
@@ -663,9 +705,9 @@ class MemoryConfigService:
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
def delete_config(
|
def delete_config(
|
||||||
self,
|
self,
|
||||||
config_id: UUID | int,
|
config_id: UUID | int,
|
||||||
force: bool = False
|
force: bool = False
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Delete memory config with protection against in-use configs.
|
"""Delete memory config with protection against in-use configs.
|
||||||
|
|
||||||
@@ -800,9 +842,9 @@ class MemoryConfigService:
|
|||||||
# ==================== 记忆配置提取方法 ====================
|
# ==================== 记忆配置提取方法 ====================
|
||||||
|
|
||||||
def extract_memory_config_id(
|
def extract_memory_config_id(
|
||||||
self,
|
self,
|
||||||
app_type: str,
|
app_type: str,
|
||||||
config: dict
|
config: dict
|
||||||
) -> tuple[Optional[uuid.UUID], bool]:
|
) -> tuple[Optional[uuid.UUID], bool]:
|
||||||
"""从发布配置中提取 memory_config_id(根据应用类型分发)
|
"""从发布配置中提取 memory_config_id(根据应用类型分发)
|
||||||
|
|
||||||
@@ -828,8 +870,8 @@ class MemoryConfigService:
|
|||||||
return None, False
|
return None, False
|
||||||
|
|
||||||
def _extract_memory_config_id_from_agent(
|
def _extract_memory_config_id_from_agent(
|
||||||
self,
|
self,
|
||||||
config: dict
|
config: dict
|
||||||
) -> tuple[Optional[uuid.UUID], bool]:
|
) -> tuple[Optional[uuid.UUID], bool]:
|
||||||
"""从 Agent 应用配置中提取 memory_config_id
|
"""从 Agent 应用配置中提取 memory_config_id
|
||||||
|
|
||||||
@@ -888,8 +930,8 @@ class MemoryConfigService:
|
|||||||
return None, False
|
return None, False
|
||||||
|
|
||||||
def _extract_memory_config_id_from_workflow(
|
def _extract_memory_config_id_from_workflow(
|
||||||
self,
|
self,
|
||||||
config: dict
|
config: dict
|
||||||
) -> tuple[Optional[uuid.UUID], bool]:
|
) -> tuple[Optional[uuid.UUID], bool]:
|
||||||
"""从 Workflow 应用配置中提取 memory_config_id
|
"""从 Workflow 应用配置中提取 memory_config_id
|
||||||
|
|
||||||
|
|||||||
@@ -341,7 +341,7 @@ async def memory_konwledges_up(
|
|||||||
)
|
)
|
||||||
db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user)
|
db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user)
|
||||||
|
|
||||||
return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful")
|
return db_document
|
||||||
|
|
||||||
|
|
||||||
async def create_document_chunk(
|
async def create_document_chunk(
|
||||||
@@ -350,7 +350,7 @@ async def create_document_chunk(
|
|||||||
create_data: ChunkCreate,
|
create_data: ChunkCreate,
|
||||||
db: Session,
|
db: Session,
|
||||||
current_user: User
|
current_user: User
|
||||||
):
|
) -> DocumentChunk:
|
||||||
"""
|
"""
|
||||||
创建文档块
|
创建文档块
|
||||||
|
|
||||||
@@ -439,10 +439,10 @@ async def create_document_chunk(
|
|||||||
db_document.chunk_num += 1
|
db_document.chunk_num += 1
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
return success(data=chunk, msg="文档块创建成功")
|
return chunk
|
||||||
|
|
||||||
|
|
||||||
async def write_rag(end_user_id, message, user_rag_memory_id):
|
async def write_rag(end_user_id, message, user_rag_memory_id) -> DocumentChunk:
|
||||||
"""
|
"""
|
||||||
将消息写入 RAG 知识库
|
将消息写入 RAG 知识库
|
||||||
|
|
||||||
@@ -482,11 +482,11 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
|
|||||||
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt")
|
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt")
|
||||||
print('======', document)
|
print('======', document)
|
||||||
api_logger.info(f"查找文档结果: document_id={document}")
|
api_logger.info(f"查找文档结果: document_id={document}")
|
||||||
|
create_chunks = ChunkCreate(content=message)
|
||||||
if document is not None:
|
if document is not None:
|
||||||
# 文档已存在,直接添加新块
|
# 文档已存在,直接添加新块
|
||||||
api_logger.info(f"文档已存在,添加新块: document_id={document}")
|
api_logger.info(f"文档已存在,添加新块: document_id={document}")
|
||||||
|
|
||||||
create_chunks = ChunkCreate(content=message)
|
|
||||||
result = await create_document_chunk(
|
result = await create_document_chunk(
|
||||||
kb_id=kb_uuid,
|
kb_id=kb_uuid,
|
||||||
document_id=uuid.UUID(document),
|
document_id=uuid.UUID(document),
|
||||||
@@ -498,13 +498,20 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
|
|||||||
else:
|
else:
|
||||||
# 文档不存在,创建新文档
|
# 文档不存在,创建新文档
|
||||||
api_logger.info(f"文档不存在,创建新文档: end_user_id={end_user_id}")
|
api_logger.info(f"文档不存在,创建新文档: end_user_id={end_user_id}")
|
||||||
result = await memory_konwledges_up(
|
document = await memory_konwledges_up(
|
||||||
kb_id=user_rag_memory_id,
|
kb_id=user_rag_memory_id,
|
||||||
parent_id=user_rag_memory_id,
|
parent_id=user_rag_memory_id,
|
||||||
create_data=create_data,
|
create_data=create_data,
|
||||||
db=db,
|
db=db,
|
||||||
current_user=current_user
|
current_user=current_user
|
||||||
)
|
)
|
||||||
|
result = await create_document_chunk(
|
||||||
|
kb_id=kb_uuid,
|
||||||
|
document_id=document.id,
|
||||||
|
create_data=create_chunks,
|
||||||
|
db=db,
|
||||||
|
current_user=current_user
|
||||||
|
)
|
||||||
# 重新查询刚创建的文档ID
|
# 重新查询刚创建的文档ID
|
||||||
new_document_id = find_document_id_by_kb_and_filename(
|
new_document_id = find_document_id_by_kb_and_filename(
|
||||||
db=db,
|
db=db,
|
||||||
|
|||||||
@@ -12,11 +12,12 @@ from app.core.error_codes import BizCode
|
|||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||||
from app.models import FileMetadata
|
from app.models import FileMetadata, ModelApiKey, ModelType
|
||||||
from app.models.memory_perceptual_model import PerceptualType, FileStorageService
|
from app.models.memory_perceptual_model import PerceptualType, FileStorageService
|
||||||
from app.models.prompt_optimizer_model import RoleType
|
from app.models.prompt_optimizer_model import RoleType
|
||||||
from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository
|
from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository
|
||||||
from app.schemas import FileType
|
from app.schemas import FileType, FileInput
|
||||||
|
from app.schemas.memory_config_schema import MemoryConfig
|
||||||
from app.schemas.memory_perceptual_schema import (
|
from app.schemas.memory_perceptual_schema import (
|
||||||
PerceptualQuerySchema,
|
PerceptualQuerySchema,
|
||||||
PerceptualTimelineResponse,
|
PerceptualTimelineResponse,
|
||||||
@@ -24,6 +25,8 @@ from app.schemas.memory_perceptual_schema import (
|
|||||||
AudioModal, Content, VideoModal, TextModal
|
AudioModal, Content, VideoModal, TextModal
|
||||||
)
|
)
|
||||||
from app.schemas.model_schema import ModelInfo
|
from app.schemas.model_schema import ModelInfo
|
||||||
|
from app.services.model_service import ModelApiKeyService
|
||||||
|
from app.services.multimodal_service import MultimodalService
|
||||||
|
|
||||||
business_logger = get_business_logger()
|
business_logger = get_business_logger()
|
||||||
|
|
||||||
@@ -195,21 +198,58 @@ class MemoryPerceptualService:
|
|||||||
business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}")
|
business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}")
|
||||||
raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR)
|
raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR)
|
||||||
|
|
||||||
|
def _get_mutlimodal_client(
|
||||||
|
self,
|
||||||
|
file_type: FileType,
|
||||||
|
config: MemoryConfig
|
||||||
|
) -> tuple[RedBearLLM | None, ModelApiKey | None]:
|
||||||
|
model_config = None
|
||||||
|
if file_type == FileType.AUDIO:
|
||||||
|
model_config = ModelApiKeyService.get_available_api_key(
|
||||||
|
self.db,
|
||||||
|
config.audio_model_id
|
||||||
|
)
|
||||||
|
elif file_type == FileType.VIDEO:
|
||||||
|
model_config = ModelApiKeyService.get_available_api_key(
|
||||||
|
self.db,
|
||||||
|
config.video_model_id
|
||||||
|
)
|
||||||
|
elif file_type == FileType.DOCUMENT:
|
||||||
|
model_config = ModelApiKeyService.get_available_api_key(
|
||||||
|
self.db,
|
||||||
|
config.llm_model_id
|
||||||
|
)
|
||||||
|
elif file_type == FileType.IMAGE:
|
||||||
|
model_config = ModelApiKeyService.get_available_api_key(
|
||||||
|
self.db,
|
||||||
|
config.vision_model_id
|
||||||
|
)
|
||||||
|
llm = None
|
||||||
|
if model_config:
|
||||||
|
llm = RedBearLLM(
|
||||||
|
RedBearModelConfig(
|
||||||
|
model_name=model_config.model_name,
|
||||||
|
provider=model_config.provider,
|
||||||
|
api_key=model_config.api_key,
|
||||||
|
base_url=model_config.api_base,
|
||||||
|
is_omni=model_config.is_omni
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return llm, model_config
|
||||||
|
|
||||||
async def generate_perceptual_memory(
|
async def generate_perceptual_memory(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
model_config: ModelInfo,
|
memory_config: MemoryConfig,
|
||||||
file_type: str,
|
file: FileInput
|
||||||
file_url: str,
|
|
||||||
file_message: dict,
|
|
||||||
):
|
):
|
||||||
memories = self.repository.get_by_url(file_url)
|
memories = self.repository.get_by_url(file.url)
|
||||||
if memories:
|
if memories:
|
||||||
business_logger.info(f"Perceptual memory already exists: {file_url}")
|
business_logger.info(f"Perceptual memory already exists: {file.url}")
|
||||||
if end_user_id not in [memory.end_user_id for memory in memories]:
|
if end_user_id not in [memory.end_user_id for memory in memories]:
|
||||||
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
|
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
|
||||||
memory_cache = memories[0]
|
memory_cache = memories[0]
|
||||||
self.repository.create_perceptual_memory(
|
memory = self.repository.create_perceptual_memory(
|
||||||
end_user_id=uuid.UUID(end_user_id),
|
end_user_id=uuid.UUID(end_user_id),
|
||||||
perceptual_type=PerceptualType(memory_cache.perceptual_type),
|
perceptual_type=PerceptualType(memory_cache.perceptual_type),
|
||||||
file_path=memory_cache.file_path,
|
file_path=memory_cache.file_path,
|
||||||
@@ -219,20 +259,33 @@ class MemoryPerceptualService:
|
|||||||
meta_data=memory_cache.meta_data
|
meta_data=memory_cache.meta_data
|
||||||
)
|
)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
|
return memory
|
||||||
return
|
else:
|
||||||
llm = RedBearLLM(RedBearModelConfig(
|
for memory in memories:
|
||||||
|
if memory.end_user_id == uuid.UUID(end_user_id):
|
||||||
|
return memory
|
||||||
|
llm, model_config = self._get_mutlimodal_client(file.type, memory_config)
|
||||||
|
multimodel_service = MultimodalService(self.db, ModelInfo(
|
||||||
model_name=model_config.model_name,
|
model_name=model_config.model_name,
|
||||||
provider=model_config.provider,
|
provider=model_config.provider,
|
||||||
api_key=model_config.api_key,
|
api_key=model_config.api_key,
|
||||||
base_url=model_config.api_base,
|
api_base=model_config.api_base,
|
||||||
is_omni=model_config.is_omni
|
is_omni=model_config.is_omni,
|
||||||
), type=model_config.model_type)
|
capability=model_config.capability,
|
||||||
|
model_type=ModelType.LLM
|
||||||
|
))
|
||||||
|
file_message = await multimodel_service.process_files(
|
||||||
|
files=[file]
|
||||||
|
)
|
||||||
|
if not file_message:
|
||||||
|
logger.warning(f"Unsupport file type {file}, model capability: {model_config.capability}")
|
||||||
|
return None
|
||||||
|
file_message = file_message[0]
|
||||||
try:
|
try:
|
||||||
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
||||||
with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
|
with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
|
||||||
opt_system_prompt = f.read()
|
opt_system_prompt = f.read()
|
||||||
rendered_system_message = Template(opt_system_prompt).render(file_type=file_type, language='zh')
|
rendered_system_message = Template(opt_system_prompt).render(file_type=file.type, language='zh')
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
|
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
|
||||||
messages = [
|
messages = [
|
||||||
@@ -242,8 +295,22 @@ class MemoryPerceptualService:
|
|||||||
]}
|
]}
|
||||||
]
|
]
|
||||||
result = await llm.ainvoke(messages)
|
result = await llm.ainvoke(messages)
|
||||||
content = json_repair.repair_json(result.content, return_objects=True)
|
content = result.content
|
||||||
path = urlparse(file_url).path
|
final_output = ""
|
||||||
|
if isinstance(content, list):
|
||||||
|
for msg in content:
|
||||||
|
if isinstance(msg, dict):
|
||||||
|
final_output += msg.get("text", "")
|
||||||
|
elif isinstance(msg, str):
|
||||||
|
final_output += msg
|
||||||
|
elif isinstance(content, dict):
|
||||||
|
final_output += content.get("text", "")
|
||||||
|
elif isinstance(content, str):
|
||||||
|
final_output = content
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexcept Model Output Type: {result.content}")
|
||||||
|
content = json_repair.repair_json(final_output, return_objects=True)
|
||||||
|
path = urlparse(file.url).path
|
||||||
filename = os.path.basename(path)
|
filename = os.path.basename(path)
|
||||||
filename = unquote(filename)
|
filename = unquote(filename)
|
||||||
file_ext = os.path.splitext(filename)[1]
|
file_ext = os.path.splitext(filename)[1]
|
||||||
@@ -252,21 +319,21 @@ class MemoryPerceptualService:
|
|||||||
stmt = select(FileMetadata).where(
|
stmt = select(FileMetadata).where(
|
||||||
FileMetadata.id == file_id
|
FileMetadata.id == file_id
|
||||||
)
|
)
|
||||||
file = self.db.execute(stmt).scalar_one_or_none()
|
file_obj = self.db.execute(stmt).scalar_one_or_none()
|
||||||
|
|
||||||
if file:
|
if file_obj:
|
||||||
filename = file.file_name
|
filename = file_obj.file_name
|
||||||
file_ext = file.file_ext
|
file_ext = file_obj.file_ext
|
||||||
except ValueError:
|
except ValueError:
|
||||||
business_logger.debug(f"Remote file, file_id={filename}")
|
business_logger.debug(f"Remote file, file_id={filename}")
|
||||||
if not file_ext:
|
if not file_ext:
|
||||||
if file_type == FileType.AUDIO:
|
if file.type == FileType.AUDIO:
|
||||||
file_ext = ".mp3"
|
file_ext = ".mp3"
|
||||||
elif file_type == FileType.VIDEO:
|
elif file.type == FileType.VIDEO:
|
||||||
file_ext = ".mp4"
|
file_ext = ".mp4"
|
||||||
elif file_type == FileType.DOCUMENT:
|
elif file.type == FileType.DOCUMENT:
|
||||||
file_ext = ".txt"
|
file_ext = ".txt"
|
||||||
elif file_type == FileType.IMAGE:
|
elif file.type == FileType.IMAGE:
|
||||||
file_ext = ".jpg"
|
file_ext = ".jpg"
|
||||||
filename += file_ext
|
filename += file_ext
|
||||||
file_content = {
|
file_content = {
|
||||||
@@ -274,11 +341,11 @@ class MemoryPerceptualService:
|
|||||||
"topic": content.get("topic"),
|
"topic": content.get("topic"),
|
||||||
"domain": content.get("domain")
|
"domain": content.get("domain")
|
||||||
}
|
}
|
||||||
if file_type in [FileType.IMAGE, FileType.VIDEO]:
|
if file.type in [FileType.IMAGE, FileType.VIDEO]:
|
||||||
file_modalities = {
|
file_modalities = {
|
||||||
"scene": content.get("scene", [])
|
"scene": content.get("scene", [])
|
||||||
}
|
}
|
||||||
elif file_type in [FileType.DOCUMENT]:
|
elif file.type in [FileType.DOCUMENT]:
|
||||||
file_modalities = {
|
file_modalities = {
|
||||||
"section_count": content.get("section_count", 0),
|
"section_count": content.get("section_count", 0),
|
||||||
"title": content.get("title", ""),
|
"title": content.get("title", ""),
|
||||||
@@ -288,10 +355,10 @@ class MemoryPerceptualService:
|
|||||||
file_modalities = {
|
file_modalities = {
|
||||||
"speaker_count": content.get("speaker_count", 0)
|
"speaker_count": content.get("speaker_count", 0)
|
||||||
}
|
}
|
||||||
self.repository.create_perceptual_memory(
|
memory = self.repository.create_perceptual_memory(
|
||||||
end_user_id=uuid.UUID(end_user_id),
|
end_user_id=uuid.UUID(end_user_id),
|
||||||
perceptual_type=PerceptualType.trans_from_file_type(file_type),
|
perceptual_type=PerceptualType.trans_from_file_type(file.type),
|
||||||
file_path=file_url,
|
file_path=file.url,
|
||||||
file_name=filename,
|
file_name=filename,
|
||||||
file_ext=file_ext,
|
file_ext=file_ext,
|
||||||
summary=content.get('summary', ""),
|
summary=content.get('summary', ""),
|
||||||
@@ -301,3 +368,4 @@ class MemoryPerceptualService:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
|
return memory
|
||||||
|
|||||||
@@ -11,9 +11,11 @@ import time
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.logging_config import get_config_logger, get_logger
|
from app.core.logging_config import get_config_logger, get_logger
|
||||||
from app.core.memory.analytics.hot_memory_tags import (
|
from app.core.memory.analytics.hot_memory_tags import (
|
||||||
get_hot_memory_tags,
|
|
||||||
get_raw_tags_from_db,
|
get_raw_tags_from_db,
|
||||||
filter_tags_with_llm,
|
filter_tags_with_llm,
|
||||||
)
|
)
|
||||||
@@ -32,8 +34,6 @@ from app.schemas.memory_storage_schema import (
|
|||||||
)
|
)
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
from app.utils.sse_utils import format_sse_message
|
from app.utils.sse_utils import format_sse_message
|
||||||
from dotenv import load_dotenv
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
config_logger = get_config_logger()
|
config_logger = get_config_logger()
|
||||||
@@ -69,7 +69,7 @@ class MemoryStorageService:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class DataConfigService: # 数据配置服务类(PostgreSQL)
|
class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||||
"""Service layer for config params CRUD.
|
"""Service layer for config params CRUD.
|
||||||
|
|
||||||
使用 SQLAlchemy ORM 进行数据库操作。
|
使用 SQLAlchemy ORM 进行数据库操作。
|
||||||
@@ -114,7 +114,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
return data_list
|
return data_list
|
||||||
|
|
||||||
# --- Create ---
|
# --- Create ---
|
||||||
def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述)
|
def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述)
|
||||||
# 业务层检查同一工作空间下是否已存在同名配置
|
# 业务层检查同一工作空间下是否已存在同名配置
|
||||||
if params.workspace_id and params.config_name:
|
if params.workspace_id and params.config_name:
|
||||||
from app.models.memory_config_model import MemoryConfig
|
from app.models.memory_config_model import MemoryConfig
|
||||||
@@ -183,20 +183,20 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# --- Delete ---
|
# --- Delete ---
|
||||||
def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置ID)
|
def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置ID)
|
||||||
success = MemoryConfigRepository.delete(self.db, key.config_id)
|
success = MemoryConfigRepository.delete(self.db, key.config_id)
|
||||||
if not success:
|
if not success:
|
||||||
raise ValueError("未找到配置")
|
raise ValueError("未找到配置")
|
||||||
return {"affected": 1}
|
return {"affected": 1}
|
||||||
|
|
||||||
# --- Update ---
|
# --- Update ---
|
||||||
def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数
|
def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数
|
||||||
config = MemoryConfigRepository.update(self.db, update)
|
config = MemoryConfigRepository.update(self.db, update)
|
||||||
if not config:
|
if not config:
|
||||||
raise ValueError("未找到配置")
|
raise ValueError("未找到配置")
|
||||||
return {"affected": 1}
|
return {"affected": 1}
|
||||||
|
|
||||||
def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数
|
def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数
|
||||||
config = MemoryConfigRepository.update_extracted(self.db, update)
|
config = MemoryConfigRepository.update_extracted(self.db, update)
|
||||||
if not config:
|
if not config:
|
||||||
raise ValueError("未找到配置")
|
raise ValueError("未找到配置")
|
||||||
@@ -207,14 +207,14 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
# 使用新方法: MemoryForgetService.read_forgetting_config() 和 MemoryForgetService.update_forgetting_config()
|
# 使用新方法: MemoryForgetService.read_forgetting_config() 和 MemoryForgetService.update_forgetting_config()
|
||||||
|
|
||||||
# --- Read ---
|
# --- Read ---
|
||||||
def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数
|
def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数
|
||||||
result = MemoryConfigRepository.get_extracted_config(self.db, key.config_id)
|
result = MemoryConfigRepository.get_extracted_config(self.db, key.config_id)
|
||||||
if not result:
|
if not result:
|
||||||
raise ValueError("未找到配置")
|
raise ValueError("未找到配置")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# --- Read All ---
|
# --- Read All ---
|
||||||
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
|
def get_all(self, workspace_id=None) -> List[Dict[str, Any]]: # 获取所有配置参数
|
||||||
results = MemoryConfigRepository.get_all(self.db, workspace_id)
|
results = MemoryConfigRepository.get_all(self.db, workspace_id)
|
||||||
|
|
||||||
# 检查并修正 pruning_scene 与 scene_name 不一致的记录
|
# 检查并修正 pruning_scene 与 scene_name 不一致的记录
|
||||||
@@ -241,11 +241,10 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
config_id_old = None
|
config_id_old = None
|
||||||
|
|
||||||
|
|
||||||
if config_id_old:
|
if config_id_old:
|
||||||
memory_config=config_id_old
|
memory_config = config_id_old
|
||||||
else:
|
else:
|
||||||
memory_config=config.config_id
|
memory_config = config.config_id
|
||||||
config_dict = {
|
config_dict = {
|
||||||
"config_id": memory_config,
|
"config_id": memory_config,
|
||||||
"config_name": config.config_name,
|
"config_name": config.config_name,
|
||||||
@@ -289,7 +288,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
# 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
|
# 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
|
||||||
return self._convert_timestamps_to_format(data_list)
|
return self._convert_timestamps_to_format(data_list)
|
||||||
|
|
||||||
|
|
||||||
async def pilot_run_stream(self, payload: ConfigPilotRun, language: str = "zh") -> AsyncGenerator[str, None]:
|
async def pilot_run_stream(self, payload: ConfigPilotRun, language: str = "zh") -> AsyncGenerator[str, None]:
|
||||||
"""
|
"""
|
||||||
流式执行试运行,产生 SSE 格式的进度事件
|
流式执行试运行,产生 SSE 格式的进度事件
|
||||||
@@ -344,11 +342,13 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
# 关联了本体场景,优先使用 custom_text
|
# 关联了本体场景,优先使用 custom_text
|
||||||
if hasattr(payload, 'custom_text') and payload.custom_text:
|
if hasattr(payload, 'custom_text') and payload.custom_text:
|
||||||
dialogue_text = payload.custom_text.strip()
|
dialogue_text = payload.custom_text.strip()
|
||||||
logger.info(f"[PILOT_RUN_STREAM] Using custom_text for scene_id={memory_config.scene_id}, length: {len(dialogue_text)}")
|
logger.info(
|
||||||
|
f"[PILOT_RUN_STREAM] Using custom_text for scene_id={memory_config.scene_id}, length: {len(dialogue_text)}")
|
||||||
else:
|
else:
|
||||||
# 如果没有提供 custom_text,回退到 dialogue_text
|
# 如果没有提供 custom_text,回退到 dialogue_text
|
||||||
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
|
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
|
||||||
logger.info(f"[PILOT_RUN_STREAM] No custom_text provided, using dialogue_text for scene_id={memory_config.scene_id}")
|
logger.info(
|
||||||
|
f"[PILOT_RUN_STREAM] No custom_text provided, using dialogue_text for scene_id={memory_config.scene_id}")
|
||||||
else:
|
else:
|
||||||
# 没有关联本体场景,使用 dialogue_text
|
# 没有关联本体场景,使用 dialogue_text
|
||||||
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
|
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
|
||||||
@@ -360,7 +360,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
|
|
||||||
logger.info(f"[PILOT_RUN_STREAM] Final text preview: {dialogue_text[:100]}")
|
logger.info(f"[PILOT_RUN_STREAM] Final text preview: {dialogue_text[:100]}")
|
||||||
|
|
||||||
|
|
||||||
# 步骤 2: 创建进度回调函数捕获管线进度
|
# 步骤 2: 创建进度回调函数捕获管线进度
|
||||||
# 使用队列在回调和生成器之间传递进度事件
|
# 使用队列在回调和生成器之间传递进度事件
|
||||||
progress_queue: asyncio.Queue = asyncio.Queue()
|
progress_queue: asyncio.Queue = asyncio.Queue()
|
||||||
@@ -382,7 +381,8 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
try:
|
try:
|
||||||
from app.services.pilot_run_service import run_pilot_extraction
|
from app.services.pilot_run_service import run_pilot_extraction
|
||||||
|
|
||||||
logger.info(f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}")
|
logger.info(
|
||||||
|
f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}")
|
||||||
await run_pilot_extraction(
|
await run_pilot_extraction(
|
||||||
memory_config=memory_config,
|
memory_config=memory_config,
|
||||||
dialogue_text=dialogue_text,
|
dialogue_text=dialogue_text,
|
||||||
@@ -483,11 +483,10 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
"time": int(time.time() * 1000)
|
"time": int(time.time() * 1000)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
async def _compute_ontology_coverage(
|
async def _compute_ontology_coverage(
|
||||||
self,
|
self,
|
||||||
extracted_result: Dict[str, Any],
|
extracted_result: Dict[str, Any],
|
||||||
memory_config,
|
memory_config,
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
"""根据提取结果中的实体类型,与场景/通用本体类型做互斥分类统计。
|
"""根据提取结果中的实体类型,与场景/通用本体类型做互斥分类统计。
|
||||||
|
|
||||||
@@ -580,8 +579,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
|
|
||||||
# -------------------- Neo4j Search & Analytics (fused from data_search_service.py) --------------------
|
# -------------------- Neo4j Search & Analytics (fused from data_search_service.py) --------------------
|
||||||
# Ensure env for connector (e.g., NEO4J_PASSWORD)
|
# Ensure env for connector (e.g., NEO4J_PASSWORD)
|
||||||
load_dotenv()
|
|
||||||
_neo4j_connector = Neo4jConnector()
|
|
||||||
|
|
||||||
|
|
||||||
async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||||
@@ -701,10 +698,11 @@ async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]
|
|||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def analytics_hot_memory_tags(
|
async def analytics_hot_memory_tags(
|
||||||
db: Session,
|
db: Session,
|
||||||
current_user: User,
|
current_user: User,
|
||||||
limit: int = 10
|
limit: int = 10
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
获取热门记忆标签,按数量排序并返回前N个
|
获取热门记忆标签,按数量排序并返回前N个
|
||||||
@@ -815,11 +813,11 @@ async def analytics_recent_activity_stats(workspace_id: Optional[str] = None) ->
|
|||||||
source = "log"
|
source = "log"
|
||||||
|
|
||||||
total = (
|
total = (
|
||||||
stats.get("chunk_count", 0)
|
stats.get("chunk_count", 0)
|
||||||
+ stats.get("statements_count", 0)
|
+ stats.get("statements_count", 0)
|
||||||
+ stats.get("triplet_entities_count", 0)
|
+ stats.get("triplet_entities_count", 0)
|
||||||
+ stats.get("triplet_relations_count", 0)
|
+ stats.get("triplet_relations_count", 0)
|
||||||
+ stats.get("temporal_count", 0)
|
+ stats.get("temporal_count", 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 计算"最新一次活动多久前"(仅日志来源时有效)
|
# 计算"最新一次活动多久前"(仅日志来源时有效)
|
||||||
@@ -845,5 +843,3 @@ async def analytics_recent_activity_stats(workspace_id: Optional[str] = None) ->
|
|||||||
|
|
||||||
data = {"total": total, "stats": stats, "latest_relative": latest_relative, "source": source}
|
data = {"total": total, "stats": stats, "latest_relative": latest_relative, "source": source}
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -9,17 +9,15 @@
|
|||||||
- OpenAI: 支持 URL 和 base64 格式
|
- OpenAI: 支持 URL 和 base64 格式
|
||||||
"""
|
"""
|
||||||
import base64
|
import base64
|
||||||
|
import csv
|
||||||
import io
|
import io
|
||||||
import uuid
|
import json
|
||||||
import zipfile
|
import zipfile
|
||||||
import chardet
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
import csv
|
|
||||||
import json
|
|
||||||
|
|
||||||
import PyPDF2
|
import PyPDF2
|
||||||
|
import chardet
|
||||||
import httpx
|
import httpx
|
||||||
import magic
|
import magic
|
||||||
import openpyxl
|
import openpyxl
|
||||||
@@ -35,7 +33,6 @@ from app.models.file_metadata_model import FileMetadata
|
|||||||
from app.schemas.app_schema import FileInput, FileType, TransferMethod
|
from app.schemas.app_schema import FileInput, FileType, TransferMethod
|
||||||
from app.schemas.model_schema import ModelInfo
|
from app.schemas.model_schema import ModelInfo
|
||||||
from app.services.audio_transcription_service import AudioTranscriptionService
|
from app.services.audio_transcription_service import AudioTranscriptionService
|
||||||
from app.tasks import write_perceptual_memory
|
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
@@ -342,15 +339,12 @@ class MultimodalService:
|
|||||||
|
|
||||||
async def process_files(
|
async def process_files(
|
||||||
self,
|
self,
|
||||||
end_user_id: uuid.UUID | str,
|
|
||||||
files: Optional[List[FileInput]],
|
files: Optional[List[FileInput]],
|
||||||
|
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
处理文件列表,返回 LLM 可用的格式
|
处理文件列表,返回 LLM 可用的格式
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
end_user_id: 用户ID
|
|
||||||
files: 文件输入列表
|
files: 文件输入列表
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -358,8 +352,6 @@ class MultimodalService:
|
|||||||
"""
|
"""
|
||||||
if not files:
|
if not files:
|
||||||
return []
|
return []
|
||||||
if isinstance(end_user_id, uuid.UUID):
|
|
||||||
end_user_id = str(end_user_id)
|
|
||||||
|
|
||||||
# 获取对应的策略
|
# 获取对应的策略
|
||||||
# dashscope 的 omni 模型使用 OpenAI 兼容格式
|
# dashscope 的 omni 模型使用 OpenAI 兼容格式
|
||||||
@@ -380,23 +372,15 @@ class MultimodalService:
|
|||||||
if file.type == FileType.IMAGE and "vision" in self.capability:
|
if file.type == FileType.IMAGE and "vision" in self.capability:
|
||||||
is_support, content = await self._process_image(file, strategy)
|
is_support, content = await self._process_image(file, strategy)
|
||||||
result.append(content)
|
result.append(content)
|
||||||
if is_support:
|
|
||||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
|
||||||
elif file.type == FileType.DOCUMENT:
|
elif file.type == FileType.DOCUMENT:
|
||||||
is_support, content = await self._process_document(file, strategy)
|
is_support, content = await self._process_document(file, strategy)
|
||||||
result.append(content)
|
result.append(content)
|
||||||
if is_support:
|
|
||||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
|
||||||
elif file.type == FileType.AUDIO and "audio" in self.capability:
|
elif file.type == FileType.AUDIO and "audio" in self.capability:
|
||||||
is_support, content = await self._process_audio(file, strategy)
|
is_support, content = await self._process_audio(file, strategy)
|
||||||
result.append(content)
|
result.append(content)
|
||||||
if is_support:
|
|
||||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
|
||||||
elif file.type == FileType.VIDEO and "video" in self.capability:
|
elif file.type == FileType.VIDEO and "video" in self.capability:
|
||||||
is_support, content = await self._process_video(file, strategy)
|
is_support, content = await self._process_video(file, strategy)
|
||||||
result.append(content)
|
result.append(content)
|
||||||
if is_support:
|
|
||||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
|
||||||
else:
|
else:
|
||||||
logger.warning(f"不支持的文件类型: {file.type}")
|
logger.warning(f"不支持的文件类型: {file.type}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -483,17 +467,6 @@ class MultimodalService:
|
|||||||
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
|
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def write_perceptual_memory(
|
|
||||||
self,
|
|
||||||
end_user_id: str,
|
|
||||||
file_type: str,
|
|
||||||
file_url: str,
|
|
||||||
file_message: dict
|
|
||||||
):
|
|
||||||
"""写入感知记忆"""
|
|
||||||
if end_user_id and self.api_config:
|
|
||||||
write_perceptual_memory.delay(end_user_id, self.api_config.model_dump(), file_type, file_url, file_message)
|
|
||||||
|
|
||||||
async def _process_image(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]:
|
async def _process_image(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
处理图片文件
|
处理图片文件
|
||||||
|
|||||||
@@ -297,9 +297,12 @@ async def run_pilot_extraction(
|
|||||||
chunk_nodes,
|
chunk_nodes,
|
||||||
statement_nodes,
|
statement_nodes,
|
||||||
entity_nodes,
|
entity_nodes,
|
||||||
|
_,
|
||||||
statement_chunk_edges,
|
statement_chunk_edges,
|
||||||
statement_entity_edges,
|
statement_entity_edges,
|
||||||
entity_edges,
|
entity_edges,
|
||||||
|
_,
|
||||||
|
_
|
||||||
) = extraction_result
|
) = extraction_result
|
||||||
|
|
||||||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||||||
|
|||||||
@@ -1887,7 +1887,8 @@ async def _extract_node_properties(label: str, properties: Dict[str, Any],node_
|
|||||||
"Chunk": ["content", "created_at"],
|
"Chunk": ["content", "created_at"],
|
||||||
"Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption","emotion_keywords","emotion_type","emotion_subject"],
|
"Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption","emotion_keywords","emotion_type","emotion_subject"],
|
||||||
"ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption","aliases","connect_strength"],
|
"ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption","aliases","connect_strength"],
|
||||||
"MemorySummary": ["summary", "content", "created_at", "caption"] # 添加 content 字段
|
"MemorySummary": ["summary", "content", "created_at", "caption"], # 添加 content 字段
|
||||||
|
"Perceptual": ["file_name", "file_path", "file_type", "domain", "topic", "keywords", "summary"]
|
||||||
}
|
}
|
||||||
|
|
||||||
# 获取该节点类型的白名单字段
|
# 获取该节点类型的白名单字段
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from app.core.workflow.variable.base_variable import FileObject
|
|||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.models import App
|
from app.models import App
|
||||||
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
||||||
|
from app.repositories import knowledge_repository
|
||||||
from app.repositories.workflow_repository import (
|
from app.repositories.workflow_repository import (
|
||||||
WorkflowConfigRepository,
|
WorkflowConfigRepository,
|
||||||
WorkflowExecutionRepository,
|
WorkflowExecutionRepository,
|
||||||
@@ -29,6 +30,7 @@ from app.schemas import DraftRunRequest, FileInput
|
|||||||
from app.services.conversation_service import ConversationService
|
from app.services.conversation_service import ConversationService
|
||||||
from app.services.multi_agent_service import convert_uuids_to_str
|
from app.services.multi_agent_service import convert_uuids_to_str
|
||||||
from app.services.multimodal_service import MultimodalService
|
from app.services.multimodal_service import MultimodalService
|
||||||
|
from app.services.workspace_service import get_workspace_storage_type_without_auth
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -540,6 +542,25 @@ class WorkflowService:
|
|||||||
mapped = internal_event
|
mapped = internal_event
|
||||||
return mapped
|
return mapped
|
||||||
|
|
||||||
|
def _get_memory_store_info(self, workspace_id: uuid.UUID) -> tuple[str, str]:
|
||||||
|
storage_type = get_workspace_storage_type_without_auth(self.db, workspace_id)
|
||||||
|
user_rag_memory_id = ""
|
||||||
|
if storage_type == "rag":
|
||||||
|
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||||
|
db=self.db,
|
||||||
|
name="USER_RAG_MERORY",
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
if knowledge:
|
||||||
|
user_rag_memory_id = str(knowledge.id)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"No knowledge base named 'USER_RAG_MEMORY' found, "
|
||||||
|
f"workspace_id: {workspace_id}, will use neo4j storage"
|
||||||
|
)
|
||||||
|
storage_type = 'neo4j'
|
||||||
|
return storage_type, user_rag_memory_id
|
||||||
|
|
||||||
# ==================== 工作流执行 ====================
|
# ==================== 工作流执行 ====================
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
@@ -607,6 +628,7 @@ class WorkflowService:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
files = await self._handle_file_input(payload.files)
|
files = await self._handle_file_input(payload.files)
|
||||||
|
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
|
||||||
input_data["files"] = files
|
input_data["files"] = files
|
||||||
message_id = uuid.uuid4()
|
message_id = uuid.uuid4()
|
||||||
# 更新状态为运行中
|
# 更新状态为运行中
|
||||||
@@ -631,7 +653,9 @@ class WorkflowService:
|
|||||||
input_data=input_data,
|
input_data=input_data,
|
||||||
execution_id=execution.execution_id,
|
execution_id=execution.execution_id,
|
||||||
workspace_id=str(workspace_id),
|
workspace_id=str(workspace_id),
|
||||||
user_id=payload.user_id
|
user_id=payload.user_id,
|
||||||
|
memory_storage_type=storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id
|
||||||
)
|
)
|
||||||
# 更新执行结果
|
# 更新执行结果
|
||||||
if result.get("status") == "completed":
|
if result.get("status") == "completed":
|
||||||
@@ -780,6 +804,7 @@ class WorkflowService:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
files = await self._handle_file_input(payload.files)
|
files = await self._handle_file_input(payload.files)
|
||||||
|
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
|
||||||
input_data["files"] = files
|
input_data["files"] = files
|
||||||
self.update_execution_status(execution.execution_id, "running")
|
self.update_execution_status(execution.execution_id, "running")
|
||||||
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
|
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
|
||||||
@@ -801,6 +826,8 @@ class WorkflowService:
|
|||||||
execution_id=execution.execution_id,
|
execution_id=execution.execution_id,
|
||||||
workspace_id=str(workspace_id),
|
workspace_id=str(workspace_id),
|
||||||
user_id=payload.user_id,
|
user_id=payload.user_id,
|
||||||
|
memory_storage_type=storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id
|
||||||
):
|
):
|
||||||
if event.get("event") == "workflow_end":
|
if event.get("event") == "workflow_end":
|
||||||
status = event.get("data", {}).get("status")
|
status = event.get("data", {}).get("status")
|
||||||
|
|||||||
@@ -863,7 +863,7 @@ def get_workspace_storage_type(
|
|||||||
def get_workspace_storage_type_without_auth(
|
def get_workspace_storage_type_without_auth(
|
||||||
db: Session,
|
db: Session,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
) -> Optional[str]:
|
) -> str:
|
||||||
"""获取工作空间的存储类型(无需权限验证,用于公开分享等场景)
|
"""获取工作空间的存储类型(无需权限验证,用于公开分享等场景)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -1073,9 +1073,15 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
|
|||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="app.core.memory.agent.write_message", bind=True)
|
@celery_app.task(name="app.core.memory.agent.write_message", bind=True)
|
||||||
def write_message_task(self, end_user_id: str, message: list[dict], config_id: str | int, storage_type: str,
|
def write_message_task(
|
||||||
user_rag_memory_id: str,
|
self,
|
||||||
language: str = "zh") -> Dict[str, Any]:
|
end_user_id: str,
|
||||||
|
message: list[dict],
|
||||||
|
config_id: str | int,
|
||||||
|
storage_type: str,
|
||||||
|
user_rag_memory_id: str,
|
||||||
|
language: str = "zh"
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Celery task to process a write message via MemoryAgentService.
|
"""Celery task to process a write message via MemoryAgentService.
|
||||||
Args:
|
Args:
|
||||||
end_user_id: Group ID for the memory agent (also used as end_user_id)
|
end_user_id: Group ID for the memory agent (also used as end_user_id)
|
||||||
@@ -1091,7 +1097,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
|||||||
Raises:
|
Raises:
|
||||||
Exception on failure
|
Exception on failure
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, "
|
f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, "
|
||||||
f"config_id={config_id} (type: {type(config_id).__name__}), "
|
f"config_id={config_id} (type: {type(config_id).__name__}), "
|
||||||
@@ -1105,14 +1110,11 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
|||||||
try:
|
try:
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
actual_config_id = resolve_config_id(config_id, db)
|
actual_config_id = resolve_config_id(config_id, db)
|
||||||
print(100 * '-')
|
logger.info(f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} "
|
||||||
print(actual_config_id)
|
f"(type: {type(actual_config_id).__name__})")
|
||||||
print(100 * '-')
|
|
||||||
logger.info(
|
|
||||||
f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} (type: {type(actual_config_id).__name__})")
|
|
||||||
except (ValueError, AttributeError) as e:
|
except (ValueError, AttributeError) as e:
|
||||||
logger.error(
|
logger.error(f"[CELERY WRITE] Invalid config_id format: {config_id} "
|
||||||
f"[CELERY WRITE] Invalid config_id format: {config_id} (type: {type(config_id).__name__}), error: {e}")
|
f"(type: {type(config_id).__name__}), error: {e}")
|
||||||
return {
|
return {
|
||||||
"status": "FAILURE",
|
"status": "FAILURE",
|
||||||
"error": f"Invalid config_id format: {config_id} - {str(e)}",
|
"error": f"Invalid config_id format: {config_id} - {str(e)}",
|
||||||
@@ -1151,8 +1153,8 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
|||||||
result = loop.run_until_complete(_run())
|
result = loop.run_until_complete(_run())
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"[CELERY WRITE] Task completed successfully "
|
||||||
f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
|
f"- elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
|
||||||
|
|
||||||
# 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用
|
# 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用
|
||||||
try:
|
try:
|
||||||
@@ -1167,7 +1169,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
|||||||
)
|
)
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
logger.warning(f"[CELERY WRITE] 写入 last_done 时间戳失败(不影响主流程): {_e}")
|
logger.warning(f"[CELERY WRITE] 写入 last_done 时间戳失败(不影响主流程): {_e}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "SUCCESS",
|
"status": "SUCCESS",
|
||||||
"result": result,
|
"result": result,
|
||||||
@@ -2611,57 +2612,6 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(
|
|
||||||
name="app.tasks.write_perceptual_memory",
|
|
||||||
bind=True,
|
|
||||||
ignore_result=True,
|
|
||||||
max_retries=0,
|
|
||||||
acks_late=False,
|
|
||||||
time_limit=3600,
|
|
||||||
soft_time_limit=3300,
|
|
||||||
)
|
|
||||||
def write_perceptual_memory(
|
|
||||||
self,
|
|
||||||
end_user_id: str,
|
|
||||||
model_api_config: dict,
|
|
||||||
file_type: str,
|
|
||||||
file_url: str,
|
|
||||||
file_message: dict
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Write perceptual memory for a user into PostgreSQL and Neo4j.
|
|
||||||
|
|
||||||
This task generates or updates the user's perceptual memory
|
|
||||||
in the backend databases. It is intended to be executed asynchronously
|
|
||||||
via Celery.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
end_user_id (uuid.UUID): The unique identifier of the end user.
|
|
||||||
model_api_config (ModelInfo): API configuration for the model
|
|
||||||
used to generate perceptual memory.
|
|
||||||
file_type (str): The file type
|
|
||||||
file_url (url): The url of file
|
|
||||||
file_message (dict): The file message containing details about the file
|
|
||||||
to be processed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
file_url_md5 = hashlib.md5(file_url.encode("utf-8")).hexdigest()
|
|
||||||
set_asyncio_event_loop()
|
|
||||||
with RedisLock(f"perceptual:{file_url_md5}", redis_client=get_sync_redis_client()):
|
|
||||||
model_info = ModelInfo(**model_api_config)
|
|
||||||
with get_db_context() as db:
|
|
||||||
memory_perceptual_service = MemoryPerceptualService(db)
|
|
||||||
return asyncio.run(memory_perceptual_service.generate_perceptual_memory(
|
|
||||||
end_user_id,
|
|
||||||
model_info,
|
|
||||||
file_type,
|
|
||||||
file_url,
|
|
||||||
file_message,
|
|
||||||
))
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 社区聚类补全任务(触发型)
|
# 社区聚类补全任务(触发型)
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -2672,7 +2622,7 @@ def write_perceptual_memory(
|
|||||||
ignore_result=False,
|
ignore_result=False,
|
||||||
max_retries=0,
|
max_retries=0,
|
||||||
acks_late=False,
|
acks_late=False,
|
||||||
time_limit=7200, # 2小时硬超时
|
time_limit=7200, # 2小时硬超时
|
||||||
soft_time_limit=6900,
|
soft_time_limit=6900,
|
||||||
)
|
)
|
||||||
def init_community_clustering_for_users(self, end_user_ids: List[str], workspace_id: Optional[str] = None) -> Dict[str, Any]:
|
def init_community_clustering_for_users(self, end_user_ids: List[str], workspace_id: Optional[str] = None) -> Dict[str, Any]:
|
||||||
@@ -2787,7 +2737,8 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace
|
|||||||
embedding_model_id=embedding_model_id,
|
embedding_model_id=embedding_model_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"[CommunityCluster] 用户 {end_user_id} 有 {len(entities)} 个实体,开始全量聚类,llm_model_id={llm_model_id}")
|
logger.info(
|
||||||
|
f"[CommunityCluster] 用户 {end_user_id} 有 {len(entities)} 个实体,开始全量聚类,llm_model_id={llm_model_id}")
|
||||||
await engine.full_clustering(end_user_id)
|
await engine.full_clustering(end_user_id)
|
||||||
initialized += 1
|
initialized += 1
|
||||||
logger.info(f"[CommunityCluster] 用户 {end_user_id} 聚类完成")
|
logger.info(f"[CommunityCluster] 用户 {end_user_id} 聚类完成")
|
||||||
@@ -2810,12 +2761,6 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
try:
|
|
||||||
import nest_asyncio
|
|
||||||
nest_asyncio.apply()
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
loop = set_asyncio_event_loop()
|
loop = set_asyncio_event_loop()
|
||||||
result = loop.run_until_complete(_run())
|
result = loop.run_until_complete(_run())
|
||||||
result["elapsed_time"] = time.time() - start_time
|
result["elapsed_time"] = time.time() - start_time
|
||||||
|
|||||||
Reference in New Issue
Block a user