Merge pull request #676 from SuanmoSuanyangTechnology/feature/multimodel_memory

feat(memory, model): update multi-modal memory write and model list API
This commit is contained in:
Ke Sun
2026-03-24 15:26:38 +08:00
committed by GitHub
45 changed files with 1607 additions and 1220 deletions

View File

@@ -118,142 +118,142 @@ async def download_log(
return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e))
@router.post("/writer_service", response_model=ApiResponse) # @router.post("/writer_service", response_model=ApiResponse)
@cur_workspace_access_guard() # @cur_workspace_access_guard()
async def write_server( # async def write_server(
user_input: Write_UserInput, # user_input: Write_UserInput,
language_type: str = Header(default=None, alias="X-Language-Type"), # language_type: str = Header(default=None, alias="X-Language-Type"),
db: Session = Depends(get_db), # db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) # current_user: User = Depends(get_current_user)
): # ):
""" # """
Write service endpoint - processes write operations synchronously # Write service endpoint - processes write operations synchronously
#
Args: # Args:
user_input: Write request containing message and end_user_id # user_input: Write request containing message and end_user_id
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递 # language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
#
Returns: # Returns:
Response with write operation status # Response with write operation status
""" # """
# 使用集中化的语言校验 # # 使用集中化的语言校验
language = get_language_from_header(language_type) # language = get_language_from_header(language_type)
#
config_id = user_input.config_id # config_id = user_input.config_id
workspace_id = current_user.current_workspace_id # workspace_id = current_user.current_workspace_id
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") # api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
#
# 获取 storage_type如果为 None 则使用默认值 # # 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type( # storage_type = workspace_service.get_workspace_storage_type(
db=db, # db=db,
workspace_id=workspace_id, # workspace_id=workspace_id,
user=current_user # user=current_user
) # )
if storage_type is None: storage_type = 'neo4j' # if storage_type is None: storage_type = 'neo4j'
user_rag_memory_id = '' # user_rag_memory_id = ''
#
# 如果 storage_type 是 rag必须确保有有效的 user_rag_memory_id # # 如果 storage_type 是 rag必须确保有有效的 user_rag_memory_id
if storage_type == 'rag': # if storage_type == 'rag':
if workspace_id: # if workspace_id:
knowledge = knowledge_repository.get_knowledge_by_name( # knowledge = knowledge_repository.get_knowledge_by_name(
db=db, # db=db,
name="USER_RAG_MERORY", # name="USER_RAG_MERORY",
workspace_id=workspace_id # workspace_id=workspace_id
) # )
if knowledge: # if knowledge:
user_rag_memory_id = str(knowledge.id) # user_rag_memory_id = str(knowledge.id)
else: # else:
api_logger.warning( # api_logger.warning(
f"未找到名为 'USER_RAG_MERORY' 的知识库workspace_id: {workspace_id},将使用 neo4j 存储") # f"未找到名为 'USER_RAG_MERORY' 的知识库workspace_id: {workspace_id},将使用 neo4j 存储")
storage_type = 'neo4j' # storage_type = 'neo4j'
else: # else:
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储") # api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
storage_type = 'neo4j' # storage_type = 'neo4j'
#
api_logger.info( # api_logger.info(
f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") # f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
try: # try:
messages_list = memory_agent_service.get_messages_list(user_input) # messages_list = memory_agent_service.get_messages_list(user_input)
result = await memory_agent_service.write_memory( # result = await memory_agent_service.write_memory(
user_input.end_user_id, # user_input.end_user_id,
messages_list, # messages_list,
config_id, # config_id,
db, # db,
storage_type, # storage_type,
user_rag_memory_id, # user_rag_memory_id,
language # language
) # )
#
return success(data=result, msg="写入成功") # return success(data=result, msg="写入成功")
except BaseException as e: # except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup # # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
if hasattr(e, 'exceptions'): # if hasattr(e, 'exceptions'):
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions] # error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
detailed_error = "; ".join(error_messages) # detailed_error = "; ".join(error_messages)
api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True) # api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error) # return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
api_logger.error(f"Write operation error: {str(e)}", exc_info=True) # api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e)) # return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
#
#
@router.post("/writer_service_async", response_model=ApiResponse) # @router.post("/writer_service_async", response_model=ApiResponse)
@cur_workspace_access_guard() # @cur_workspace_access_guard()
async def write_server_async( # async def write_server_async(
user_input: Write_UserInput, # user_input: Write_UserInput,
language_type: str = Header(default=None, alias="X-Language-Type"), # language_type: str = Header(default=None, alias="X-Language-Type"),
db: Session = Depends(get_db), # db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) # current_user: User = Depends(get_current_user)
): # ):
""" # """
Async write service endpoint - enqueues write processing to Celery # Async write service endpoint - enqueues write processing to Celery
#
Args: # Args:
user_input: Write request containing message and end_user_id # user_input: Write request containing message and end_user_id
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递 # language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
#
Returns: # Returns:
Task ID for tracking async operation # Task ID for tracking async operation
Use GET /memory/write_result/{task_id} to check task status and get result # Use GET /memory/write_result/{task_id} to check task status and get result
""" # """
# 使用集中化的语言校验 # # 使用集中化的语言校验
language = get_language_from_header(language_type) # language = get_language_from_header(language_type)
#
config_id = user_input.config_id # config_id = user_input.config_id
workspace_id = current_user.current_workspace_id # workspace_id = current_user.current_workspace_id
api_logger.info( # api_logger.info(
f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") # f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
#
# 获取 storage_type如果为 None 则使用默认值 # # 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type( # storage_type = workspace_service.get_workspace_storage_type(
db=db, # db=db,
workspace_id=workspace_id, # workspace_id=workspace_id,
user=current_user # user=current_user
) # )
if storage_type is None: storage_type = 'neo4j' # if storage_type is None: storage_type = 'neo4j'
user_rag_memory_id = '' # user_rag_memory_id = ''
if workspace_id: # if workspace_id:
#
knowledge = knowledge_repository.get_knowledge_by_name( # knowledge = knowledge_repository.get_knowledge_by_name(
db=db, # db=db,
name="USER_RAG_MERORY", # name="USER_RAG_MERORY",
workspace_id=workspace_id # workspace_id=workspace_id
) # )
if knowledge: user_rag_memory_id = str(knowledge.id) # if knowledge: user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") # api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
try: # try:
# 获取标准化的消息列表 # # 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input) # messages_list = memory_agent_service.get_messages_list(user_input)
#
task = celery_app.send_task( # task = celery_app.send_task(
"app.core.memory.agent.write_message", # "app.core.memory.agent.write_message",
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language] # args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
) # )
api_logger.info(f"Write task queued: {task.id}") # api_logger.info(f"Write task queued: {task.id}")
#
return success(data={"task_id": task.id}, msg="写入任务已提交") # return success(data={"task_id": task.id}, msg="写入任务已提交")
except Exception as e: # except Exception as e:
api_logger.error(f"Async write operation failed: {str(e)}") # api_logger.error(f"Async write operation failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e)) # return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
@router.post("/read_service", response_model=ApiResponse) @router.post("/read_service", response_model=ApiResponse)

View File

@@ -54,8 +54,8 @@ router = APIRouter(
@router.get("/info", response_model=ApiResponse) @router.get("/info", response_model=ApiResponse)
async def get_storage_info( async def get_storage_info(
storage_id: str, storage_id: str,
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
Example wrapper endpoint - retrieves storage information Example wrapper endpoint - retrieves storage information
@@ -75,17 +75,12 @@ async def get_storage_info(
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
def create_config( def create_config(
payload: ConfigParamsCreate, payload: ConfigParamsCreate,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"), x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间 # 检查用户是否已选择工作空间
@@ -107,9 +102,11 @@ def create_config(
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}") api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
lang = get_language_from_header(x_language_type) lang = get_language_from_header(x_language_type)
if lang == "en": if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.") msg = fail(BizCode.BAD_REQUEST, "Config name already exists",
f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
else: else:
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称") msg = fail(BizCode.BAD_REQUEST, "配置名称已存在",
f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
return JSONResponse(status_code=400, content=msg) return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Create config failed: {err_str}") api_logger.error(f"Create config failed: {err_str}")
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str) return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
@@ -119,9 +116,11 @@ def create_config(
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}") api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
lang = get_language_from_header(x_language_type) lang = get_language_from_header(x_language_type)
if lang == "en": if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.") msg = fail(BizCode.BAD_REQUEST, "Config name already exists",
f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
else: else:
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称") msg = fail(BizCode.BAD_REQUEST, "配置名称已存在",
f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
return JSONResponse(status_code=400, content=msg) return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Create config failed: {str(e)}") api_logger.error(f"Create config failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
@@ -129,10 +128,10 @@ def create_config(
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称) @router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
def delete_config( def delete_config(
config_id: UUID|int, config_id: UUID | int,
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"), force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
"""删除记忆配置(带终端用户保护) """删除记忆配置(带终端用户保护)
@@ -145,7 +144,7 @@ def delete_config(
force: 设置为 true 可强制删除(即使有终端用户正在使用) force: 设置为 true 可强制删除(即使有终端用户正在使用)
""" """
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
config_id=resolve_config_id(config_id, db) config_id = resolve_config_id(config_id, db)
# 检查用户是否已选择工作空间 # 检查用户是否已选择工作空间
if workspace_id is None: if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间") api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
@@ -203,9 +202,9 @@ def delete_config(
@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc @router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc
def update_config( def update_config(
payload: ConfigUpdate, payload: ConfigUpdate,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
payload.config_id = resolve_config_id(payload.config_id, db) payload.config_id = resolve_config_id(payload.config_id, db)
@@ -217,7 +216,8 @@ def update_config(
# 校验至少有一个字段需要更新 # 校验至少有一个字段需要更新
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None: if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段") api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空") return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段",
"config_name, config_desc, scene_id 均为空")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}") api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
try: try:
@@ -231,9 +231,9 @@ def update_config(
@router.post("/update_config_extracted", response_model=ApiResponse) # 更新数据库中的部分内容 所有业务字段均可选 @router.post("/update_config_extracted", response_model=ApiResponse) # 更新数据库中的部分内容 所有业务字段均可选
def update_config_extracted( def update_config_extracted(
payload: ConfigUpdateExtracted, payload: ConfigUpdateExtracted,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
payload.config_id = resolve_config_id(payload.config_id, db) payload.config_id = resolve_config_id(payload.config_id, db)
@@ -256,11 +256,11 @@ def update_config_extracted(
# 遗忘引擎配置接口已迁移到 memory_forget_controller.py # 遗忘引擎配置接口已迁移到 memory_forget_controller.py
# 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config # 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除 @router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
def read_config_extracted( def read_config_extracted(
config_id: UUID | int, config_id: UUID | int,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
config_id = resolve_config_id(config_id, db) config_id = resolve_config_id(config_id, db)
@@ -278,10 +278,11 @@ def read_config_extracted(
api_logger.error(f"Read config extracted failed: {str(e)}") api_logger.error(f"Read config extracted failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e))
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
def read_all_config( def read_all_config(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
@@ -303,10 +304,10 @@ def read_all_config(
@router.post("/pilot_run", response_model=None) @router.post("/pilot_run", response_model=None)
async def pilot_run( async def pilot_run(
payload: ConfigPilotRun, payload: ConfigPilotRun,
language_type: str = Header(default=None, alias="X-Language-Type"), language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> StreamingResponse: ) -> StreamingResponse:
# 使用集中化的语言校验 # 使用集中化的语言校验
language = get_language_from_header(language_type) language = get_language_from_header(language_type)
@@ -333,9 +334,9 @@ async def pilot_run(
@router.get("/search/kb_type_distribution", response_model=ApiResponse) @router.get("/search/kb_type_distribution", response_model=ApiResponse)
async def get_kb_type_distribution( async def get_kb_type_distribution(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"KB type distribution requested for end_user_id: {end_user_id}") api_logger.info(f"KB type distribution requested for end_user_id: {end_user_id}")
try: try:
result = await kb_type_distribution(end_user_id) result = await kb_type_distribution(end_user_id)
@@ -347,9 +348,9 @@ async def get_kb_type_distribution(
@router.get("/search/dialogue", response_model=ApiResponse) @router.get("/search/dialogue", response_model=ApiResponse)
async def search_dialogues_num( async def search_dialogues_num(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Search dialogue requested for end_user_id: {end_user_id}") api_logger.info(f"Search dialogue requested for end_user_id: {end_user_id}")
try: try:
result = await search_dialogue(end_user_id) result = await search_dialogue(end_user_id)
@@ -361,9 +362,9 @@ async def search_dialogues_num(
@router.get("/search/chunk", response_model=ApiResponse) @router.get("/search/chunk", response_model=ApiResponse)
async def search_chunks_num( async def search_chunks_num(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Search chunk requested for end_user_id: {end_user_id}") api_logger.info(f"Search chunk requested for end_user_id: {end_user_id}")
try: try:
result = await search_chunk(end_user_id) result = await search_chunk(end_user_id)
@@ -375,9 +376,9 @@ async def search_chunks_num(
@router.get("/search/statement", response_model=ApiResponse) @router.get("/search/statement", response_model=ApiResponse)
async def search_statements_num( async def search_statements_num(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Search statement requested for end_user_id: {end_user_id}") api_logger.info(f"Search statement requested for end_user_id: {end_user_id}")
try: try:
result = await search_statement(end_user_id) result = await search_statement(end_user_id)
@@ -389,9 +390,9 @@ async def search_statements_num(
@router.get("/search/entity", response_model=ApiResponse) @router.get("/search/entity", response_model=ApiResponse)
async def search_entities_num( async def search_entities_num(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Search entity requested for end_user_id: {end_user_id}") api_logger.info(f"Search entity requested for end_user_id: {end_user_id}")
try: try:
result = await search_entity(end_user_id) result = await search_entity(end_user_id)
@@ -403,9 +404,9 @@ async def search_entities_num(
@router.get("/search", response_model=ApiResponse) @router.get("/search", response_model=ApiResponse)
async def search_all_num( async def search_all_num(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Search all requested for end_user_id: {end_user_id}") api_logger.info(f"Search all requested for end_user_id: {end_user_id}")
try: try:
result = await search_all(end_user_id) result = await search_all(end_user_id)
@@ -417,9 +418,9 @@ async def search_all_num(
@router.get("/search/detials", response_model=ApiResponse) @router.get("/search/detials", response_model=ApiResponse)
async def search_entities_detials( async def search_entities_detials(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Search details requested for end_user_id: {end_user_id}") api_logger.info(f"Search details requested for end_user_id: {end_user_id}")
try: try:
result = await search_detials(end_user_id) result = await search_detials(end_user_id)
@@ -431,9 +432,9 @@ async def search_entities_detials(
@router.get("/search/edges", response_model=ApiResponse) @router.get("/search/edges", response_model=ApiResponse)
async def search_entity_edges( async def search_entity_edges(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Search edges requested for end_user_id: {end_user_id}") api_logger.info(f"Search edges requested for end_user_id: {end_user_id}")
try: try:
result = await search_edges(end_user_id) result = await search_edges(end_user_id)
@@ -443,14 +444,12 @@ async def search_entity_edges(
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse) @router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
async def get_hot_memory_tags_api( async def get_hot_memory_tags_api(
limit: int = 10, limit: int = 10,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
""" """
获取热门记忆标签带Redis缓存 获取热门记忆标签带Redis缓存
@@ -505,8 +504,8 @@ async def get_hot_memory_tags_api(
@router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse) @router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse)
async def clear_hot_memory_tags_cache( async def clear_hot_memory_tags_cache(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
""" """
清除热门标签缓存 清除热门标签缓存
@@ -543,7 +542,7 @@ async def clear_hot_memory_tags_cache(
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse) @router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
async def get_recent_activity_stats_api( async def get_recent_activity_stats_api(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None
api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}") api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}")
@@ -553,4 +552,3 @@ async def get_recent_activity_stats_api(
except Exception as e: except Exception as e:
api_logger.error(f"Recent activity stats failed: {str(e)}") api_logger.error(f"Recent activity stats failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))

View File

@@ -42,6 +42,7 @@ def get_model_strategies():
@router.get("", response_model=ApiResponse) @router.get("", response_model=ApiResponse)
def get_model_list( def get_model_list(
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING"), type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING"),
capability: Optional[list[str]] = Query(None, description="能力筛选(支持多个,如 ?capability=chat 或 ?capability=chat, embedding"),
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"), provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
is_active: Optional[bool] = Query(None, description="激活状态筛选"), is_active: Optional[bool] = Query(None, description="激活状态筛选"),
is_public: Optional[bool] = Query(None, description="公开状态筛选"), is_public: Optional[bool] = Query(None, description="公开状态筛选"),
@@ -74,10 +75,21 @@ def get_model_list(
unique_flat_type = list(dict.fromkeys(flat_type)) unique_flat_type = list(dict.fromkeys(flat_type))
type_list = [ModelType(t.lower()) for t in unique_flat_type] type_list = [ModelType(t.lower()) for t in unique_flat_type]
capability_list = []
if capability is not None:
flat_capability = []
for item in capability:
split_items = [c.strip() for c in item.split(', ') if c.strip()]
flat_capability.extend(split_items)
unique_flat_capability = list(dict.fromkeys(flat_capability))
capability_list = unique_flat_capability
api_logger.error(f"获取模型type_list: {type_list}") api_logger.error(f"获取模型type_list: {type_list}")
query = model_schema.ModelConfigQuery( query = model_schema.ModelConfigQuery(
type=type_list, type=type_list,
provider=provider, provider=provider,
capability=capability_list,
is_active=is_active, is_active=is_active,
is_public=is_public, is_public=is_public,
search=search, search=search,

View File

@@ -5,7 +5,7 @@
from typing import Optional from typing import Optional
import datetime import datetime
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from fastapi import APIRouter, Depends,Header from fastapi import APIRouter, Depends, Header
from app.db import get_db from app.db import get_db
from app.core.language_utils import get_language_from_header from app.core.language_utils import get_language_from_header
@@ -19,7 +19,7 @@ from app.services.user_memory_service import (
analytics_graph_data, analytics_graph_data,
analytics_community_graph_data, analytics_community_graph_data,
) )
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction from app.services.memory_entity_relationship_service import MemoryEntityService, MemoryEmotion, MemoryInteraction
from app.schemas.response_schema import ApiResponse from app.schemas.response_schema import ApiResponse
from app.schemas.memory_storage_schema import GenerateCacheRequest from app.schemas.memory_storage_schema import GenerateCacheRequest
from app.repositories.workspace_repository import WorkspaceRepository from app.repositories.workspace_repository import WorkspaceRepository
@@ -45,9 +45,9 @@ router = APIRouter(
@router.get("/analytics/memory_insight/report", response_model=ApiResponse) @router.get("/analytics/memory_insight/report", response_model=ApiResponse)
async def get_memory_insight_report_api( async def get_memory_insight_report_api(
end_user_id: str, end_user_id: str,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
""" """
获取缓存的记忆洞察报告 获取缓存的记忆洞察报告
@@ -73,10 +73,10 @@ async def get_memory_insight_report_api(
@router.get("/analytics/user_summary", response_model=ApiResponse) @router.get("/analytics/user_summary", response_model=ApiResponse)
async def get_user_summary_api( async def get_user_summary_api(
end_user_id: str, end_user_id: str,
language_type: str = Header(default=None, alias="X-Language-Type"), language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
""" """
获取缓存的用户摘要 获取缓存的用户摘要
@@ -102,7 +102,7 @@ async def get_user_summary_api(
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}") api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
try: try:
# 调用服务层获取缓存数据 # 调用服务层获取缓存数据
result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language) result = await user_memory_service.get_cached_user_summary(db, end_user_id, model_id, language)
if result["is_cached"]: if result["is_cached"]:
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}") api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
@@ -117,10 +117,10 @@ async def get_user_summary_api(
@router.post("/analytics/generate_cache", response_model=ApiResponse) @router.post("/analytics/generate_cache", response_model=ApiResponse)
async def generate_cache_api( async def generate_cache_api(
request: GenerateCacheRequest, request: GenerateCacheRequest,
language_type: str = Header(default=None, alias="X-Language-Type"), language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
""" """
手动触发缓存生成 手动触发缓存生成
@@ -155,10 +155,12 @@ async def generate_cache_api(
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}") api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
# 生成记忆洞察 # 生成记忆洞察
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id, language=language) insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id,
language=language)
# 生成用户摘要 # 生成用户摘要
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id, language=language) summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id,
language=language)
# 构建响应 # 构建响应
result = { result = {
@@ -209,9 +211,9 @@ async def generate_cache_api(
@router.get("/analytics/node_statistics", response_model=ApiResponse) @router.get("/analytics/node_statistics", response_model=ApiResponse)
async def get_node_statistics_api( async def get_node_statistics_api(
end_user_id: str, end_user_id: str,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
@@ -220,7 +222,8 @@ async def get_node_statistics_api(
api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间") api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}") api_logger.info(
f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
try: try:
# 调用新的记忆类型统计函数 # 调用新的记忆类型统计函数
@@ -228,21 +231,23 @@ async def get_node_statistics_api(
# 计算总数用于日志 # 计算总数用于日志
total_count = sum(item["count"] for item in result) total_count = sum(item["count"] for item in result)
api_logger.info(f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}") api_logger.info(
f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
return success(data=result, msg="查询成功") return success(data=result, msg="查询成功")
except Exception as e: except Exception as e:
api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}") api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e))
@router.get("/analytics/graph_data", response_model=ApiResponse) @router.get("/analytics/graph_data", response_model=ApiResponse)
async def get_graph_data_api( async def get_graph_data_api(
end_user_id: str, end_user_id: str,
node_types: Optional[str] = None, node_types: Optional[str] = None,
limit: int = 100, limit: int = 100,
depth: int = 1, depth: int = 1,
center_node_id: Optional[str] = None, center_node_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
@@ -298,9 +303,9 @@ async def get_graph_data_api(
@router.get("/analytics/community_graph", response_model=ApiResponse) @router.get("/analytics/community_graph", response_model=ApiResponse)
async def get_community_graph_data_api( async def get_community_graph_data_api(
end_user_id: str, end_user_id: str,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
@@ -334,9 +339,9 @@ async def get_community_graph_data_api(
@router.get("/read_end_user/profile", response_model=ApiResponse) @router.get("/read_end_user/profile", response_model=ApiResponse)
async def get_end_user_profile( async def get_end_user_profile(
end_user_id: str, end_user_id: str,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
workspace_repo = WorkspaceRepository(db) workspace_repo = WorkspaceRepository(db)
@@ -385,9 +390,9 @@ async def get_end_user_profile(
@router.post("/updated_end_user/profile", response_model=ApiResponse) @router.post("/updated_end_user/profile", response_model=ApiResponse)
async def update_end_user_profile( async def update_end_user_profile(
profile_update: EndUserProfileUpdate, profile_update: EndUserProfileUpdate,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
""" """
更新终端用户的基本信息 更新终端用户的基本信息
@@ -427,15 +432,18 @@ async def update_end_user_profile(
# 只有未预期的错误才使用 INTERNAL_ERROR # 只有未预期的错误才使用 INTERNAL_ERROR
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg) return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg)
@router.get("/memory_space/timeline_memories", response_model=ApiResponse) @router.get("/memory_space/timeline_memories", response_model=ApiResponse)
async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str = Header(default=None, alias="X-Language-Type"), async def memory_space_timeline_of_shared_memories(
current_user: User = Depends(get_current_user), id: str, label: str,
db: Session = Depends(get_db), language_type: str = Header(default=None, alias="X-Language-Type"),
): current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
# 使用集中化的语言校验 # 使用集中化的语言校验
language = get_language_from_header(language_type) language = get_language_from_header(language_type)
workspace_id=current_user.current_workspace_id workspace_id = current_user.current_workspace_id
workspace_repo = WorkspaceRepository(db) workspace_repo = WorkspaceRepository(db)
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id) workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
@@ -447,11 +455,13 @@ async def memory_space_timeline_of_shared_memories(id: str, label: str,language_
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language) timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language)
return success(data=timeline_memories_result, msg="共同记忆时间线") return success(data=timeline_memories_result, msg="共同记忆时间线")
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse) @router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
async def memory_space_relationship_evolution(id: str, label: str, async def memory_space_relationship_evolution(id: str, label: str,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
try: try:
api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}") api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}")

View File

@@ -598,8 +598,10 @@ class LangChainAgent:
for msg in reversed(output_messages): for msg in reversed(output_messages):
if isinstance(msg, AIMessage): if isinstance(msg, AIMessage):
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", total_tokens = response_meta.get("token_usage", {}).get(
0) if response_meta else 0 "total_tokens",
0
) if response_meta else 0
yield total_tokens yield total_tokens
break break
if memory_flag: if memory_flag:

View File

@@ -11,7 +11,7 @@ async def get_chunked_dialogs(
chunker_strategy: str = "RecursiveChunker", chunker_strategy: str = "RecursiveChunker",
end_user_id: str = "group_1", end_user_id: str = "group_1",
messages: list = None, messages: list = None,
ref_id: str = "wyl_20251027", ref_id: str = "",
config_id: str = None config_id: str = None
) -> List[DialogData]: ) -> List[DialogData]:
"""Generate chunks from structured messages using the specified chunker strategy. """Generate chunks from structured messages using the specified chunker strategy.
@@ -40,12 +40,13 @@ async def get_chunked_dialogs(
role = msg['role'] role = msg['role']
content = msg['content'] content = msg['content']
files = msg.get("file_content", [])
if role not in ['user', 'assistant']: if role not in ['user', 'assistant']:
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}") raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
if content.strip(): if content.strip():
conversation_messages.append(ConversationMessage(role=role, msg=content.strip())) conversation_messages.append(ConversationMessage(role=role, msg=content.strip(), files=files))
if not conversation_messages: if not conversation_messages:
raise ValueError("Message list cannot be empty after filtering") raise ValueError("Message list cannot be empty after filtering")

View File

@@ -6,6 +6,7 @@ pipeline. Only MemoryConfig is needed - clients are constructed internally.
""" """
import asyncio import asyncio
import time import time
import uuid
from datetime import datetime from datetime import datetime
from dotenv import load_dotenv from dotenv import load_dotenv
@@ -13,7 +14,8 @@ from dotenv import load_dotenv
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import memory_summary_generation from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import \
memory_summary_generation
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.memory.utils.log.logging_utils import log_time from app.core.memory.utils.log.logging_utils import log_time
from app.db import get_db_context from app.db import get_db_context
@@ -23,18 +25,17 @@ from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import MemoryConfig from app.schemas.memory_config_schema import MemoryConfig
load_dotenv() load_dotenv()
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
async def write( async def write(
end_user_id: str, end_user_id: str,
memory_config: MemoryConfig, memory_config: MemoryConfig,
messages: list, messages: list,
ref_id: str = "wyl20251027", ref_id: str = "",
language: str = "zh", language: str = "zh",
) -> None: ) -> None:
""" """
Execute the complete knowledge extraction pipeline. Execute the complete knowledge extraction pipeline.
@@ -43,9 +44,11 @@ async def write(
end_user_id: Group identifier end_user_id: Group identifier
memory_config: MemoryConfig object containing all configuration memory_config: MemoryConfig object containing all configuration
messages: Structured message list [{"role": "user", "content": "..."}, ...] messages: Structured message list [{"role": "user", "content": "..."}, ...]
ref_id: Reference ID, defaults to "wyl20251027" ref_id: Reference ID, defaults to ""
language: 语言类型 ("zh" 中文, "en" 英文),默认中文 language: 语言类型 ("zh" 中文, "en" 英文),默认中文
""" """
if not ref_id:
ref_id = uuid.uuid4().hex
# Extract config values # Extract config values
embedding_model_id = str(memory_config.embedding_model_id) embedding_model_id = str(memory_config.embedding_model_id)
chunker_strategy = memory_config.chunker_strategy chunker_strategy = memory_config.chunker_strategy
@@ -135,9 +138,11 @@ async def write(
all_chunk_nodes, all_chunk_nodes,
all_statement_nodes, all_statement_nodes,
all_entity_nodes, all_entity_nodes,
all_perceptual_nodes,
all_statement_chunk_edges, all_statement_chunk_edges,
all_statement_entity_edges, all_statement_entity_edges,
all_entity_entity_edges, all_entity_entity_edges,
all_perceptual_edges,
all_dedup_details, all_dedup_details,
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False) ) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
@@ -162,9 +167,11 @@ async def write(
chunk_nodes=all_chunk_nodes, chunk_nodes=all_chunk_nodes,
statement_nodes=all_statement_nodes, statement_nodes=all_statement_nodes,
entity_nodes=all_entity_nodes, entity_nodes=all_entity_nodes,
perceptual_nodes=all_perceptual_nodes,
statement_chunk_edges=all_statement_chunk_edges, statement_chunk_edges=all_statement_chunk_edges,
statement_entity_edges=all_statement_entity_edges, statement_entity_edges=all_statement_entity_edges,
entity_edges=all_entity_entity_edges, entity_edges=all_entity_entity_edges,
perceptual_edges=all_perceptual_edges,
connector=neo4j_connector, connector=neo4j_connector,
) )
if success: if success:
@@ -173,7 +180,8 @@ async def write(
await _trigger_clustering_sync( await _trigger_clustering_sync(
all_entity_nodes, all_entity_nodes,
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None, llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None, embedding_model_id=str(
memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
) )
break break
else: else:
@@ -208,9 +216,8 @@ async def write(
summaries = await memory_summary_generation( summaries = await memory_summary_generation(
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language
) )
ms_connector = Neo4jConnector()
try: try:
ms_connector = Neo4jConnector()
await add_memory_summary_nodes(summaries, ms_connector) await add_memory_summary_nodes(summaries, ms_connector)
await add_memory_summary_statement_edges(summaries, ms_connector) await add_memory_summary_statement_edges(summaries, ms_connector)
finally: finally:

View File

@@ -1,10 +1,10 @@
from typing import Any, List
import re
import os
import asyncio import asyncio
import json import json
import numpy as np
import logging import logging
import os
from typing import Any, List
import numpy as np
# Fix tokenizer parallelism warning # Fix tokenizer parallelism warning
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -246,6 +246,7 @@ class ChunkerClient:
"total_sub_chunks": len(sub_chunks), "total_sub_chunks": len(sub_chunks),
"chunker_strategy": self.chunker_config.chunker_strategy, "chunker_strategy": self.chunker_config.chunker_strategy,
}, },
files=msg.files
) )
dialogue.chunks.append(chunk) dialogue.chunks.append(chunk)
else: else:
@@ -258,6 +259,7 @@ class ChunkerClient:
"message_role": msg.role, "message_role": msg.role,
"chunker_strategy": self.chunker_config.chunker_strategy, "chunker_strategy": self.chunker_config.chunker_strategy,
}, },
files=msg.files
) )
dialogue.chunks.append(chunk) dialogue.chunks.append(chunk)

View File

@@ -114,7 +114,7 @@ class Edge(BaseModel):
end_user_id: str = Field(..., description="The end user ID of the edge.") end_user_id: str = Field(..., description="The end user ID of the edge.")
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.") run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.") created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.") expired_at: Optional[datetime] = Field(default=None, description="The expired time of the edge from system perspective.")
class ChunkEdge(Edge): class ChunkEdge(Edge):
@@ -175,6 +175,12 @@ class EntityEntityEdge(Edge):
return parse_historical_datetime(v) return parse_historical_datetime(v)
class PerceptualEdge(Edge):
"""Edge connecting perceptual nodes to their source chunks
"""
pass
class Node(BaseModel): class Node(BaseModel):
"""Base class for all graph nodes in the knowledge graph. """Base class for all graph nodes in the knowledge graph.
@@ -206,7 +212,8 @@ class DialogueNode(Node):
ref_id: str = Field(..., description="Reference identifier of the dialog") ref_id: str = Field(..., description="Reference identifier of the dialog")
content: str = Field(..., description="Dialogue content") content: str = Field(..., description="Dialogue content")
dialog_embedding: Optional[List[float]] = Field(None, description="Dialog embedding vector") dialog_embedding: Optional[List[float]] = Field(None, description="Dialog embedding vector")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialogue (integer or string)") config_id: Optional[int | str] = Field(None,
description="Configuration ID used to process this dialogue (integer or string)")
class StatementNode(Node): class StatementNode(Node):
@@ -281,7 +288,8 @@ class StatementNode(Node):
statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector") statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector")
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector") chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement") connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)") config_id: Optional[int | str] = Field(None,
description="Configuration ID used to process this statement (integer or string)")
# ACT-R Memory Activation Properties # ACT-R Memory Activation Properties
importance_score: float = Field( importance_score: float = Field(
@@ -416,7 +424,8 @@ class ExtractedEntityNode(Node):
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_summary: str = Field(default="", description="Summary of the fact about this entity") # fact_summary: str = Field(default="", description="Summary of the fact about this entity")
connect_strength: str = Field(..., description="Strong VS Weak about this entity") connect_strength: str = Field(..., description="Strong VS Weak about this entity")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)") config_id: Optional[int | str] = Field(None,
description="Configuration ID used to process this entity (integer or string)")
# ACT-R Memory Activation Properties # ACT-R Memory Activation Properties
importance_score: float = Field( importance_score: float = Field(
@@ -453,7 +462,7 @@ class ExtractedEntityNode(Node):
@field_validator('aliases', mode='before') @field_validator('aliases', mode='before')
@classmethod @classmethod
def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段 def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
"""Validate and clean aliases field using utility function. """Validate and clean aliases field using utility function.
This validator ensures that the aliases field is always a valid list of strings. This validator ensures that the aliases field is always a valid list of strings.
@@ -507,7 +516,8 @@ class MemorySummaryNode(Node):
memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory") memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory")
summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary") summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary")
metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary") metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)") config_id: Optional[int | str] = Field(None,
description="Configuration ID used to process this summary (integer or string)")
# ACT-R Forgetting Engine Properties # ACT-R Forgetting Engine Properties
original_statement_id: Optional[str] = Field( original_statement_id: Optional[str] = Field(
@@ -549,3 +559,18 @@ class MemorySummaryNode(Node):
ge=0, ge=0,
description="Total number of times this node has been accessed (reset to 1 on creation)" description="Total number of times this node has been accessed (reset to 1 on creation)"
) )
class PerceptualNode(Node):
"""Node representing a multimodal message in the knowledge graph.
"""
perceptual_type: int
file_path: str
file_name: str
file_ext: str
summary: str
keywords: list[str]
topic: str
domain: str
file_type: str
summary_embedding: list[float] | None

View File

@@ -30,6 +30,7 @@ class ConversationMessage(BaseModel):
""" """
role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').") role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').")
msg: str = Field(..., description="The text content of the message.") msg: str = Field(..., description="The text content of the message.")
files: list[tuple] = Field(default_factory=list, description="The file content of the message", exclude=True)
class TemporalValidityRange(BaseModel): class TemporalValidityRange(BaseModel):
@@ -130,7 +131,8 @@ class Chunk(BaseModel):
content: str = Field(..., description="The content of the chunk as a string.") content: str = Field(..., description="The content of the chunk as a string.")
speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).") speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).")
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.") statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.") files: list[tuple] = Field(default_factory=list, description="List of files in the chunk.")
chunk_embedding: Optional[List[float]] = Field(default=None, description="The embedding vector of the chunk.")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.") metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
@classmethod @classmethod

View File

@@ -25,17 +25,17 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def dedup_layers_and_merge_and_return( async def dedup_layers_and_merge_and_return(
dialogue_nodes: List[DialogueNode], dialogue_nodes: List[DialogueNode],
chunk_nodes: List[ChunkNode], chunk_nodes: List[ChunkNode],
statement_nodes: List[StatementNode], statement_nodes: List[StatementNode],
entity_nodes: List[ExtractedEntityNode], entity_nodes: List[ExtractedEntityNode],
statement_chunk_edges: List[StatementChunkEdge], statement_chunk_edges: List[StatementChunkEdge],
statement_entity_edges: List[StatementEntityEdge], statement_entity_edges: List[StatementEntityEdge],
entity_entity_edges: List[EntityEntityEdge], entity_entity_edges: List[EntityEntityEdge],
dialog_data_list: List[DialogData], dialog_data_list: List[DialogData],
pipeline_config: ExtractionPipelineConfig, pipeline_config: ExtractionPipelineConfig,
connector: Optional[Neo4jConnector] = None, connector: Optional[Neo4jConnector] = None,
llm_client = None, llm_client=None,
) -> Tuple[ ) -> Tuple[
List[DialogueNode], List[DialogueNode],
List[ChunkNode], List[ChunkNode],
@@ -44,7 +44,7 @@ async def dedup_layers_and_merge_and_return(
List[StatementChunkEdge], List[StatementChunkEdge],
List[StatementEntityEdge], List[StatementEntityEdge],
List[EntityEntityEdge], List[EntityEntityEdge],
dict, # 新增:返回去重详情 dict
]: ]:
""" """
执行两层实体去重与融合: 执行两层实体去重与融合:

View File

@@ -32,10 +32,11 @@ from app.core.memory.models.graph_models import (
StatementChunkEdge, StatementChunkEdge,
StatementEntityEdge, StatementEntityEdge,
StatementNode, StatementNode,
PerceptualEdge,
PerceptualNode
) )
from app.core.memory.models.message_models import DialogData from app.core.memory.models.message_models import DialogData
from app.core.memory.models.ontology_extraction_models import OntologyTypeList from app.core.memory.models.ontology_extraction_models import OntologyTypeList
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
from app.core.memory.models.variate_config import ( from app.core.memory.models.variate_config import (
ExtractionPipelineConfig, ExtractionPipelineConfig,
) )
@@ -46,7 +47,6 @@ from app.core.memory.storage_services.extraction_engine.knowledge_extraction.emb
embedding_generation, embedding_generation,
generate_entity_embeddings_from_triplets, generate_entity_embeddings_from_triplets,
) )
# 导入各个提取模块 # 导入各个提取模块
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.statement_extraction import ( from app.core.memory.storage_services.extraction_engine.knowledge_extraction.statement_extraction import (
StatementExtractor, StatementExtractor,
@@ -90,16 +90,16 @@ class ExtractionOrchestrator:
""" """
def __init__( def __init__(
self, self,
llm_client: LLMClient, llm_client: LLMClient,
embedder_client: OpenAIEmbedderClient, embedder_client: OpenAIEmbedderClient,
connector: Neo4jConnector, connector: Neo4jConnector,
config: Optional[ExtractionPipelineConfig] = None, config: Optional[ExtractionPipelineConfig] = None,
progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None, progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None,
embedding_id: Optional[str] = None, embedding_id: Optional[str] = None,
ontology_types: Optional[OntologyTypeList] = None, ontology_types: Optional[OntologyTypeList] = None,
enable_general_types: bool = True, enable_general_types: bool = True,
language: str = "zh", language: str = "zh",
): ):
""" """
初始化流水线编排器 初始化流水线编排器
@@ -157,19 +157,27 @@ class ExtractionOrchestrator:
llm_client=llm_client, llm_client=llm_client,
config=self.config.statement_extraction, config=self.config.statement_extraction,
) )
self.triplet_extractor = TripletExtractor(llm_client=llm_client,ontology_types=self.ontology_types, language=language) self.triplet_extractor = TripletExtractor(llm_client=llm_client, ontology_types=self.ontology_types,
language=language)
self.temporal_extractor = TemporalExtractor(llm_client=llm_client) self.temporal_extractor = TemporalExtractor(llm_client=llm_client)
logger.info("ExtractionOrchestrator 初始化完成") logger.info("ExtractionOrchestrator 初始化完成")
async def run( async def run(
self, self,
dialog_data_list: List[DialogData], dialog_data_list: List[DialogData],
is_pilot_run: bool = False, is_pilot_run: bool = False,
) -> Tuple[ ) -> tuple[
Tuple[List[DialogueNode], List[ChunkNode], List[StatementNode]], list[DialogueNode],
Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], list[ChunkNode],
Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], list[StatementNode],
list[ExtractedEntityNode],
list[PerceptualNode],
list[StatementChunkEdge],
list[StatementEntityEdge],
list[EntityEntityEdge],
list[PerceptualEdge],
dict
]: ]:
""" """
运行完整的知识提取流水线(优化版:并行执行) 运行完整的知识提取流水线(优化版:并行执行)
@@ -208,7 +216,6 @@ class ExtractionOrchestrator:
for dialog in dialog_data_list: for dialog in dialog_data_list:
for chunk in dialog.chunks: for chunk in dialog.chunks:
all_statements_list.extend(chunk.statements) all_statements_list.extend(chunk.statements)
len(all_statements_list)
# 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成 # 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成") logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成")
@@ -230,10 +237,6 @@ class ExtractionOrchestrator:
all_entities_list.extend(triplet_info.entities) all_entities_list.extend(triplet_info.entities)
all_triplets_list.extend(triplet_info.triplets) all_triplets_list.extend(triplet_info.triplets)
len(all_entities_list)
len(all_triplets_list)
sum(len(temporal_map) for temporal_map in temporal_maps)
# 步骤 3: 生成实体嵌入(依赖三元组提取结果) # 步骤 3: 生成实体嵌入(依赖三元组提取结果)
logger.info("步骤 3/6: 生成实体嵌入") logger.info("步骤 3/6: 生成实体嵌入")
triplet_maps = await self._generate_entity_embeddings(triplet_maps) triplet_maps = await self._generate_entity_embeddings(triplet_maps)
@@ -260,9 +263,11 @@ class ExtractionOrchestrator:
chunk_nodes, chunk_nodes,
statement_nodes, statement_nodes,
entity_nodes, entity_nodes,
perceptual_nodes,
statement_chunk_edges, statement_chunk_edges,
statement_entity_edges, statement_entity_edges,
entity_entity_edges, entity_entity_edges,
perceptual_edges
) = await self._create_nodes_and_edges(dialog_data_list) ) = await self._create_nodes_and_edges(dialog_data_list)
# 导出去重前的测试输入文档(试运行和正式模式都需要,用于生成结果汇总) # 导出去重前的测试输入文档(试运行和正式模式都需要,用于生成结果汇总)
@@ -276,7 +281,16 @@ class ExtractionOrchestrator:
# 注意deduplication 消息已在创建节点和边完成后立即发送 # 注意deduplication 消息已在创建节点和边完成后立即发送
result = await self._run_dedup_and_write_summary( (
dialogue_nodes,
chunk_nodes,
statement_nodes,
entity_nodes,
statement_chunk_edges,
statement_entity_edges,
entity_entity_edges,
dialog_data_list,
) = await self._run_dedup_and_write_summary(
dialogue_nodes, dialogue_nodes,
chunk_nodes, chunk_nodes,
statement_nodes, statement_nodes,
@@ -287,17 +301,26 @@ class ExtractionOrchestrator:
dialog_data_list, dialog_data_list,
) )
logger.info(f"知识提取流水线运行完成({mode_str}") logger.info(f"知识提取流水线运行完成({mode_str}")
return result return (
dialogue_nodes,
chunk_nodes,
statement_nodes,
entity_nodes,
perceptual_nodes,
statement_chunk_edges,
statement_entity_edges,
entity_entity_edges,
perceptual_edges,
dialog_data_list,
)
except Exception as e: except Exception as e:
logger.error(f"知识提取流水线运行失败: {e}", exc_info=True) logger.error(f"知识提取流水线运行失败: {e}", exc_info=True)
raise raise
async def _extract_statements( async def _extract_statements(
self, dialog_data_list: List[DialogData] self, dialog_data_list: List[DialogData]
) -> List[DialogData]: ) -> List[DialogData]:
""" """
从对话中提取陈述句(流式输出版本:边提取边发送进度) 从对话中提取陈述句(流式输出版本:边提取边发送进度)
@@ -395,7 +418,7 @@ class ExtractionOrchestrator:
return dialog_data_list return dialog_data_list
async def _extract_triplets( async def _extract_triplets(
self, dialog_data_list: List[DialogData] self, dialog_data_list: List[DialogData]
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
从对话中提取三元组(流式输出版本:边提取边发送进度) 从对话中提取三元组(流式输出版本:边提取边发送进度)
@@ -478,7 +501,7 @@ class ExtractionOrchestrator:
return triplet_maps return triplet_maps
async def _extract_temporal( async def _extract_temporal(
self, dialog_data_list: List[DialogData] self, dialog_data_list: List[DialogData]
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
从对话中提取时间信息(流式输出版本:边提取边发送进度) 从对话中提取时间信息(流式输出版本:边提取边发送进度)
@@ -585,7 +608,7 @@ class ExtractionOrchestrator:
return temporal_maps return temporal_maps
async def _extract_emotions( async def _extract_emotions(
self, dialog_data_list: List[DialogData] self, dialog_data_list: List[DialogData]
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
从对话中提取情绪信息(仅针对用户消息,全局陈述句级并行) 从对话中提取情绪信息(仅针对用户消息,全局陈述句级并行)
@@ -706,7 +729,7 @@ class ExtractionOrchestrator:
return emotion_maps return emotion_maps
async def _parallel_extract_and_embed( async def _parallel_extract_and_embed(
self, dialog_data_list: List[DialogData] self, dialog_data_list: List[DialogData]
) -> Tuple[ ) -> Tuple[
List[Dict[str, Any]], List[Dict[str, Any]],
List[Dict[str, Any]], List[Dict[str, Any]],
@@ -777,7 +800,7 @@ class ExtractionOrchestrator:
) )
async def _generate_basic_embeddings( async def _generate_basic_embeddings(
self, dialog_data_list: List[DialogData] self, dialog_data_list: List[DialogData]
) -> Tuple[List[Dict[str, List[float]]], List[Dict[str, List[float]]], List[List[float]]]: ) -> Tuple[List[Dict[str, List[float]]], List[Dict[str, List[float]]], List[List[float]]]:
""" """
生成基础嵌入向量(陈述句、分块、对话) 生成基础嵌入向量(陈述句、分块、对话)
@@ -836,7 +859,7 @@ class ExtractionOrchestrator:
) )
async def _generate_entity_embeddings( async def _generate_entity_embeddings(
self, triplet_maps: List[Dict[str, Any]] self, triplet_maps: List[Dict[str, Any]]
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
生成实体嵌入向量 生成实体嵌入向量
@@ -874,17 +897,15 @@ class ExtractionOrchestrator:
logger.error(f"实体嵌入生成失败: {e}", exc_info=True) logger.error(f"实体嵌入生成失败: {e}", exc_info=True)
return triplet_maps return triplet_maps
async def _assign_extracted_data( async def _assign_extracted_data(
self, self,
dialog_data_list: List[DialogData], dialog_data_list: List[DialogData],
temporal_maps: List[Dict[str, Any]], temporal_maps: List[Dict[str, Any]],
triplet_maps: List[Dict[str, Any]], triplet_maps: List[Dict[str, Any]],
emotion_maps: List[Dict[str, Any]], emotion_maps: List[Dict[str, Any]],
statement_embedding_maps: List[Dict[str, List[float]]], statement_embedding_maps: List[Dict[str, List[float]]],
chunk_embedding_maps: List[Dict[str, List[float]]], chunk_embedding_maps: List[Dict[str, List[float]]],
dialog_embeddings: List[List[float]], dialog_embeddings: List[List[float]],
) -> List[DialogData]: ) -> List[DialogData]:
""" """
将提取的数据赋值到语句 将提取的数据赋值到语句
@@ -906,12 +927,12 @@ class ExtractionOrchestrator:
# 确保列表长度匹配 # 确保列表长度匹配
expected_length = len(dialog_data_list) expected_length = len(dialog_data_list)
if ( if (
len(temporal_maps) != expected_length len(temporal_maps) != expected_length
or len(triplet_maps) != expected_length or len(triplet_maps) != expected_length
or len(emotion_maps) != expected_length or len(emotion_maps) != expected_length
or len(statement_embedding_maps) != expected_length or len(statement_embedding_maps) != expected_length
or len(chunk_embedding_maps) != expected_length or len(chunk_embedding_maps) != expected_length
or len(dialog_embeddings) != expected_length or len(dialog_embeddings) != expected_length
): ):
logger.warning( logger.warning(
f"数据大小不匹配 - 对话: {len(dialog_data_list)}, " f"数据大小不匹配 - 对话: {len(dialog_data_list)}, "
@@ -999,15 +1020,17 @@ class ExtractionOrchestrator:
return dialog_data_list return dialog_data_list
async def _create_nodes_and_edges( async def _create_nodes_and_edges(
self, dialog_data_list: List[DialogData] self, dialog_data_list: List[DialogData]
) -> Tuple[ ) -> Tuple[
List[DialogueNode], List[DialogueNode],
List[ChunkNode], List[ChunkNode],
List[StatementNode], List[StatementNode],
List[ExtractedEntityNode], List[ExtractedEntityNode],
List[PerceptualNode],
List[StatementChunkEdge], List[StatementChunkEdge],
List[StatementEntityEdge], List[StatementEntityEdge],
List[EntityEntityEdge], List[EntityEntityEdge],
List[PerceptualEdge]
]: ]:
""" """
创建图数据库节点和边 创建图数据库节点和边
@@ -1031,6 +1054,8 @@ class ExtractionOrchestrator:
statement_chunk_edges = [] statement_chunk_edges = []
statement_entity_edges = [] statement_entity_edges = []
entity_entity_edges = [] entity_entity_edges = []
perceptual_nodes = []
perceptual_edges = []
# 用于去重的集合 # 用于去重的集合
entity_id_set = set() entity_id_set = set()
@@ -1074,6 +1099,46 @@ class ExtractionOrchestrator:
metadata=chunk.metadata, metadata=chunk.metadata,
) )
chunk_nodes.append(chunk_node) chunk_nodes.append(chunk_node)
logger.error(f"chunk file: {chunk.files}")
for p, file_type in chunk.files:
meta = p.meta_data or {}
content_meta = meta.get("content", {})
# 生成 summary embedding如果有 embedder_client
summary_embedding = None
if self.embedder_client and p.summary:
try:
summary_embedding = (await self.embedder_client.response([p.summary]))[0]
except Exception as emb_err:
print(f"Failed to embed perceptual summary: {emb_err}")
perceptual = PerceptualNode(
name=f"Perceptual_{p.id}",
**{
"id": str(p.id),
"end_user_id": str(p.end_user_id),
"perceptual_type": p.perceptual_type,
"file_path": p.file_path or "",
"file_name": p.file_name or "",
"file_ext": p.file_ext or "",
"summary": p.summary or "",
"keywords": content_meta.get("keywords", []),
"topic": content_meta.get("topic", ""),
"domain": content_meta.get("domain", ""),
"created_at": p.created_time.isoformat() if p.created_time else None,
"file_type": file_type,
"summary_embedding": summary_embedding,
})
perceptual_nodes.append(perceptual)
perceptual_edges.append(PerceptualEdge(
source=perceptual.id,
target=chunk.id,
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id,
created_at=dialog_data.created_at,
))
# 处理每个陈述句 # 处理每个陈述句
for statement in chunk.statements: for statement in chunk.statements:
@@ -1083,15 +1148,19 @@ class ExtractionOrchestrator:
name=f"Statement_{statement.id}", # 添加必需的 name 字段 name=f"Statement_{statement.id}", # 添加必需的 name 字段
chunk_id=chunk.id, chunk_id=chunk.id,
stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段 stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段
temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段 temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL),
connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段 # 添加必需的 temporal_info 字段
connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong',
# 添加必需的 connect_strength 字段
end_user_id=dialog_data.end_user_id, end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
statement=statement.statement, statement=statement.statement,
speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段 speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段
statement_embedding=statement.statement_embedding, statement_embedding=statement.statement_embedding,
valid_at=statement.temporal_validity.valid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None, valid_at=statement.temporal_validity.valid_at if hasattr(statement,
invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None, 'temporal_validity') and statement.temporal_validity else None,
invalid_at=statement.temporal_validity.invalid_at if hasattr(statement,
'temporal_validity') and statement.temporal_validity else None,
created_at=dialog_data.created_at, created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at, expired_at=dialog_data.expired_at,
config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None, config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None,
@@ -1141,7 +1210,8 @@ class ExtractionOrchestrator:
example=getattr(entity, 'example', ''), # 新增:传递示例字段 example=getattr(entity, 'example', ''), # 新增:传递示例字段
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段 # fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段 connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong',
# 添加必需的 connect_strength 字段
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
name_embedding=getattr(entity, 'name_embedding', None), name_embedding=getattr(entity, 'name_embedding', None),
is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记 is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记
@@ -1248,25 +1318,32 @@ class ExtractionOrchestrator:
chunk_nodes, chunk_nodes,
statement_nodes, statement_nodes,
entity_nodes, entity_nodes,
perceptual_nodes,
statement_chunk_edges, statement_chunk_edges,
statement_entity_edges, statement_entity_edges,
entity_entity_edges, entity_entity_edges,
perceptual_edges
) )
async def _run_dedup_and_write_summary( async def _run_dedup_and_write_summary(
self, self,
dialogue_nodes: List[DialogueNode], dialogue_nodes: List[DialogueNode],
chunk_nodes: List[ChunkNode], chunk_nodes: List[ChunkNode],
statement_nodes: List[StatementNode], statement_nodes: List[StatementNode],
entity_nodes: List[ExtractedEntityNode], entity_nodes: List[ExtractedEntityNode],
statement_chunk_edges: List[StatementChunkEdge], statement_chunk_edges: List[StatementChunkEdge],
statement_entity_edges: List[StatementEntityEdge], statement_entity_edges: List[StatementEntityEdge],
entity_entity_edges: List[EntityEntityEdge], entity_entity_edges: List[EntityEntityEdge],
dialog_data_list: List[DialogData], dialog_data_list: List[DialogData],
) -> Tuple[ ) -> tuple[
Tuple[List[DialogueNode], List[ChunkNode], List[StatementNode]], list[DialogueNode],
Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], list[ChunkNode],
Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], list[StatementNode],
list[ExtractedEntityNode],
list[StatementChunkEdge],
list[StatementEntityEdge],
list[EntityEntityEdge],
dict
]: ]:
""" """
执行两阶段去重并写入汇总 执行两阶段去重并写入汇总
@@ -1415,7 +1492,6 @@ class ExtractionOrchestrator:
len(entity_entity_edges), len(final_entity_entity_edges) len(entity_entity_edges), len(final_entity_entity_edges)
) )
# 写入提取结果汇总(试运行和正式模式都需要生成) # 写入提取结果汇总(试运行和正式模式都需要生成)
try: try:
from app.core.config import settings from app.core.config import settings
@@ -1436,10 +1512,10 @@ class ExtractionOrchestrator:
raise raise
def _save_dedup_details( def _save_dedup_details(
self, self,
dedup_details: Dict[str, Any], dedup_details: Dict[str, Any],
original_entities: List[ExtractedEntityNode], original_entities: List[ExtractedEntityNode],
final_entities: List[ExtractedEntityNode] final_entities: List[ExtractedEntityNode]
): ):
""" """
保存去重消歧的详细记录到实例变量(基于内存数据结构) 保存去重消歧的详细记录到实例变量(基于内存数据结构)
@@ -1537,15 +1613,16 @@ class ExtractionOrchestrator:
except Exception as e: except Exception as e:
logger.debug(f"解析消歧记录失败: {record}, 错误: {e}") logger.debug(f"解析消歧记录失败: {record}, 错误: {e}")
logger.info(f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录") logger.info(
f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录")
except Exception as e: except Exception as e:
logger.error(f"保存去重消歧详情失败: {e}", exc_info=True) logger.error(f"保存去重消歧详情失败: {e}", exc_info=True)
async def _analyze_entity_merges( async def _analyze_entity_merges(
self, self,
original_entities: List[ExtractedEntityNode], original_entities: List[ExtractedEntityNode],
final_entities: List[ExtractedEntityNode] final_entities: List[ExtractedEntityNode]
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
分析实体合并情况,直接使用内存中的合并记录(不再解析日志文件) 分析实体合并情况,直接使用内存中的合并记录(不再解析日志文件)
@@ -1585,9 +1662,9 @@ class ExtractionOrchestrator:
return [] return []
async def _analyze_entity_disambiguation( async def _analyze_entity_disambiguation(
self, self,
original_entities: List[ExtractedEntityNode], original_entities: List[ExtractedEntityNode],
final_entities: List[ExtractedEntityNode] final_entities: List[ExtractedEntityNode]
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
分析实体消歧情况,直接使用内存中的消歧记录(不再解析日志文件) 分析实体消歧情况,直接使用内存中的消歧记录(不再解析日志文件)
@@ -1645,9 +1722,9 @@ class ExtractionOrchestrator:
return type_mapping.get(entity_type, f"{entity_type}实体节点") return type_mapping.get(entity_type, f"{entity_type}实体节点")
async def _output_relationship_creation_results( async def _output_relationship_creation_results(
self, self,
entity_entity_edges: List[EntityEntityEdge], entity_entity_edges: List[EntityEntityEdge],
entity_nodes: List[ExtractedEntityNode] entity_nodes: List[ExtractedEntityNode]
): ):
""" """
输出关系创建结果 输出关系创建结果
@@ -1681,13 +1758,13 @@ class ExtractionOrchestrator:
logger.error(f"输出关系创建结果失败: {e}", exc_info=True) logger.error(f"输出关系创建结果失败: {e}", exc_info=True)
async def _send_dedup_progress_callback( async def _send_dedup_progress_callback(
self, self,
original_entities: int, original_entities: int,
final_entities: int, final_entities: int,
original_stmt_edges: int, original_stmt_edges: int,
final_stmt_edges: int, final_stmt_edges: int,
original_ent_edges: int, original_ent_edges: int,
final_ent_edges: int, final_ent_edges: int,
): ):
""" """
发送去重消歧完成的进度回调,传递具体的去重和消歧效果 发送去重消歧完成的进度回调,传递具体的去重和消歧效果
@@ -1715,7 +1792,8 @@ class ExtractionOrchestrator:
"original_count": original_entities, "original_count": original_entities,
"final_count": final_entities, "final_count": final_entities,
"reduced_count": entities_reduced, "reduced_count": entities_reduced,
"reduction_rate": round(entities_reduced / original_entities * 100, 1) if original_entities > 0 else 0, "reduction_rate": round(entities_reduced / original_entities * 100,
1) if original_entities > 0 else 0,
}, },
"statement_entity_edges": { "statement_entity_edges": {
"original_count": original_stmt_edges, "original_count": original_stmt_edges,
@@ -1790,7 +1868,8 @@ class ExtractionOrchestrator:
disamb_examples.append({ disamb_examples.append({
"entity1_name": entity_name, "entity1_name": entity_name,
"entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:", "").strip() if "vs" in disamb_type else "未知", "entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:",
"").strip() if "vs" in disamb_type else "未知",
"entity2_name": entity_name, "entity2_name": entity_name,
"entity2_type": disamb_type.split("vs")[1].strip() if "vs" in disamb_type else "未知", "entity2_type": disamb_type.split("vs")[1].strip() if "vs" in disamb_type else "未知",
"description": f"{entity_name},消歧区分成功" "description": f"{entity_name},消歧区分成功"
@@ -1815,9 +1894,9 @@ class ExtractionOrchestrator:
async def get_chunked_dialogs( async def get_chunked_dialogs(
chunker_strategy: str = "RecursiveChunker", chunker_strategy: str = "RecursiveChunker",
end_user_id: str = "group_1", end_user_id: str = "group_1",
indices: Optional[List[int]] = None, indices: Optional[List[int]] = None,
) -> List[DialogData]: ) -> List[DialogData]:
"""从测试数据生成分块对话 """从测试数据生成分块对话
@@ -1924,10 +2003,10 @@ async def get_chunked_dialogs(
def preprocess_data( def preprocess_data(
input_path: Optional[str] = None, input_path: Optional[str] = None,
output_path: Optional[str] = None, output_path: Optional[str] = None,
skip_cleaning: bool = True, skip_cleaning: bool = True,
indices: Optional[List[int]] = None indices: Optional[List[int]] = None
) -> List[DialogData]: ) -> List[DialogData]:
"""数据预处理 """数据预处理
@@ -1946,7 +2025,8 @@ def preprocess_data(
) )
preprocessor = DataPreprocessor() preprocessor = DataPreprocessor()
try: try:
cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path, skip_cleaning=skip_cleaning, indices=indices) cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path,
skip_cleaning=skip_cleaning, indices=indices)
logger.debug(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据") logger.debug(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据")
return cleaned_data return cleaned_data
except Exception as e: except Exception as e:
@@ -1955,9 +2035,9 @@ def preprocess_data(
async def get_chunked_dialogs_from_preprocessed( async def get_chunked_dialogs_from_preprocessed(
data: List[DialogData], data: List[DialogData],
chunker_strategy: str = "RecursiveChunker", chunker_strategy: str = "RecursiveChunker",
llm_client: Optional[Any] = None, llm_client: Optional[Any] = None,
) -> List[DialogData]: ) -> List[DialogData]:
"""从预处理后的数据中生成分块 """从预处理后的数据中生成分块
@@ -1988,15 +2068,15 @@ async def get_chunked_dialogs_from_preprocessed(
async def get_chunked_dialogs_with_preprocessing( async def get_chunked_dialogs_with_preprocessing(
chunker_strategy: str = "RecursiveChunker", chunker_strategy: str = "RecursiveChunker",
end_user_id: str = "default", end_user_id: str = "default",
user_id: str = "default", user_id: str = "default",
apply_id: str = "default", apply_id: str = "default",
indices: Optional[List[int]] = None, indices: Optional[List[int]] = None,
input_data_path: Optional[str] = None, input_data_path: Optional[str] = None,
llm_client: Optional[Any] = None, llm_client: Optional[Any] = None,
skip_cleaning: bool = True, skip_cleaning: bool = True,
pruning_config: Optional[Dict] = None, pruning_config: Optional[Dict] = None,
) -> List[DialogData]: ) -> List[DialogData]:
"""包含数据预处理步骤的完整分块流程 """包含数据预处理步骤的完整分块流程
@@ -2046,7 +2126,8 @@ async def get_chunked_dialogs_with_preprocessing(
if pruning_config: if pruning_config:
# 使用传入的配置 # 使用传入的配置
config = PruningConfig(**pruning_config) config = PruningConfig(**pruning_config)
logger.debug(f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}") logger.debug(
f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}")
else: else:
# 使用默认配置(关闭剪枝) # 使用默认配置(关闭剪枝)
config = None config = None

View File

@@ -188,7 +188,6 @@ async def _process_chunk_summary(
response_model=MemorySummaryResponse, response_model=MemorySummaryResponse,
) )
summary_text = structured.summary.strip() summary_text = structured.summary.strip()
# Generate title and type for the summary # Generate title and type for the summary
title = None title = None
episodic_type = None episodic_type = None

View File

@@ -390,6 +390,8 @@ class GraphBuilder:
for edge in self.edges: for edge in self.edges:
source = edge.get("source") source = edge.get("source")
target = edge.get("target") target = edge.get("target")
if source not in self.reachable_nodes or target not in self.reachable_nodes:
continue
condition = edge.get("condition") condition = edge.get("condition")
edge_type = edge.get("type") edge_type = edge.get("type")

View File

@@ -12,14 +12,26 @@ class ExecutionContext(BaseModel):
execution_id: str execution_id: str
workspace_id: str workspace_id: str
user_id: str user_id: str
memory_storage_type: str
user_rag_memory_id: str
checkpoint_config: RunnableConfig checkpoint_config: RunnableConfig
@classmethod @classmethod
def create(cls, execution_id: str, workspace_id: str, user_id: str): def create(
cls,
execution_id: str,
workspace_id: str,
user_id: str,
memory_storage_type: str,
user_rag_memory_id: str
):
return cls( return cls(
execution_id=execution_id, execution_id=execution_id,
workspace_id=workspace_id, workspace_id=workspace_id,
user_id=user_id, user_id=user_id,
memory_storage_type=memory_storage_type,
user_rag_memory_id=user_rag_memory_id,
checkpoint_config=RunnableConfig( checkpoint_config=RunnableConfig(
configurable={ configurable={
"thread_id": uuid.uuid4(), "thread_id": uuid.uuid4(),

View File

@@ -33,6 +33,8 @@ class WorkflowState(dict):
"workspace_id", "workspace_id",
"user_id", "user_id",
"activate", "activate",
"memory_storage_type",
"user_rag_memory_id"
}) })
__optional_keys__ = frozenset({ __optional_keys__ = frozenset({
"error", "error",
@@ -62,6 +64,9 @@ class WorkflowState(dict):
# node activate status # node activate status
activate: Annotated[dict[str, bool], merge_activate_state] activate: Annotated[dict[str, bool], merge_activate_state]
memory_storage_type: str
user_rag_memory_id: str
class WorkflowStateManager: class WorkflowStateManager:
def create_initial_state( def create_initial_state(
@@ -85,7 +90,9 @@ class WorkflowStateManager:
looping=0, looping=0,
activate={ activate={
start_node_id: True start_node_id: True
} },
memory_storage_type=execution_context.memory_storage_type,
user_rag_memory_id=execution_context.user_rag_memory_id
) )
@staticmethod @staticmethod

View File

@@ -13,7 +13,7 @@ from pydantic import BaseModel
from app.core.workflow.engine.runtime_schema import ExecutionContext from app.core.workflow.engine.runtime_schema import ExecutionContext
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
from app.core.workflow.variable.variable_objects import T, create_variable_instance from app.core.workflow.variable.variable_objects import T, create_variable_instance, ArrayVariable, FileVariable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -373,6 +373,16 @@ class VariablePool:
def copy(self, pool: 'VariablePool'): def copy(self, pool: 'VariablePool'):
self.variables = deepcopy(pool.variables) self.variables = deepcopy(pool.variables)
def is_file_variable(self, selector):
variable_struct = self.get_instance(selector, default=None, strict=False)
if variable_struct is None:
return False
if isinstance(variable_struct, FileVariable):
return True
elif isinstance(variable_struct, ArrayVariable) and variable_struct.child_type == FileVariable:
return True
return False
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
"""导出为字典 """导出为字典

View File

@@ -409,7 +409,9 @@ async def execute_workflow(
input_data: dict[str, Any], input_data: dict[str, Any],
execution_id: str, execution_id: str,
workspace_id: str, workspace_id: str,
user_id: str user_id: str,
memory_storage_type: str,
user_rag_memory_id: str
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
Execute a workflow (convenience function, non-streaming). Execute a workflow (convenience function, non-streaming).
@@ -420,6 +422,8 @@ async def execute_workflow(
execution_id (str): Execution ID. execution_id (str): Execution ID.
workspace_id (str): Workspace ID. workspace_id (str): Workspace ID.
user_id (str): User ID. user_id (str): User ID.
user_rag_memory_id: rag knowledge db id
memory_storage_type: neo4j / rag
Returns: Returns:
dict: Workflow execution result. dict: Workflow execution result.
@@ -427,7 +431,9 @@ async def execute_workflow(
execution_context = ExecutionContext.create( execution_context = ExecutionContext.create(
execution_id=execution_id, execution_id=execution_id,
workspace_id=workspace_id, workspace_id=workspace_id,
user_id=user_id user_id=user_id,
memory_storage_type=memory_storage_type,
user_rag_memory_id=user_rag_memory_id
) )
executor = WorkflowExecutor( executor = WorkflowExecutor(
workflow_config=workflow_config, workflow_config=workflow_config,
@@ -441,7 +447,9 @@ async def execute_workflow_stream(
input_data: dict[str, Any], input_data: dict[str, Any],
execution_id: str, execution_id: str,
workspace_id: str, workspace_id: str,
user_id: str user_id: str,
memory_storage_type: str,
user_rag_memory_id: str
): ):
""" """
Execute a workflow in streaming mode (convenience function). Execute a workflow in streaming mode (convenience function).
@@ -452,6 +460,8 @@ async def execute_workflow_stream(
execution_id (str): Execution ID. execution_id (str): Execution ID.
workspace_id (str): Workspace ID. workspace_id (str): Workspace ID.
user_id (str): User ID. user_id (str): User ID.
user_rag_memory_id: rag knowledge db id
memory_storage_type: neo4j / rag
Yields: Yields:
dict: Streaming workflow events, e.g. node start, node end, chunk messages, workflow end. dict: Streaming workflow events, e.g. node start, node end, chunk messages, workflow end.
@@ -459,7 +469,9 @@ async def execute_workflow_stream(
execution_context = ExecutionContext.create( execution_context = ExecutionContext.create(
execution_id=execution_id, execution_id=execution_id,
workspace_id=workspace_id, workspace_id=workspace_id,
user_id=user_id user_id=user_id,
memory_storage_type=memory_storage_type,
user_rag_memory_id=user_rag_memory_id
) )
executor = WorkflowExecutor( executor = WorkflowExecutor(
workflow_config=workflow_config, workflow_config=workflow_config,

View File

@@ -623,7 +623,6 @@ class BaseNode(ABC):
async def process_message( async def process_message(
api_config: ModelInfo, api_config: ModelInfo,
content: str | dict | FileObject, content: str | dict | FileObject,
end_user_id: str,
enable_file=False enable_file=False
) -> list | str | None: ) -> list | str | None:
provider = api_config.provider provider = api_config.provider
@@ -642,8 +641,8 @@ class BaseNode(ABC):
return content return content
elif isinstance(content, FileObject): elif isinstance(content, FileObject):
if content.content_cache.get(provider): if content.content_cache.get(f"{provider}_{ModelInfo.is_omni}"):
return content.content_cache[provider] return content.content_cache[f"{provider}_{ModelInfo.is_omni}"]
with get_db_read() as db: with get_db_read() as db:
multimodel_service = MultimodalService(db, api_config=api_config) multimodel_service = MultimodalService(db, api_config=api_config)
file_obj = FileInput( file_obj = FileInput(
@@ -655,12 +654,11 @@ class BaseNode(ABC):
) )
file_obj.set_content(content.get_content()) file_obj.set_content(content.get_content())
message = await multimodel_service.process_files( message = await multimodel_service.process_files(
end_user_id,
[file_obj], [file_obj],
) )
content.set_content(file_obj.get_content()) content.set_content(file_obj.get_content())
if message: if message:
content.content_cache[provider] = message content.content_cache[f"{provider}_{ModelInfo.is_omni}"] = message
return message return message
return None return None
raise TypeError(f'Unexpect input value type - {type(content)}') raise TypeError(f'Unexpect input value type - {type(content)}')

View File

@@ -144,7 +144,6 @@ class LLMNode(BaseNode):
f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}") f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}")
messages_config = self.typed_config.messages messages_config = self.typed_config.messages
if messages_config: if messages_config:
# 使用 LangChain 消息格式 # 使用 LangChain 消息格式
messages = [] messages = []
@@ -153,7 +152,6 @@ class LLMNode(BaseNode):
content_template = msg_config.content content_template = msg_config.content
content_template = self._render_context(content_template, variable_pool) content_template = self._render_context(content_template, variable_pool)
content = self._render_template(content_template, variable_pool) content = self._render_template(content_template, variable_pool)
user_id = self.get_variable("sys.user_id", variable_pool)
# 根据角色创建对应的消息对象 # 根据角色创建对应的消息对象
if role == "system": if role == "system":
messages.append({ messages.append({
@@ -161,32 +159,31 @@ class LLMNode(BaseNode):
"content": await self.process_message( "content": await self.process_message(
model_info, model_info,
content, content,
user_id,
self.typed_config.vision, self.typed_config.vision,
) )
}) })
elif role in ["user", "human"]: elif role in ["user", "human"]:
messages.append({ messages.append({
"role": "user", "role": "user",
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision) "content": await self.process_message(model_info, content, self.typed_config.vision)
}) })
elif role in ["ai", "assistant"]: elif role in ["ai", "assistant"]:
messages.append({ messages.append({
"role": "assistant", "role": "assistant",
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision) "content": await self.process_message(model_info, content, self.typed_config.vision)
}) })
else: else:
logger.warning(f"未知的消息角色: {role},默认使用 user") logger.warning(f"未知的消息角色: {role},默认使用 user")
messages.append({ messages.append({
"role": "user", "role": "user",
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision) "content": await self.process_message(model_info, content, self.typed_config.vision)
}) })
if self.typed_config.vision_input and self.typed_config.vision: if self.typed_config.vision_input and self.typed_config.vision:
file_content = [] file_content = []
files = variable_pool.get_instance(self.typed_config.vision_input) files = variable_pool.get_instance(self.typed_config.vision_input)
for file in files.value: for file in files.value:
content = await self.process_message(model_info, file.value, user_id, self.typed_config.vision) content = await self.process_message(model_info, file.value, self.typed_config.vision)
if content: if content:
file_content.extend(content) file_content.extend(content)
if messages and messages[-1]["role"] == 'user': if messages and messages[-1]["role"] == 'user':
@@ -200,7 +197,7 @@ class LLMNode(BaseNode):
if isinstance(message["content"], list): if isinstance(message["content"], list):
file_content = [] file_content = []
for file in message["content"]: for file in message["content"]:
content = await self.process_message(model_info, file, user_id, self.typed_config.vision) content = await self.process_message(model_info, file, self.typed_config.vision)
if content: if content:
file_content.extend(content) file_content.extend(content)
history_message.append( history_message.append(
@@ -210,7 +207,6 @@ class LLMNode(BaseNode):
message["content"] = await self.process_message( message["content"] = await self.process_message(
model_info, model_info,
message["content"], message["content"],
user_id,
self.typed_config.vision self.typed_config.vision
) )
history_message.append(message) history_message.append(message)

View File

@@ -1,3 +1,4 @@
import re
from typing import Any from typing import Any
from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.state_manager import WorkflowState
@@ -5,7 +6,9 @@ from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
from app.db import get_db_read from app.db import get_db_read
from app.schemas import FileInput
from app.services.memory_agent_service import MemoryAgentService from app.services.memory_agent_service import MemoryAgentService
from app.tasks import write_message_task from app.tasks import write_message_task
@@ -36,8 +39,8 @@ class MemoryReadNode(BaseNode):
search_switch=self.typed_config.search_switch, search_switch=self.typed_config.search_switch,
history=[], history=[],
db=db, db=db,
storage_type="neo4j", storage_type=state["memory_storage_type"],
user_rag_memory_id="" user_rag_memory_id=state["user_rag_memory_id"]
) )
@@ -49,6 +52,19 @@ class MemoryWriteNode(BaseNode):
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
return {"output": VariableType.STRING} return {"output": VariableType.STRING}
@staticmethod
def _extract_multimodal_memory_variables(content: str, variable_pool: VariablePool) -> tuple[list[str], str]:
variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}'
variable_pattern = re.compile(variable_pattern_string)
variables = variable_pattern.findall(content)
file_variables = []
for variable in variables:
if variable_pool.is_file_variable(variable):
file_variables.append(variable)
for var in file_variables:
content = content.replace(var, "")
return file_variables, content
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
self.typed_config = MemoryWriteNodeConfig(**self.config) self.typed_config = MemoryWriteNodeConfig(**self.config)
end_user_id = self.get_variable("sys.user_id", variable_pool) end_user_id = self.get_variable("sys.user_id", variable_pool)
@@ -63,17 +79,42 @@ class MemoryWriteNode(BaseNode):
}) })
for message in self.typed_config.messages: for message in self.typed_config.messages:
file_variables, content = self._extract_multimodal_memory_variables(
message.content,
variable_pool
)
file_info = []
for var in file_variables:
instence: FileVariable | ArrayVariable[FileVariable] = variable_pool.get_instance(var)
if isinstance(instence, FileVariable):
file_info.append(FileInput(
type=instence.value.type,
transfer_method=instence.value.transfer_method,
upload_file_id=instence.value.file_id,
url=instence.value.url,
file_type=instence.value.origin_file_type
).model_dump())
elif isinstance(instence, ArrayVariable) and instence.child_type == FileVariable:
for file_instence in instence.value:
file_info.append(FileInput(
type=file_instence.value.type,
transfer_method=file_instence.value.transfer_method,
upload_file_id=file_instence.value.file_id,
url=file_instence.value.url,
file_type=file_instence.value.origin_file_type
).model_dump())
messages.append({ messages.append({
"role": message.role, "role": message.role,
"content": self._render_template(message.content, variable_pool) "content": self._render_template(content, variable_pool),
"files": file_info
}) })
write_message_task.delay( write_message_task.delay(
end_user_id, end_user_id=end_user_id,
messages, message=messages,
str(self.typed_config.config_id), config_id=str(self.typed_config.config_id),
"neo4j", storage_type=state["memory_storage_type"],
"" user_rag_memory_id=state["user_rag_memory_id"]
) )
return "success" return "success"

View File

@@ -30,6 +30,9 @@ class MemoryConfig(Base):
llm_id = Column(String, nullable=True, comment="LLM模型配置ID") llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID") embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
rerank_id = Column(String, nullable=True, comment="重排序模型配置ID") rerank_id = Column(String, nullable=True, comment="重排序模型配置ID")
vision_id = Column(String, nullable=True, comment="视觉模型配置ID")
audio_id = Column(String, nullable=True, comment="语音模型配置ID")
video_id = Column(String, nullable=True, comment="视频模型配置ID")
# 记忆萃取引擎配置 # 记忆萃取引擎配置
enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重") enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重")

View File

@@ -2,10 +2,11 @@ import datetime
import uuid import uuid
from enum import StrEnum from enum import StrEnum
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table, text from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, UniqueConstraint, Integer, Table, text
from sqlalchemy.dialects.postgresql import UUID, JSON from sqlalchemy.dialects.postgresql import UUID, JSON, ARRAY
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from sqlalchemy.sql import func from sqlalchemy.sql import func
from app.db import Base from app.db import Base

View File

@@ -9,21 +9,22 @@ Classes:
""" """
import uuid import uuid
from uuid import UUID
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from uuid import UUID
from sqlalchemy import desc, select
from sqlalchemy.orm import Session
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.logging_config import get_config_logger, get_db_logger from app.core.logging_config import get_config_logger, get_db_logger
from app.models.memory_config_model import MemoryConfig from app.models.memory_config_model import MemoryConfig
from app.models.workspace_model import Workspace
from app.schemas.memory_storage_schema import ( from app.schemas.memory_storage_schema import (
ConfigKey,
ConfigParamsCreate, ConfigParamsCreate,
ConfigUpdate, ConfigUpdate,
ConfigUpdateExtracted, ConfigUpdateExtracted,
ConfigUpdateForget, ConfigUpdateForget,
) )
from sqlalchemy import desc, select
from sqlalchemy.orm import Session
from app.utils.config_utils import resolve_config_id from app.utils.config_utils import resolve_config_id
# 获取数据库专用日志器 # 获取数据库专用日志器
@@ -157,7 +158,7 @@ class MemoryConfigRepository:
return memory_config_obj return memory_config_obj
@staticmethod @staticmethod
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID|int|str) -> MemoryConfig: def query_reflection_config_by_id(db: Session, config_id: uuid.UUID | int | str) -> MemoryConfig:
"""构建反思配置查询语句通过config_id查询反思配置SQLAlchemy text() 命名参数) """构建反思配置查询语句通过config_id查询反思配置SQLAlchemy text() 命名参数)
Args: Args:
@@ -309,57 +310,21 @@ class MemoryConfigRepository:
Returns: Returns:
Optional[MemoryConfig]: 更新后的配置对象不存在则返回None Optional[MemoryConfig]: 更新后的配置对象不存在则返回None
Raises:
ValueError: 没有字段需要更新时抛出
""" """
db_logger.debug(f"更新萃取配置: config_id={update.config_id}") db_logger.debug(f"更新萃取配置: config_id={update.config_id}")
try: try:
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == update.config_id).first() stmt = select(MemoryConfig).where(MemoryConfig.config_id == update.config_id)
db_config = db.execute(stmt).scalar_one_or_none()
if not db_config: if not db_config:
db_logger.warning(f"记忆配置不存在: config_id={update.config_id}") db_logger.warning(f"记忆配置不存在: config_id={update.config_id}")
return None return None
# 更新字段映射 update_data = update.model_dump(exclude_unset=True)
field_mapping = { update_data.pop("config_id", None)
# 模型选择
"llm_id": "llm_id",
"embedding_id": "embedding_id",
"rerank_id": "rerank_id",
# 记忆萃取引擎
"enable_llm_dedup_blockwise": "enable_llm_dedup_blockwise",
"enable_llm_disambiguation": "enable_llm_disambiguation",
"deep_retrieval": "deep_retrieval",
"t_type_strict": "t_type_strict",
"t_name_strict": "t_name_strict",
"t_overall": "t_overall",
"state": "state",
"chunker_strategy": "chunker_strategy",
# 句子提取
"statement_granularity": "statement_granularity",
"include_dialogue_context": "include_dialogue_context",
"max_context": "max_context",
# 剪枝配置
"pruning_enabled": "pruning_enabled",
"pruning_scene": "pruning_scene",
"pruning_threshold": "pruning_threshold",
# 自我反思配置
"enable_self_reflexion": "enable_self_reflexion",
"iteration_period": "iteration_period",
"reflexion_range": "reflexion_range",
"baseline": "baseline",
}
has_update = False for field, value in update_data.items():
for api_field, db_field in field_mapping.items(): setattr(db_config, field, value)
value = getattr(update, api_field, None)
if value is not None:
setattr(db_config, db_field, value)
has_update = True
if not has_update:
raise ValueError("No fields to update")
db.commit() db.commit()
db.refresh(db_config) db.refresh(db_config)
@@ -443,6 +408,9 @@ class MemoryConfigRepository:
"llm_id": db_config.llm_id, "llm_id": db_config.llm_id,
"embedding_id": db_config.embedding_id, "embedding_id": db_config.embedding_id,
"rerank_id": db_config.rerank_id, "rerank_id": db_config.rerank_id,
"vision_id": db_config.vision_id,
"audio_id": db_config.audio_id,
"video_id": db_config.video_id,
"enable_llm_dedup_blockwise": db_config.enable_llm_dedup_blockwise, "enable_llm_dedup_blockwise": db_config.enable_llm_dedup_blockwise,
"enable_llm_disambiguation": db_config.enable_llm_disambiguation, "enable_llm_disambiguation": db_config.enable_llm_disambiguation,
"deep_retrieval": db_config.deep_retrieval, "deep_retrieval": db_config.deep_retrieval,
@@ -527,7 +495,10 @@ class MemoryConfigRepository:
raise raise
@staticmethod @staticmethod
def get_config_with_workspace(db: Session, config_id: uuid.UUID | int | str) -> Optional[tuple]: def get_config_with_workspace(
db: Session,
config_id: uuid.UUID | int | str
) -> Optional[tuple[MemoryConfig, Workspace]]:
"""Get memory config and its associated workspace information """Get memory config and its associated workspace information
Args: Args:
@@ -542,8 +513,6 @@ class MemoryConfigRepository:
""" """
import time import time
from app.models.workspace_model import Workspace
start_time = time.time() start_time = time.time()
config_id = resolve_config_id(config_id, db) config_id = resolve_config_id(config_id, db)
@@ -630,7 +599,7 @@ class MemoryConfigRepository:
db_logger.debug( db_logger.debug(
f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}") f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
return (config, workspace) return config, workspace
except ValueError: except ValueError:
# Re-raise known business exceptions # Re-raise known business exceptions
@@ -775,9 +744,9 @@ class MemoryConfigRepository:
@staticmethod @staticmethod
def get_with_fallback( def get_with_fallback(
db: Session, db: Session,
config_id: Optional[uuid.UUID], config_id: Optional[uuid.UUID],
workspace_id: uuid.UUID workspace_id: uuid.UUID
) -> Optional[MemoryConfig]: ) -> Optional[MemoryConfig]:
"""获取记忆配置,支持回退到工作空间默认配置 """获取记忆配置,支持回退到工作空间默认配置
@@ -807,4 +776,3 @@ class MemoryConfigRepository:
) )
return MemoryConfigRepository.get_workspace_default(db, workspace_id) return MemoryConfigRepository.get_workspace_default(db, workspace_id)

View File

@@ -1,14 +1,15 @@
from sqlalchemy.orm import Session, joinedload, selectinload
from sqlalchemy import and_, or_, func, desc, select
from typing import List, Optional, Dict, Any, Tuple
import uuid import uuid
from typing import List, Optional, Dict, Any, Tuple
from sqlalchemy import and_, or_, func, desc
from sqlalchemy.orm import Session, joinedload
from app.core.logging_config import get_db_logger
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelBase, model_config_api_key_association from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelBase, model_config_api_key_association
from app.schemas.model_schema import ( from app.schemas.model_schema import (
ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate, ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
ModelConfigQuery, ModelConfigQueryNew ModelConfigQuery, ModelConfigQueryNew
) )
from app.core.logging_config import get_db_logger
# 获取数据库专用日志器 # 获取数据库专用日志器
db_logger = get_db_logger() db_logger = get_db_logger()
@@ -137,6 +138,9 @@ class ModelConfigRepository:
type_values.append(ModelType.LLM) type_values.append(ModelType.LLM)
filters.append(ModelConfig.type.in_(type_values)) filters.append(ModelConfig.type.in_(type_values))
if query.capability:
filters.append(ModelConfig.capability.contains(query.capability))
if query.is_active is not None: if query.is_active is not None:
filters.append(ModelConfig.is_active == query.is_active) filters.append(ModelConfig.is_active == query.is_active)

View File

@@ -1,8 +1,9 @@
from typing import List, Optional
import logging import logging
from typing import List, Optional
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE,MEMORY_SUMMARY_NODE_SAVE
from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE, \
MEMORY_SUMMARY_NODE_SAVE
# 使用新的仓储层 # 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
@@ -12,9 +13,10 @@ logger = logging.getLogger(__name__)
async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector): async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector):
"""Delete all nodes in the database.""" """Delete all nodes in the database."""
result = await connector.execute_query(f"MATCH (n {{end_user_id: '{end_user_id}'}}) DETACH DELETE n") result = await connector.execute_query(f"MATCH (n {{end_user_id: '{end_user_id}'}}) DETACH DELETE n")
print(f"All end_user_id: {end_user_id} node and edge deleted successfully") logger.warning(f"All end_user_id: {end_user_id} node and edge deleted successfully")
return result return result
async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]: async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]:
"""Add dialogue nodes to Neo4j database. """Add dialogue nodes to Neo4j database.
@@ -26,7 +28,7 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn
List of created node UUIDs or None if failed List of created node UUIDs or None if failed
""" """
if not dialogues: if not dialogues:
print("No dialogues to save") logger.info("No dialogues to save")
return [] return []
try: try:
@@ -51,11 +53,11 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn
) )
created_uuids = [record["uuid"] for record in result] created_uuids = [record["uuid"] for record in result]
print(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}") logger.info(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}")
return created_uuids return created_uuids
except Exception as e: except Exception as e:
print(f"Error creating dialogue nodes: {e}") logger.error(f"Error creating dialogue nodes: {e}")
return None return None
@@ -70,7 +72,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
List of created node UUIDs or None if failed List of created node UUIDs or None if failed
""" """
if not statements: if not statements:
print("No statements to save") logger.info("No statements to save")
return [] return []
try: try:
@@ -123,13 +125,14 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
) )
created_uuids = [record["uuid"] for record in result] created_uuids = [record["uuid"] for record in result]
print(f"Successfully created {len(created_uuids)} statement nodes") logger.info(f"Successfully created {len(created_uuids)} statement nodes")
return created_uuids return created_uuids
except Exception as e: except Exception as e:
print(f"Error creating statement nodes: {e}") logger.error(f"Error creating statement nodes: {e}")
return None return None
async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]: async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]:
"""Add chunk nodes to Neo4j in batch. """Add chunk nodes to Neo4j in batch.
@@ -141,7 +144,7 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
List of created chunk UUIDs or None if failed List of created chunk UUIDs or None if failed
""" """
if not chunks: if not chunks:
print("No chunk nodes to add") logger.info("No chunk nodes to add")
return [] return []
try: try:
@@ -174,16 +177,18 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
) )
created_uuids = [record["uuid"] for record in result] created_uuids = [record["uuid"] for record in result]
print(f"Successfully created {len(created_uuids)} chunk nodes") logger.info(f"Successfully created {len(created_uuids)} chunk nodes")
return created_uuids return created_uuids
except Exception as e: except Exception as e:
print(f"Error creating chunk nodes: {e}") logger.error(f"Error creating chunk nodes: {e}")
return None return None
async def add_memory_summary_nodes(
async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[List[str]]: summaries: List[MemorySummaryNode],
connector: Neo4jConnector
) -> Optional[List[str]]:
"""Add memory summary nodes to Neo4j in batch. """Add memory summary nodes to Neo4j in batch.
Args: Args:
@@ -194,7 +199,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
List of created summary node ids or None if failed List of created summary node ids or None if failed
""" """
if not summaries: if not summaries:
print("No memory summary nodes to add") logger.info("No memory summary nodes to add")
return [] return []
try: try:
@@ -225,5 +230,3 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
except Exception as e: except Exception as e:
logger.error(f"Failed to save MemorySummary nodes to Neo4j: {e}") logger.error(f"Failed to save MemorySummary nodes to Neo4j: {e}")
return None return None

View File

@@ -709,7 +709,6 @@ SET r.end_user_id = e.end_user_id,
RETURN elementId(r) AS uuid RETURN elementId(r) AS uuid
""" """
# Entity Merge Query # Entity Merge Query
MERGE_ENTITIES = """ MERGE_ENTITIES = """
MATCH (canonical:ExtractedEntity {id: $canonical_id}) MATCH (canonical:ExtractedEntity {id: $canonical_id})
@@ -829,9 +828,8 @@ neo4j_query_all = """
other as entity2 other as entity2
""" """
'''针对当前节点下扩长的句子,实体和总结''' '''针对当前节点下扩长的句子,实体和总结'''
Memory_Timeline_ExtractedEntity=""" Memory_Timeline_ExtractedEntity = """
MATCH (n)-[r1]-(e)-[r2]-(ms) MATCH (n)-[r1]-(e)-[r2]-(ms)
WHERE elementId(n) = $id WHERE elementId(n) = $id
AND (ms:ExtractedEntity OR ms:MemorySummary) AND (ms:ExtractedEntity OR ms:MemorySummary)
@@ -869,7 +867,7 @@ RETURN
""" """
Memory_Timeline_MemorySummary=""" Memory_Timeline_MemorySummary = """
MATCH (n)-[r1]-(e)-[r2]-(ms) MATCH (n)-[r1]-(e)-[r2]-(ms)
WHERE elementId(n) =$id WHERE elementId(n) =$id
AND (ms:MemorySummary OR ms:ExtractedEntity) AND (ms:MemorySummary OR ms:ExtractedEntity)
@@ -904,7 +902,7 @@ RETURN
} }
) AS statement; ) AS statement;
""" """
Memory_Timeline_Statement=""" Memory_Timeline_Statement = """
MATCH (n) MATCH (n)
WHERE elementId(n) = $id WHERE elementId(n) = $id
@@ -947,7 +945,7 @@ RETURN
""" """
'''针对当前节点,主要获取更加完整的句子节点''' '''针对当前节点,主要获取更加完整的句子节点'''
Memory_Space_Emotion_Statement=""" Memory_Space_Emotion_Statement = """
MATCH (n) MATCH (n)
WHERE elementId(n) = $id WHERE elementId(n) = $id
RETURN RETURN
@@ -957,7 +955,7 @@ RETURN
n.statement AS statement; n.statement AS statement;
""" """
Memory_Space_Emotion_MemorySummary=""" Memory_Space_Emotion_MemorySummary = """
MATCH (n)-[]-(e) MATCH (n)-[]-(e)
WHERE elementId(n) = $id WHERE elementId(n) = $id
AND EXISTS { AND EXISTS {
@@ -970,7 +968,7 @@ RETURN DISTINCT
e.emotion_type AS emotion_type, e.emotion_type AS emotion_type,
e.statement AS statement; e.statement AS statement;
""" """
Memory_Space_Emotion_ExtractedEntity=""" Memory_Space_Emotion_ExtractedEntity = """
MATCH (n)-[]-(e) MATCH (n)-[]-(e)
WHERE elementId(n) = $id WHERE elementId(n) = $id
AND EXISTS { AND EXISTS {
@@ -985,18 +983,18 @@ RETURN DISTINCT
'''获取实体''' '''获取实体'''
Memory_Space_User=""" Memory_Space_User = """
MATCH (n)-[r]->(m) MATCH (n)-[r]->(m)
WHERE n.end_user_id = $end_user_id AND m.name="用户" WHERE n.end_user_id = $end_user_id AND m.name="用户"
return DISTINCT elementId(m) as id return DISTINCT elementId(m) as id
""" """
Memory_Space_Entity=""" Memory_Space_Entity = """
MATCH (n)-[]-(m) MATCH (n)-[]-(m)
WHERE elementId(m) = $id AND m.entity_type = "Person" WHERE elementId(m) = $id AND m.entity_type = "Person"
RETURN RETURN
DISTINCT m.name as name,m.end_user_id as end_user_id DISTINCT m.name as name,m.end_user_id as end_user_id
""" """
Memory_Space_Associative=""" Memory_Space_Associative = """
MATCH (u)-[]-(x)-[]-(h) MATCH (u)-[]-(x)-[]-(h)
WHERE elementId(u) = $user_id WHERE elementId(u) = $user_id
AND elementId(h) = $id AND elementId(h) = $id
@@ -1005,61 +1003,69 @@ RETURN DISTINCT
""" """
Graph_Node_query = """ Graph_Node_query = """
MATCH (n:MemorySummary) MATCH (n:MemorySummary)
WHERE n.end_user_id = $end_user_id WHERE n.end_user_id = $end_user_id
RETURN RETURN
elementId(n) AS id, elementId(n) AS id,
labels(n) AS labels, labels(n) AS labels,
properties(n) AS properties, properties(n) AS properties,
0 AS priority 0 AS priority
LIMIT $limit LIMIT $limit
UNION ALL UNION ALL
MATCH (n:Dialogue) MATCH (n:Dialogue)
WHERE n.end_user_id = $end_user_id WHERE n.end_user_id = $end_user_id
RETURN RETURN
elementId(n) AS id, elementId(n) AS id,
labels(n) AS labels, labels(n) AS labels,
properties(n) AS properties, properties(n) AS properties,
1 AS priority 1 AS priority
LIMIT 1 LIMIT 1
UNION ALL UNION ALL
MATCH (n:Statement) MATCH (n:Statement)
WHERE n.end_user_id = $end_user_id WHERE n.end_user_id = $end_user_id
RETURN RETURN
elementId(n) AS id, elementId(n) AS id,
labels(n) AS labels, labels(n) AS labels,
properties(n) AS properties, properties(n) AS properties,
1 AS priority 1 AS priority
LIMIT $limit LIMIT $limit
UNION ALL UNION ALL
MATCH (n:ExtractedEntity) MATCH (n:ExtractedEntity)
WHERE n.end_user_id = $end_user_id WHERE n.end_user_id = $end_user_id
RETURN RETURN
elementId(n) AS id, elementId(n) AS id,
labels(n) AS labels, labels(n) AS labels,
properties(n) AS properties, properties(n) AS properties,
2 AS priority 2 AS priority
LIMIT $limit LIMIT $limit
UNION ALL UNION ALL
MATCH (n:Chunk) MATCH (n:Chunk)
WHERE n.end_user_id = $end_user_id WHERE n.end_user_id = $end_user_id
RETURN RETURN
elementId(n) AS id, elementId(n) AS id,
labels(n) AS labels, labels(n) AS labels,
properties(n) AS properties, properties(n) AS properties,
3 AS priority 3 AS priority
LIMIT $limit LIMIT $limit
""" UNION ALL
MATCH (n:Perceptual)
WHERE n.end_user_id = $end_user_id
RETURN
elementId(n) AS id,
labels(n) AS labels,
properties(n) AS properties,
4 AS priority
"""
# ============================================================ # ============================================================
# Community 节点 & BELONGS_TO_COMMUNITY 边 # Community 节点 & BELONGS_TO_COMMUNITY 边
@@ -1363,3 +1369,36 @@ RETURN s.statement AS statement,
ORDER BY COALESCE(s.activation_value, 0) DESC ORDER BY COALESCE(s.activation_value, 0) DESC
LIMIT $limit LIMIT $limit
""" """
# 感知记忆节点保存
PERCEPTUAL_NODE_SAVE = """
UNWIND $perceptuals AS p
MERGE (n:Perceptual {id: p.id})
SET n += {
id: p.id,
end_user_id: p.end_user_id,
perceptual_type: p.perceptual_type,
file_path: p.file_path,
file_name: p.file_name,
file_ext: p.file_ext,
summary: p.summary,
keywords: p.keywords,
topic: p.topic,
domain: p.domain,
created_at: p.created_at,
file_type: p.file_type,
summary_embedding: p.summary_embedding
}
RETURN n.id AS uuid
"""
# 感知记忆与对话的关联边
PERCEPTUAL_CHUNK_EDGE_SAVE = """
UNWIND $edges AS edge
MATCH (p:Perceptual {id: edge.perceptual_id, end_user_id: edge.end_user_id})
MATCH (c:Chunk {id: edge.chunk_id, end_user_id: edge.end_user_id})
MERGE (c)-[r:HAS_PERCEPTUAL]->(p)
ON CREATE SET r.end_user_id = edge.end_user_id,
r.created_at = edge.created_at
RETURN elementId(r) AS uuid
"""

View File

@@ -22,13 +22,18 @@ from app.core.memory.models.graph_models import (
StatementNode, StatementNode,
ExtractedEntityNode, ExtractedEntityNode,
EntityEntityEdge, EntityEntityEdge,
PerceptualNode,
PerceptualEdge,
) )
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def save_entities_and_relationships( async def save_entities_and_relationships(
entity_nodes: List[ExtractedEntityNode], entity_nodes: List[ExtractedEntityNode],
entity_entity_edges: List[EntityEntityEdge], entity_entity_edges: List[EntityEntityEdge],
connector: Neo4jConnector connector: Neo4jConnector
): ):
"""Save entities and their relationships using graph models""" """Save entities and their relationships using graph models"""
all_entities = [entity.model_dump() for entity in entity_nodes] all_entities = [entity.model_dump() for entity in entity_nodes]
@@ -73,8 +78,8 @@ async def save_entities_and_relationships(
async def save_chunk_nodes( async def save_chunk_nodes(
chunk_nodes: List[ChunkNode], chunk_nodes: List[ChunkNode],
connector: Neo4jConnector connector: Neo4jConnector
): ):
"""Save chunk nodes using graph models""" """Save chunk nodes using graph models"""
if not chunk_nodes: if not chunk_nodes:
@@ -89,8 +94,8 @@ async def save_chunk_nodes(
async def save_statement_chunk_edges( async def save_statement_chunk_edges(
statement_chunk_edges: List[StatementChunkEdge], statement_chunk_edges: List[StatementChunkEdge],
connector: Neo4jConnector connector: Neo4jConnector
): ):
"""Save statement-chunk edges using graph models""" """Save statement-chunk edges using graph models"""
if not statement_chunk_edges: if not statement_chunk_edges:
@@ -118,8 +123,8 @@ async def save_statement_chunk_edges(
async def save_statement_entity_edges( async def save_statement_entity_edges(
statement_entity_edges: List[StatementEntityEdge], statement_entity_edges: List[StatementEntityEdge],
connector: Neo4jConnector connector: Neo4jConnector
): ):
"""Save statement-entity edges using graph models""" """Save statement-entity edges using graph models"""
if not statement_entity_edges: if not statement_entity_edges:
@@ -154,9 +159,11 @@ async def save_dialog_and_statements_to_neo4j(
chunk_nodes: List[ChunkNode], chunk_nodes: List[ChunkNode],
statement_nodes: List[StatementNode], statement_nodes: List[StatementNode],
entity_nodes: List[ExtractedEntityNode], entity_nodes: List[ExtractedEntityNode],
perceptual_nodes: List[PerceptualNode],
entity_edges: List[EntityEntityEdge], entity_edges: List[EntityEntityEdge],
statement_chunk_edges: List[StatementChunkEdge], statement_chunk_edges: List[StatementChunkEdge],
statement_entity_edges: List[StatementEntityEdge], statement_entity_edges: List[StatementEntityEdge],
perceptual_edges: List[PerceptualEdge],
connector: Neo4jConnector, connector: Neo4jConnector,
) -> bool: ) -> bool:
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models. """Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
@@ -169,9 +176,11 @@ async def save_dialog_and_statements_to_neo4j(
chunk_nodes: List of ChunkNode objects to save chunk_nodes: List of ChunkNode objects to save
statement_nodes: List of StatementNode objects to save statement_nodes: List of StatementNode objects to save
entity_nodes: List of ExtractedEntityNode objects to save entity_nodes: List of ExtractedEntityNode objects to save
perceptual_nodes: List of PerceptualNode objects to save
entity_edges: List of EntityEntityEdge objects to save entity_edges: List of EntityEntityEdge objects to save
statement_chunk_edges: List of StatementChunkEdge objects to save statement_chunk_edges: List of StatementChunkEdge objects to save
statement_entity_edges: List of StatementEntityEdge objects to save statement_entity_edges: List of StatementEntityEdge objects to save
perceptual_edges: List of PerceptualEdge objects to save
connector: Neo4j connector instance connector: Neo4j connector instance
Returns: Returns:
@@ -190,7 +199,7 @@ async def save_dialog_and_statements_to_neo4j(
result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data) result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data)
dialogue_uuids = [record["uuid"] async for record in result] dialogue_uuids = [record["uuid"] async for record in result]
results['dialogues'] = dialogue_uuids results['dialogues'] = dialogue_uuids
print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}") logger.info(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
# 2. Save all chunk nodes in batch # 2. Save all chunk nodes in batch
if chunk_nodes: if chunk_nodes:
@@ -201,6 +210,14 @@ async def save_dialog_and_statements_to_neo4j(
results['chunks'] = chunk_uuids results['chunks'] = chunk_uuids
logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j") logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j")
if perceptual_nodes:
from app.repositories.neo4j.cypher_queries import PERCEPTUAL_NODE_SAVE
perceptual_data = [node.model_dump() for node in perceptual_nodes]
result = await tx.run(PERCEPTUAL_NODE_SAVE, perceptuals=perceptual_data)
perceptual_uuids = [record["uuid"] async for record in result]
results["perceptuals"] = perceptual_uuids
logger.info(f"Successfully saved {len(perceptual_uuids)} perceptual nodes to Neo4j")
# 3. Save all statement nodes in batch # 3. Save all statement nodes in batch
if statement_nodes: if statement_nodes:
from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE
@@ -281,6 +298,22 @@ async def save_dialog_and_statements_to_neo4j(
results['statement_entity_edges'] = se_uuids results['statement_entity_edges'] = se_uuids
logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j") logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j")
if perceptual_edges:
from app.repositories.neo4j.cypher_queries import PERCEPTUAL_CHUNK_EDGE_SAVE
perceptual_edge_data = []
for edge in perceptual_edges:
print(edge.source, edge.target)
perceptual_edge_data.append({
"perceptual_id": edge.source,
"chunk_id": edge.target,
"end_user_id": edge.end_user_id,
"created_at": edge.created_at.isoformat() if edge.created_at else None,
})
result = await tx.run(PERCEPTUAL_CHUNK_EDGE_SAVE, edges=perceptual_edge_data)
perceptual_edges_uuids = [record["uuid"] async for record in result]
results['perceptual_chunk_edges'] = perceptual_edges_uuids
logger.info(f"Successfully saved {len(perceptual_edges_uuids)} perceptual-chunk edges to Neo4j")
return results return results
try: try:
@@ -304,9 +337,9 @@ async def save_dialog_and_statements_to_neo4j(
async def _trigger_clustering_sync( async def _trigger_clustering_sync(
entity_nodes: List, entity_nodes: List,
llm_model_id: Optional[str] = None, llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None, embedding_model_id: Optional[str] = None,
) -> None: ) -> None:
""" """
同步等待聚类完成,避免与其他 LLM 任务并发冲突。 同步等待聚类完成,避免与其他 LLM 任务并发冲突。
@@ -322,14 +355,15 @@ async def _trigger_clustering_sync(
end_user_id = entity_nodes[0].end_user_id end_user_id = entity_nodes[0].end_user_id
new_entity_ids = [e.id for e in entity_nodes] new_entity_ids = [e.id for e in entity_nodes]
logger.info(f"[Clustering] 准备触发聚类(同步),实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}") logger.info(f"[Clustering] 准备触发聚类(同步),实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
await _trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id) await _trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id,
embedding_model_id=embedding_model_id)
async def _trigger_clustering( async def _trigger_clustering(
new_entity_ids: List[str], new_entity_ids: List[str],
end_user_id: str, end_user_id: str,
llm_model_id: Optional[str] = None, llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None, embedding_model_id: Optional[str] = None,
) -> None: ) -> None:
""" """
聚类触发函数,自动判断全量初始化还是增量更新。 聚类触发函数,自动判断全量初始化还是增量更新。

View File

@@ -387,6 +387,12 @@ class MemoryConfig:
rerank_model_id: Optional[UUID] = None rerank_model_id: Optional[UUID] = None
rerank_model_name: Optional[str] = None rerank_model_name: Optional[str] = None
video_model_id: Optional[UUID] = None
video_model_name: Optional[str] = None
vision_model_id: Optional[UUID] = None
vision_model_name: Optional[str] = None
audio_model_id: Optional[UUID] = None
audio_model_name: Optional[str] = None
llm_params: Dict[str, Any] = field(default_factory=dict) llm_params: Dict[str, Any] = field(default_factory=dict)
embedding_params: Dict[str, Any] = field(default_factory=dict) embedding_params: Dict[str, Any] = field(default_factory=dict)

View File

@@ -8,9 +8,6 @@ import uuid
from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator
# ============================================================================ # ============================================================================
# 从 json_schema.py 迁移的 Schema # 从 json_schema.py 迁移的 Schema
# ============================================================================ # ============================================================================
@@ -58,10 +55,13 @@ class MemoryVerifySchema(BaseModel):
class ConflictResultSchema(BaseModel): class ConflictResultSchema(BaseModel):
"""Schema for the conflict result data in the reflexion_data.json file.""" """Schema for the conflict result data in the reflexion_data.json file."""
data: List[BaseDataSchema] = Field(..., description="The conflict memory data. Only contains conflicting records when conflict is True.") data: List[BaseDataSchema] = Field(...,
description="The conflict memory data. Only contains conflicting records when conflict is True.")
conflict: bool = Field(..., description="Whether the memory is in conflict.") conflict: bool = Field(..., description="Whether the memory is in conflict.")
quality_assessment: Optional[QualityAssessmentSchema] = Field(None, description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.") quality_assessment: Optional[QualityAssessmentSchema] = Field(None,
memory_verify: Optional[MemoryVerifySchema] = Field(None, description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.") description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.")
memory_verify: Optional[MemoryVerifySchema] = Field(None,
description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.")
@model_validator(mode="before") @model_validator(mode="before")
def _normalize_data(cls, v): def _normalize_data(cls, v):
@@ -105,12 +105,15 @@ class ChangeRecordSchema(BaseModel):
description="List of field changes. First item: {id: value}, followed by changed fields as {field_name: [old_value, new_value]} or {field_name: new_value} or nested structures like {entity2: {field_name: [old, new]}}" description="List of field changes. First item: {id: value}, followed by changed fields as {field_name: [old_value, new_value]} or {field_name: new_value} or nested structures like {entity2: {field_name: [old, new]}}"
) )
class ResolvedSchema(BaseModel): class ResolvedSchema(BaseModel):
"""Schema for the resolved memory data in the reflexion_data""" """Schema for the resolved memory data in the reflexion_data"""
original_memory_id: Optional[str] = Field(None, description="The original memory identifier.") original_memory_id: Optional[str] = Field(None, description="The original memory identifier.")
# resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data (only contains records that need modification).") # resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data (only contains records that need modification).")
resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None, description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.") resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None,
change: Optional[List[ChangeRecordSchema]] = Field(None, description="List of detailed change records with IDs and field information.") description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.")
change: Optional[List[ChangeRecordSchema]] = Field(None,
description="List of detailed change records with IDs and field information.")
class SingleReflexionResultSchema(BaseModel): class SingleReflexionResultSchema(BaseModel):
@@ -120,9 +123,11 @@ class SingleReflexionResultSchema(BaseModel):
resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data for this conflict.") resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data for this conflict.")
type: str = Field("reflexion_result", description="The type identifier.") type: str = Field("reflexion_result", description="The type identifier.")
class ReflexionResultSchema(BaseModel): class ReflexionResultSchema(BaseModel):
"""Schema for the complete reflexion result data - a list of individual conflict resolutions.""" """Schema for the complete reflexion result data - a list of individual conflict resolutions."""
results: List[SingleReflexionResultSchema] = Field(..., description="List of individual conflict resolution results, grouped by conflict type.") results: List[SingleReflexionResultSchema] = Field(...,
description="List of individual conflict resolution results, grouped by conflict type.")
@model_validator(mode="before") @model_validator(mode="before")
def _normalize_resolved(cls, v): def _normalize_resolved(cls, v):
@@ -147,9 +152,9 @@ class ReflexionResultSchema(BaseModel):
# Composite key identifying a config row # Composite key identifying a config row
class ConfigKey(BaseModel): # 配置参数键模型 class ConfigKey(BaseModel): # 配置参数键模型
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
config_id:Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识UUID或int)") config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识UUID或int)")
user_id: str = Field("user_id", description="用户标识(字符串)") user_id: str | None = Field(default=None, description="用户标识(字符串)")
apply_id: str = Field("apply_id", description="应用或场景标识(字符串)") apply_id: str | None = Field(default=None, description="应用或场景标识(字符串)")
# Allowed chunking strategies (extendable later) # Allowed chunking strategies (extendable later)
@@ -241,10 +246,12 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID") rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
reflection_model_id: Optional[str] = Field(None, description="反思模型ID默认与llm_id一致") reflection_model_id: Optional[str] = Field(None, description="反思模型ID默认与llm_id一致")
emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID默认与llm_id一致") emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID默认与llm_id一致")
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体) class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
# config_name: str = Field("配置名称", description="配置名称(字符串)") # config_name: str = Field("配置名称", description="配置名称(字符串)")
config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID支持UUID、整数或字符串") config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID支持UUID、整数或字符串")
class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
@@ -255,8 +262,11 @@ class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用
class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
config_id:Union[uuid.UUID, int, str] = None config_id: Union[uuid.UUID, int, str] = None
llm_id: Optional[str] = Field(None, description="LLM模型配置ID") llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
audio_id: Optional[str] = Field(None, description="语音模型ID")
vision_id: Optional[str] = Field(None, description="视觉模型ID")
video_id: Optional[str] = Field(None, description="视频模型ID")
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID") embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID") rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
enable_llm_dedup_blockwise: Optional[bool] = None enable_llm_dedup_blockwise: Optional[bool] = None
@@ -322,14 +332,14 @@ class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数
class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用的模型 class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用的模型
# 遗忘引擎配置参数更新模型 # 遗忘引擎配置参数更新模型
config_id:Union[uuid.UUID, int, str] = None config_id: Union[uuid.UUID, int, str] = None
lambda_time: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="最低保持度0-1 小数;默认 0.5") lambda_time: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="最低保持度0-1 小数;默认 0.5")
lambda_mem: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="遗忘率0-1 小数;默认 0.5") lambda_mem: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="遗忘率0-1 小数;默认 0.5")
offset: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="偏移度0-1 小数;默认 0.0") offset: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="偏移度0-1 小数;默认 0.0")
class ConfigPilotRun(BaseModel): # 试运行触发请求模型 class ConfigPilotRun(BaseModel): # 试运行触发请求模型
config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID唯一支持UUID、整数或字符串") config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID唯一支持UUID、整数或字符串")
dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填") dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填")
custom_text: Optional[str] = Field(None, description="自定义输入文本,当配置关联本体场景时使用此字段进行试运行") custom_text: Optional[str] = Field(None, description="自定义输入文本,当配置关联本体场景时使用此字段进行试运行")
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
@@ -364,11 +374,11 @@ def ok(msg: str = "OK", data: Optional[Any] = None, time: Optional[int] = None)
def fail( def fail(
msg: str, msg: str,
error_code: str = "ERROR", error_code: str = "ERROR",
data: Optional[Any] = None, data: Optional[Any] = None,
time: Optional[int] = None, time: Optional[int] = None,
query_preview: Optional[str] = None, query_preview: Optional[str] = None,
) -> ApiResponse: ) -> ApiResponse:
payload = data payload = data
if query_preview is not None: if query_preview is not None:
@@ -387,6 +397,7 @@ def fail(
time=time or _now_ms(), time=time or _now_ms(),
) )
class GenerateCacheRequest(BaseModel): class GenerateCacheRequest(BaseModel):
"""缓存生成请求模型""" """缓存生成请求模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
@@ -432,7 +443,7 @@ class ForgettingConfigUpdateRequest(BaseModel):
"""遗忘引擎配置更新请求模型""" """遗忘引擎配置更新请求模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
config_id: Union[uuid.UUID, int,str] = Field(..., description="配置唯一标识UUID或int)") config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识UUID或int)")
decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d") decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d")
lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数") lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数")
lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数") lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数")
@@ -472,7 +483,8 @@ class ForgettingStatsResponse(BaseModel):
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
activation_metrics: Dict[str, Any] = Field(..., description="激活值相关指标") activation_metrics: Dict[str, Any] = Field(..., description="激活值相关指标")
node_distribution: Dict[str, int] = Field(..., description="节点类型分布") node_distribution: Dict[str, int] = Field(..., description="节点类型分布")
recent_trends: List[ForgettingCycleHistoryPoint] = Field(..., description="最近7个日期的遗忘趋势数据每天取最后一次执行") recent_trends: List[ForgettingCycleHistoryPoint] = Field(...,
description="最近7个日期的遗忘趋势数据每天取最后一次执行")
pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表前20个满足遗忘条件的节点") pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表前20个满足遗忘条件的节点")
timestamp: int = Field(..., description="统计时间(时间戳)") timestamp: int = Field(..., description="统计时间(时间戳)")

View File

@@ -81,6 +81,12 @@ class ModelConfig(ModelConfigBase):
updated_at: datetime.datetime updated_at: datetime.datetime
api_keys: List["ModelApiKey"] = [] api_keys: List["ModelApiKey"] = []
@staticmethod
def mask_api_key(key: str, prefix: int = 4, suffix: int = 4) -> str:
if not key or len(key) <= prefix + suffix:
return "*" * len(key)
return key[:prefix] + "*" * (len(key) - prefix - suffix) + key[-suffix:]
@field_validator("api_keys", mode="after") @field_validator("api_keys", mode="after")
@classmethod @classmethod
def filter_active_api_keys(cls, api_keys: List["ModelApiKey"]) -> List["ModelApiKey"]: def filter_active_api_keys(cls, api_keys: List["ModelApiKey"]) -> List["ModelApiKey"]:
@@ -90,6 +96,15 @@ class ModelConfig(ModelConfigBase):
def _serialize_created_at(self, dt: datetime.datetime | None): def _serialize_created_at(self, dt: datetime.datetime | None):
return int(dt.timestamp() * 1000) if dt else None return int(dt.timestamp() * 1000) if dt else None
@field_serializer("api_keys", when_used="json")
def _serialize_api_keys(self, api_keys: List["ModelApiKey"]):
result = []
for api_key in api_keys:
data = api_key.model_dump()
data["api_key"] = self.mask_api_key(api_key.api_key)
result.append(data)
return result
@field_serializer("updated_at", when_used="json") @field_serializer("updated_at", when_used="json")
def _serialize_updated_at(self, dt: datetime.datetime): def _serialize_updated_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None return int(dt.timestamp() * 1000) if dt else None
@@ -166,8 +181,8 @@ class ModelApiKey(ModelApiKeyBase):
self.model_config_ids = [ self.model_config_ids = [
mc.id for mc in self.model_configs mc.id for mc in self.model_configs
if hasattr(mc, 'id') if hasattr(mc, 'id')
and not getattr(mc, 'is_composite', False) and not getattr(mc, 'is_composite', False)
and getattr(mc, 'name', None) == self.model_name and getattr(mc, 'name', None) == self.model_name
] ]
# 情况2字典列表 # 情况2字典列表
elif isinstance(self.model_configs, list): elif isinstance(self.model_configs, list):
@@ -193,7 +208,6 @@ class ModelApiKey(ModelApiKeyBase):
validate_assignment=True # 确保赋值触发校验 validate_assignment=True # 确保赋值触发校验
) )
@field_serializer("created_at", when_used="json") @field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime): def _serialize_created_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None return int(dt.timestamp() * 1000) if dt else None
@@ -211,6 +225,7 @@ class ModelConfigQuery(BaseModel):
"""模型配置查询Schema""" """模型配置查询Schema"""
type: Optional[List[ModelType]] = Field(None, description="模型类型筛选(支持多个)") type: Optional[List[ModelType]] = Field(None, description="模型类型筛选(支持多个)")
provider: Optional[ModelProvider] = Field(None, description="提供商筛选(通过API Key)") provider: Optional[ModelProvider] = Field(None, description="提供商筛选(通过API Key)")
capability: Optional[List[str]] = Field(None, description="能力筛选(支持多个)")
is_active: Optional[bool] = Field(None, description="激活状态筛选") is_active: Optional[bool] = Field(None, description="激活状态筛选")
is_public: Optional[bool] = Field(None, description="公开状态筛选") is_public: Optional[bool] = Field(None, description="公开状态筛选")
search: Optional[str] = Field(None, description="搜索关键词", max_length=255) search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
@@ -228,6 +243,7 @@ class ModelConfigQueryNew(BaseModel):
is_composite: Optional[bool] = Field(None, description="组合模型筛选") is_composite: Optional[bool] = Field(None, description="组合模型筛选")
search: Optional[str] = Field(None, description="搜索关键词", max_length=255) search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
class ModelMarketplace(BaseModel): class ModelMarketplace(BaseModel):
"""模型广场响应Schema""" """模型广场响应Schema"""
llm_models: List[ModelConfig] = [] llm_models: List[ModelConfig] = []
@@ -327,6 +343,7 @@ class ModelBaseQuery(BaseModel):
is_deprecated: Optional[bool] = Field(None, description="是否弃用") is_deprecated: Optional[bool] = Field(None, description="是否弃用")
search: Optional[str] = Field(None, description="搜索关键词", max_length=255) search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
class ModelInfo(BaseModel): class ModelInfo(BaseModel):
"""模型信息Schema""" """模型信息Schema"""
model_name: str = Field(..., description="模型名称") model_name: str = Field(..., description="模型名称")
@@ -336,4 +353,3 @@ class ModelInfo(BaseModel):
is_omni: bool = Field(default=False, description="是否为omni模型") is_omni: bool = Field(default=False, description="是否为omni模型")
model_type: ModelType = Field(..., description="模型类型") model_type: ModelType = Field(..., description="模型类型")
capability: List[str] = Field(default_factory=list, description="模型能力列表") capability: List[str] = Field(default_factory=list, description="模型能力列表")

View File

@@ -140,7 +140,7 @@ class AppChatService:
processed_files = None processed_files = None
if files: if files:
multimodal_service = MultimodalService(self.db, model_info) multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files) processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件") logger.info(f"处理了 {len(processed_files)} 个文件")
# 调用 Agent支持多模态 # 调用 Agent支持多模态
@@ -343,7 +343,7 @@ class AppChatService:
processed_files = None processed_files = None
if files: if files:
multimodal_service = MultimodalService(self.db, model_info) multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files) processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件") logger.info(f"处理了 {len(processed_files)} 个文件")
# 流式调用 Agent支持多模态同时并行启动 TTS # 流式调用 Agent支持多模态同时并行启动 TTS

View File

@@ -603,7 +603,7 @@ class AgentRunService:
# 获取 provider 信息 # 获取 provider 信息
provider = api_key_config.get("provider", "openai") provider = api_key_config.get("provider", "openai")
multimodal_service = MultimodalService(self.db, model_info) multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files) processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}") logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}")
# 7. 知识库检索 # 7. 知识库检索
@@ -846,7 +846,7 @@ class AgentRunService:
# 获取 provider 信息 # 获取 provider 信息
provider = api_key_config.get("provider", "openai") provider = api_key_config.get("provider", "openai")
multimodal_service = MultimodalService(self.db, model_info) multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files) processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}") logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}")
# 7. 知识库检索 # 7. 知识库检索

View File

@@ -19,32 +19,35 @@ from typing import Any, AsyncGenerator, Dict, List, Optional
from uuid import UUID from uuid import UUID
import redis import redis
from langchain_core.messages import AIMessage, HumanMessage from langchain_core.messages import HumanMessage
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.cache import InterestMemoryCache
from app.core.config import settings from app.core.config import settings
from app.core.logging_config import get_config_logger, get_logger from app.core.logging_config import get_config_logger, get_logger
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
from app.core.memory.agent.logger_file.log_streamer import LogStreamer from app.core.memory.agent.logger_file.log_streamer import LogStreamer
from app.core.memory.agent.utils.messages_tools import ( from app.core.memory.agent.utils.messages_tools import (
merge_multiple_search_results, merge_multiple_search_results,
reorder_output_results, reorder_output_results,
) )
from app.core.memory.agent.utils.type_classifier import status_typle from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.agent.utils.write_tools import write as write_neo4j
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas import FileInput
from app.schemas.memory_agent_schema import Write_UserInput from app.schemas.memory_agent_schema import Write_UserInput
from app.schemas.memory_config_schema import ConfigurationError from app.schemas.memory_config_schema import ConfigurationError
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import ( from app.services.memory_konwledges_server import (
write_rag, write_rag,
) )
from app.services.memory_perceptual_service import MemoryPerceptualService
try: try:
from app.core.memory.utils.log.audit_logger import audit_logger from app.core.memory.utils.log.audit_logger import audit_logger
@@ -267,8 +270,16 @@ class MemoryAgentService:
logger.info("Log streaming completed, cleaning up resources") logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic # LogStreamer uses context manager for file handling, so cleanup is automatic
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID] | int, async def write_memory(
db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str: self,
end_user_id: str,
messages: list[dict],
config_id: Optional[uuid.UUID] | int,
db: Session,
storage_type: str,
user_rag_memory_id: str,
language: str = "zh"
) -> str:
""" """
Process write operation with config_id Process write operation with config_id
@@ -297,8 +308,8 @@ class MemoryAgentService:
config_id = connected_config.get("memory_config_id") config_id = connected_config.get("memory_config_id")
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}") logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
if config_id is None and workspace_id is None: if config_id is None and workspace_id is None:
raise ValueError( raise ValueError(f"No memory configuration found for end_user {end_user_id}. "
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.") f"Please ensure the user has a connected memory configuration.")
except Exception as e: except Exception as e:
if "No memory configuration found" in str(e): if "No memory configuration found" in str(e):
raise # Re-raise our specific error raise # Re-raise our specific error
@@ -334,45 +345,57 @@ class MemoryAgentService:
raise ValueError(error_msg) raise ValueError(error_msg)
perceptual_serivce = MemoryPerceptualService(db)
for message in messages:
message["file_content"] = []
for file in message["files"]:
file_object = await perceptual_serivce.generate_perceptual_memory(
end_user_id=end_user_id,
memory_config=memory_config,
file=FileInput(**file)
)
if file_object is None:
continue
message["file_content"].append((file_object, file["type"]))
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
try: try:
if storage_type == "rag": if storage_type == "rag":
# For RAG storage, convert messages to single string # For RAG storage, convert messages to single string
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) await write_rag(end_user_id, message_text, user_rag_memory_id)
result = await write_rag(end_user_id, message_text, user_rag_memory_id) return "success"
return result
else: else:
async with make_write_graph() as graph: await write_neo4j(
config = {"configurable": {"thread_id": end_user_id}} end_user_id=end_user_id,
# Convert structured messages to LangChain messages messages=messages,
langchain_messages = [] memory_config=memory_config,
for msg in messages: ref_id='',
if msg['role'] == 'user': language=language
langchain_messages.append(HumanMessage(content=msg['content'])) )
elif msg['role'] == 'assistant': for lang in ["zh", "en"]:
langchain_messages.append(AIMessage(content=msg['content'])) deleted = await InterestMemoryCache.delete_interest_distribution(
# 初始状态 - 包含所有必要字段 end_user_id, lang
initial_state = { )
"messages": langchain_messages, if deleted:
"end_user_id": end_user_id, logger.info(
"memory_config": memory_config, f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}")
"language": language for message in messages:
message["file_content"] = [
perceptual[0].file_path for perceptual in message["file_content"]
]
return self.writer_messages_deal(
"success",
start_time,
end_user_id,
config_id,
message_text,
{
"status": "success",
"data": messages,
"config_id": memory_config.config_id,
"config_name": memory_config.config_name
} }
)
# 获取节点更新信息
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
config=config
):
for node_name, node_data in update_event.items():
if 'save_neo4j' == node_name:
massages = node_data
massagesstatus = massages.get('write_result')['status']
contents = massages.get('write_result')
# Convert messages back to string for logging
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text,
contents)
except Exception as e: except Exception as e:
# Ensure proper error handling and logging # Ensure proper error handling and logging
error_msg = f"Write operation failed: {str(e)}" error_msg = f"Write operation failed: {str(e)}"

View File

@@ -38,9 +38,9 @@ class MemoryAPIService:
self.db = db self.db = db
def validate_end_user( def validate_end_user(
self, self,
end_user_id: str, end_user_id: str,
workspace_id: uuid.UUID workspace_id: uuid.UUID
) -> EndUser: ) -> EndUser:
"""Validate that end_user exists and belongs to the workspace. """Validate that end_user exists and belongs to the workspace.
@@ -125,13 +125,13 @@ class MemoryAPIService:
logger.warning(f"Failed to update memory_config_id for end_user {end_user_id}: {e}") logger.warning(f"Failed to update memory_config_id for end_user {end_user_id}: {e}")
async def write_memory( async def write_memory(
self, self,
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
end_user_id: str, end_user_id: str,
message: str, message: str,
config_id: str, config_id: str,
storage_type: str = "neo4j", storage_type: str = "neo4j",
user_rag_memory_id: Optional[str] = None, user_rag_memory_id: Optional[str] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Write memory with validation. """Write memory with validation.
@@ -171,7 +171,7 @@ class MemoryAPIService:
config_id=config_id, config_id=config_id,
db=self.db, db=self.db,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id or "" user_rag_memory_id=user_rag_memory_id or "",
) )
logger.info(f"Memory write successful for end_user: {end_user_id}") logger.info(f"Memory write successful for end_user: {end_user_id}")
@@ -206,14 +206,14 @@ class MemoryAPIService:
) )
async def read_memory( async def read_memory(
self, self,
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
end_user_id: str, end_user_id: str,
message: str, message: str,
search_switch: str = "0", search_switch: str = "0",
config_id: str = "", config_id: str = "",
storage_type: str = "neo4j", storage_type: str = "neo4j",
user_rag_memory_id: Optional[str] = None, user_rag_memory_id: Optional[str] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Read memory with validation. """Read memory with validation.
@@ -244,7 +244,6 @@ class MemoryAPIService:
# Update end user's memory_config_id # Update end user's memory_config_id
self._update_end_user_config(end_user_id, config_id) self._update_end_user_config(end_user_id, config_id)
try: try:
# Delegate to MemoryAgentService # Delegate to MemoryAgentService
result = await MemoryAgentService().read_memory( result = await MemoryAgentService().read_memory(
@@ -282,8 +281,8 @@ class MemoryAPIService:
) )
def list_memory_configs( def list_memory_configs(
self, self,
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""List all memory configs for a workspace. """List all memory configs for a workspace.

View File

@@ -154,10 +154,10 @@ class MemoryConfigService:
self.db = db self.db = db
def load_memory_config( def load_memory_config(
self, self,
config_id: Optional[UUID] = None, config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None,
service_name: str = "MemoryConfigService", service_name: str = "MemoryConfigService",
) -> MemoryConfig: ) -> MemoryConfig:
""" """
Load memory configuration from database with optional fallback. Load memory configuration from database with optional fallback.
@@ -194,8 +194,8 @@ class MemoryConfigService:
try: try:
# Use get_config_with_fallback if workspace_id is provided # Use get_config_with_fallback if workspace_id is provided
memory_config = None memory_config = None
validated_config_id = None
if workspace_id: if workspace_id:
validated_config_id = None
if config_id: if config_id:
try: try:
validated_config_id = _validate_config_id(config_id, self.db) validated_config_id = _validate_config_id(config_id, self.db)
@@ -243,10 +243,10 @@ class MemoryConfigService:
# Helper function to validate model with workspace fallback # Helper function to validate model with workspace fallback
def _validate_model_with_fallback( def _validate_model_with_fallback(
model_id: str, model_id: str,
model_type: str, model_type: str,
workspace_default: str, workspace_default: str,
required: bool = False required: bool = False
) -> tuple: ) -> tuple:
"""Validate model ID, falling back to workspace default if invalid. """Validate model ID, falling back to workspace default if invalid.
@@ -343,6 +343,35 @@ class MemoryConfigService:
if memory_config.rerank_id or workspace.rerank: if memory_config.rerank_id or workspace.rerank:
logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s") logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s")
vision_uuid, vision_name = validate_and_resolve_model_id(
memory_config.vision_id,
"llm",
self.db,
workspace.tenant_id,
required=False,
config_id=validated_config_id,
workspace_id=workspace.id,
)
audio_uuid, audio_name = validate_and_resolve_model_id(
memory_config.audio_id,
"llm",
self.db,
workspace.tenant_id,
required=False,
config_id=validated_config_id,
workspace_id=workspace.id,
)
video_uuid, video_name = validate_and_resolve_model_id(
memory_config.video_id,
"llm",
self.db,
workspace.tenant_id,
required=False,
config_id=validated_config_id,
workspace_id=workspace.id,
)
# Create immutable MemoryConfig object # Create immutable MemoryConfig object
config = MemoryConfig( config = MemoryConfig(
config_id=memory_config.config_id, config_id=memory_config.config_id,
@@ -356,6 +385,12 @@ class MemoryConfigService:
embedding_model_name=embedding_name, embedding_model_name=embedding_name,
rerank_model_id=rerank_uuid, rerank_model_id=rerank_uuid,
rerank_model_name=rerank_name, rerank_model_name=rerank_name,
video_model_id=video_uuid,
video_model_name=video_name,
vision_model_id=vision_uuid,
vision_model_name=vision_name,
audio_model_id=audio_uuid,
audio_model_name=audio_name,
storage_type=workspace.storage_type or "neo4j", storage_type=workspace.storage_type or "neo4j",
chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker", chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker",
reflexion_enabled=memory_config.enable_self_reflexion or False, reflexion_enabled=memory_config.enable_self_reflexion or False,
@@ -364,24 +399,31 @@ class MemoryConfigService:
reflexion_baseline=memory_config.baseline or "Time", reflexion_baseline=memory_config.baseline or "Time",
loaded_at=datetime.now(), loaded_at=datetime.now(),
# Pipeline config: Deduplication # Pipeline config: Deduplication
enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False, enable_llm_dedup_blockwise=bool(
enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False, memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
enable_llm_disambiguation=bool(
memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True, deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True,
t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8, t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8,
t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8, t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8,
t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8, t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8,
# Pipeline config: Statement extraction # Pipeline config: Statement extraction
statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2, statement_granularity=int(
include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False, memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000, include_dialogue_context=bool(
memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
max_dialogue_context_chars=int(
memory_config.max_context) if memory_config.max_context is not None else 1000,
# Pipeline config: Forgetting engine # Pipeline config: Forgetting engine
lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5, lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5,
lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5, lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5,
offset=float(memory_config.offset) if memory_config.offset is not None else 0.0, offset=float(memory_config.offset) if memory_config.offset is not None else 0.0,
# Pipeline config: Pruning # Pipeline config: Pruning
pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False, pruning_enabled=bool(
memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
pruning_scene=memory_config.pruning_scene or "education", pruning_scene=memory_config.pruning_scene or "education",
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5, pruning_threshold=float(
memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
# Ontology scene association # Ontology scene association
scene_id=memory_config.scene_id, scene_id=memory_config.scene_id,
ontology_class_infos=_load_ontology_class_infos(self.db, memory_config.scene_id), ontology_class_infos=_load_ontology_class_infos(self.db, memory_config.scene_id),
@@ -598,8 +640,8 @@ class MemoryConfigService:
return None return None
def get_workspace_default_config( def get_workspace_default_config(
self, self,
workspace_id: UUID workspace_id: UUID
) -> Optional["MemoryConfigModel"]: ) -> Optional["MemoryConfigModel"]:
"""Get workspace default memory config. """Get workspace default memory config.
@@ -623,9 +665,9 @@ class MemoryConfigService:
return config return config
def get_config_with_fallback( def get_config_with_fallback(
self, self,
memory_config_id: Optional[UUID], memory_config_id: Optional[UUID],
workspace_id: UUID workspace_id: UUID
) -> Optional["MemoryConfigModel"]: ) -> Optional["MemoryConfigModel"]:
"""Get memory config with fallback to workspace default. """Get memory config with fallback to workspace default.
@@ -663,9 +705,9 @@ class MemoryConfigService:
return config return config
def delete_config( def delete_config(
self, self,
config_id: UUID | int, config_id: UUID | int,
force: bool = False force: bool = False
) -> dict: ) -> dict:
"""Delete memory config with protection against in-use configs. """Delete memory config with protection against in-use configs.
@@ -800,9 +842,9 @@ class MemoryConfigService:
# ==================== 记忆配置提取方法 ==================== # ==================== 记忆配置提取方法 ====================
def extract_memory_config_id( def extract_memory_config_id(
self, self,
app_type: str, app_type: str,
config: dict config: dict
) -> tuple[Optional[uuid.UUID], bool]: ) -> tuple[Optional[uuid.UUID], bool]:
"""从发布配置中提取 memory_config_id根据应用类型分发 """从发布配置中提取 memory_config_id根据应用类型分发
@@ -828,8 +870,8 @@ class MemoryConfigService:
return None, False return None, False
def _extract_memory_config_id_from_agent( def _extract_memory_config_id_from_agent(
self, self,
config: dict config: dict
) -> tuple[Optional[uuid.UUID], bool]: ) -> tuple[Optional[uuid.UUID], bool]:
"""从 Agent 应用配置中提取 memory_config_id """从 Agent 应用配置中提取 memory_config_id
@@ -888,8 +930,8 @@ class MemoryConfigService:
return None, False return None, False
def _extract_memory_config_id_from_workflow( def _extract_memory_config_id_from_workflow(
self, self,
config: dict config: dict
) -> tuple[Optional[uuid.UUID], bool]: ) -> tuple[Optional[uuid.UUID], bool]:
"""从 Workflow 应用配置中提取 memory_config_id """从 Workflow 应用配置中提取 memory_config_id

View File

@@ -341,7 +341,7 @@ async def memory_konwledges_up(
) )
db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user) db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user)
return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful") return db_document
async def create_document_chunk( async def create_document_chunk(
@@ -350,7 +350,7 @@ async def create_document_chunk(
create_data: ChunkCreate, create_data: ChunkCreate,
db: Session, db: Session,
current_user: User current_user: User
): ) -> DocumentChunk:
""" """
创建文档块 创建文档块
@@ -439,10 +439,10 @@ async def create_document_chunk(
db_document.chunk_num += 1 db_document.chunk_num += 1
db.commit() db.commit()
return success(data=chunk, msg="文档块创建成功") return chunk
async def write_rag(end_user_id, message, user_rag_memory_id): async def write_rag(end_user_id, message, user_rag_memory_id) -> DocumentChunk:
""" """
将消息写入 RAG 知识库 将消息写入 RAG 知识库
@@ -482,11 +482,11 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt") document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt")
print('======', document) print('======', document)
api_logger.info(f"查找文档结果: document_id={document}") api_logger.info(f"查找文档结果: document_id={document}")
create_chunks = ChunkCreate(content=message)
if document is not None: if document is not None:
# 文档已存在,直接添加新块 # 文档已存在,直接添加新块
api_logger.info(f"文档已存在,添加新块: document_id={document}") api_logger.info(f"文档已存在,添加新块: document_id={document}")
create_chunks = ChunkCreate(content=message)
result = await create_document_chunk( result = await create_document_chunk(
kb_id=kb_uuid, kb_id=kb_uuid,
document_id=uuid.UUID(document), document_id=uuid.UUID(document),
@@ -498,13 +498,20 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
else: else:
# 文档不存在,创建新文档 # 文档不存在,创建新文档
api_logger.info(f"文档不存在,创建新文档: end_user_id={end_user_id}") api_logger.info(f"文档不存在,创建新文档: end_user_id={end_user_id}")
result = await memory_konwledges_up( document = await memory_konwledges_up(
kb_id=user_rag_memory_id, kb_id=user_rag_memory_id,
parent_id=user_rag_memory_id, parent_id=user_rag_memory_id,
create_data=create_data, create_data=create_data,
db=db, db=db,
current_user=current_user current_user=current_user
) )
result = await create_document_chunk(
kb_id=kb_uuid,
document_id=document.id,
create_data=create_chunks,
db=db,
current_user=current_user
)
# 重新查询刚创建的文档ID # 重新查询刚创建的文档ID
new_document_id = find_document_id_by_kb_and_filename( new_document_id = find_document_id_by_kb_and_filename(
db=db, db=db,

View File

@@ -12,11 +12,12 @@ from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.core.models import RedBearLLM, RedBearModelConfig from app.core.models import RedBearLLM, RedBearModelConfig
from app.models import FileMetadata from app.models import FileMetadata, ModelApiKey, ModelType
from app.models.memory_perceptual_model import PerceptualType, FileStorageService from app.models.memory_perceptual_model import PerceptualType, FileStorageService
from app.models.prompt_optimizer_model import RoleType from app.models.prompt_optimizer_model import RoleType
from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository
from app.schemas import FileType from app.schemas import FileType, FileInput
from app.schemas.memory_config_schema import MemoryConfig
from app.schemas.memory_perceptual_schema import ( from app.schemas.memory_perceptual_schema import (
PerceptualQuerySchema, PerceptualQuerySchema,
PerceptualTimelineResponse, PerceptualTimelineResponse,
@@ -24,6 +25,8 @@ from app.schemas.memory_perceptual_schema import (
AudioModal, Content, VideoModal, TextModal AudioModal, Content, VideoModal, TextModal
) )
from app.schemas.model_schema import ModelInfo from app.schemas.model_schema import ModelInfo
from app.services.model_service import ModelApiKeyService
from app.services.multimodal_service import MultimodalService
business_logger = get_business_logger() business_logger = get_business_logger()
@@ -195,21 +198,58 @@ class MemoryPerceptualService:
business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}") business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}")
raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR) raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR)
def _get_mutlimodal_client(
self,
file_type: FileType,
config: MemoryConfig
) -> tuple[RedBearLLM | None, ModelApiKey | None]:
model_config = None
if file_type == FileType.AUDIO:
model_config = ModelApiKeyService.get_available_api_key(
self.db,
config.audio_model_id
)
elif file_type == FileType.VIDEO:
model_config = ModelApiKeyService.get_available_api_key(
self.db,
config.video_model_id
)
elif file_type == FileType.DOCUMENT:
model_config = ModelApiKeyService.get_available_api_key(
self.db,
config.llm_model_id
)
elif file_type == FileType.IMAGE:
model_config = ModelApiKeyService.get_available_api_key(
self.db,
config.vision_model_id
)
llm = None
if model_config:
llm = RedBearLLM(
RedBearModelConfig(
model_name=model_config.model_name,
provider=model_config.provider,
api_key=model_config.api_key,
base_url=model_config.api_base,
is_omni=model_config.is_omni
)
)
return llm, model_config
async def generate_perceptual_memory( async def generate_perceptual_memory(
self, self,
end_user_id: str, end_user_id: str,
model_config: ModelInfo, memory_config: MemoryConfig,
file_type: str, file: FileInput
file_url: str,
file_message: dict,
): ):
memories = self.repository.get_by_url(file_url) memories = self.repository.get_by_url(file.url)
if memories: if memories:
business_logger.info(f"Perceptual memory already exists: {file_url}") business_logger.info(f"Perceptual memory already exists: {file.url}")
if end_user_id not in [memory.end_user_id for memory in memories]: if end_user_id not in [memory.end_user_id for memory in memories]:
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}") business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
memory_cache = memories[0] memory_cache = memories[0]
self.repository.create_perceptual_memory( memory = self.repository.create_perceptual_memory(
end_user_id=uuid.UUID(end_user_id), end_user_id=uuid.UUID(end_user_id),
perceptual_type=PerceptualType(memory_cache.perceptual_type), perceptual_type=PerceptualType(memory_cache.perceptual_type),
file_path=memory_cache.file_path, file_path=memory_cache.file_path,
@@ -219,20 +259,33 @@ class MemoryPerceptualService:
meta_data=memory_cache.meta_data meta_data=memory_cache.meta_data
) )
self.db.commit() self.db.commit()
return memory
return else:
llm = RedBearLLM(RedBearModelConfig( for memory in memories:
if memory.end_user_id == uuid.UUID(end_user_id):
return memory
llm, model_config = self._get_mutlimodal_client(file.type, memory_config)
multimodel_service = MultimodalService(self.db, ModelInfo(
model_name=model_config.model_name, model_name=model_config.model_name,
provider=model_config.provider, provider=model_config.provider,
api_key=model_config.api_key, api_key=model_config.api_key,
base_url=model_config.api_base, api_base=model_config.api_base,
is_omni=model_config.is_omni is_omni=model_config.is_omni,
), type=model_config.model_type) capability=model_config.capability,
model_type=ModelType.LLM
))
file_message = await multimodel_service.process_files(
files=[file]
)
if not file_message:
logger.warning(f"Unsupport file type {file}, model capability: {model_config.capability}")
return None
file_message = file_message[0]
try: try:
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt') prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f: with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
opt_system_prompt = f.read() opt_system_prompt = f.read()
rendered_system_message = Template(opt_system_prompt).render(file_type=file_type, language='zh') rendered_system_message = Template(opt_system_prompt).render(file_type=file.type, language='zh')
except FileNotFoundError: except FileNotFoundError:
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND) raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
messages = [ messages = [
@@ -242,8 +295,22 @@ class MemoryPerceptualService:
]} ]}
] ]
result = await llm.ainvoke(messages) result = await llm.ainvoke(messages)
content = json_repair.repair_json(result.content, return_objects=True) content = result.content
path = urlparse(file_url).path final_output = ""
if isinstance(content, list):
for msg in content:
if isinstance(msg, dict):
final_output += msg.get("text", "")
elif isinstance(msg, str):
final_output += msg
elif isinstance(content, dict):
final_output += content.get("text", "")
elif isinstance(content, str):
final_output = content
else:
raise ValueError(f"Unexcept Model Output Type: {result.content}")
content = json_repair.repair_json(final_output, return_objects=True)
path = urlparse(file.url).path
filename = os.path.basename(path) filename = os.path.basename(path)
filename = unquote(filename) filename = unquote(filename)
file_ext = os.path.splitext(filename)[1] file_ext = os.path.splitext(filename)[1]
@@ -252,21 +319,21 @@ class MemoryPerceptualService:
stmt = select(FileMetadata).where( stmt = select(FileMetadata).where(
FileMetadata.id == file_id FileMetadata.id == file_id
) )
file = self.db.execute(stmt).scalar_one_or_none() file_obj = self.db.execute(stmt).scalar_one_or_none()
if file: if file_obj:
filename = file.file_name filename = file_obj.file_name
file_ext = file.file_ext file_ext = file_obj.file_ext
except ValueError: except ValueError:
business_logger.debug(f"Remote file, file_id={filename}") business_logger.debug(f"Remote file, file_id={filename}")
if not file_ext: if not file_ext:
if file_type == FileType.AUDIO: if file.type == FileType.AUDIO:
file_ext = ".mp3" file_ext = ".mp3"
elif file_type == FileType.VIDEO: elif file.type == FileType.VIDEO:
file_ext = ".mp4" file_ext = ".mp4"
elif file_type == FileType.DOCUMENT: elif file.type == FileType.DOCUMENT:
file_ext = ".txt" file_ext = ".txt"
elif file_type == FileType.IMAGE: elif file.type == FileType.IMAGE:
file_ext = ".jpg" file_ext = ".jpg"
filename += file_ext filename += file_ext
file_content = { file_content = {
@@ -274,11 +341,11 @@ class MemoryPerceptualService:
"topic": content.get("topic"), "topic": content.get("topic"),
"domain": content.get("domain") "domain": content.get("domain")
} }
if file_type in [FileType.IMAGE, FileType.VIDEO]: if file.type in [FileType.IMAGE, FileType.VIDEO]:
file_modalities = { file_modalities = {
"scene": content.get("scene", []) "scene": content.get("scene", [])
} }
elif file_type in [FileType.DOCUMENT]: elif file.type in [FileType.DOCUMENT]:
file_modalities = { file_modalities = {
"section_count": content.get("section_count", 0), "section_count": content.get("section_count", 0),
"title": content.get("title", ""), "title": content.get("title", ""),
@@ -288,10 +355,10 @@ class MemoryPerceptualService:
file_modalities = { file_modalities = {
"speaker_count": content.get("speaker_count", 0) "speaker_count": content.get("speaker_count", 0)
} }
self.repository.create_perceptual_memory( memory = self.repository.create_perceptual_memory(
end_user_id=uuid.UUID(end_user_id), end_user_id=uuid.UUID(end_user_id),
perceptual_type=PerceptualType.trans_from_file_type(file_type), perceptual_type=PerceptualType.trans_from_file_type(file.type),
file_path=file_url, file_path=file.url,
file_name=filename, file_name=filename,
file_ext=file_ext, file_ext=file_ext,
summary=content.get('summary', ""), summary=content.get('summary', ""),
@@ -301,3 +368,4 @@ class MemoryPerceptualService:
} }
) )
self.db.commit() self.db.commit()
return memory

View File

@@ -11,9 +11,11 @@ import time
from datetime import datetime from datetime import datetime
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
from dotenv import load_dotenv
from sqlalchemy.orm import Session
from app.core.logging_config import get_config_logger, get_logger from app.core.logging_config import get_config_logger, get_logger
from app.core.memory.analytics.hot_memory_tags import ( from app.core.memory.analytics.hot_memory_tags import (
get_hot_memory_tags,
get_raw_tags_from_db, get_raw_tags_from_db,
filter_tags_with_llm, filter_tags_with_llm,
) )
@@ -32,8 +34,6 @@ from app.schemas.memory_storage_schema import (
) )
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
from app.utils.sse_utils import format_sse_message from app.utils.sse_utils import format_sse_message
from dotenv import load_dotenv
from sqlalchemy.orm import Session
logger = get_logger(__name__) logger = get_logger(__name__)
config_logger = get_config_logger() config_logger = get_config_logger()
@@ -69,7 +69,7 @@ class MemoryStorageService:
return result return result
class DataConfigService: # 数据配置服务类PostgreSQL class DataConfigService: # 数据配置服务类PostgreSQL
"""Service layer for config params CRUD. """Service layer for config params CRUD.
使用 SQLAlchemy ORM 进行数据库操作。 使用 SQLAlchemy ORM 进行数据库操作。
@@ -114,7 +114,7 @@ class DataConfigService: # 数据配置服务类PostgreSQL
return data_list return data_list
# --- Create --- # --- Create ---
def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述) def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述)
# 业务层检查同一工作空间下是否已存在同名配置 # 业务层检查同一工作空间下是否已存在同名配置
if params.workspace_id and params.config_name: if params.workspace_id and params.config_name:
from app.models.memory_config_model import MemoryConfig from app.models.memory_config_model import MemoryConfig
@@ -183,20 +183,20 @@ class DataConfigService: # 数据配置服务类PostgreSQL
return None return None
# --- Delete --- # --- Delete ---
def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数按配置ID def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数按配置ID
success = MemoryConfigRepository.delete(self.db, key.config_id) success = MemoryConfigRepository.delete(self.db, key.config_id)
if not success: if not success:
raise ValueError("未找到配置") raise ValueError("未找到配置")
return {"affected": 1} return {"affected": 1}
# --- Update --- # --- Update ---
def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数 def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数
config = MemoryConfigRepository.update(self.db, update) config = MemoryConfigRepository.update(self.db, update)
if not config: if not config:
raise ValueError("未找到配置") raise ValueError("未找到配置")
return {"affected": 1} return {"affected": 1}
def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数 def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数
config = MemoryConfigRepository.update_extracted(self.db, update) config = MemoryConfigRepository.update_extracted(self.db, update)
if not config: if not config:
raise ValueError("未找到配置") raise ValueError("未找到配置")
@@ -207,14 +207,14 @@ class DataConfigService: # 数据配置服务类PostgreSQL
# 使用新方法: MemoryForgetService.read_forgetting_config() 和 MemoryForgetService.update_forgetting_config() # 使用新方法: MemoryForgetService.read_forgetting_config() 和 MemoryForgetService.update_forgetting_config()
# --- Read --- # --- Read ---
def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数 def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数
result = MemoryConfigRepository.get_extracted_config(self.db, key.config_id) result = MemoryConfigRepository.get_extracted_config(self.db, key.config_id)
if not result: if not result:
raise ValueError("未找到配置") raise ValueError("未找到配置")
return result return result
# --- Read All --- # --- Read All ---
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数 def get_all(self, workspace_id=None) -> List[Dict[str, Any]]: # 获取所有配置参数
results = MemoryConfigRepository.get_all(self.db, workspace_id) results = MemoryConfigRepository.get_all(self.db, workspace_id)
# 检查并修正 pruning_scene 与 scene_name 不一致的记录 # 检查并修正 pruning_scene 与 scene_name 不一致的记录
@@ -241,11 +241,10 @@ class DataConfigService: # 数据配置服务类PostgreSQL
except (ValueError, TypeError): except (ValueError, TypeError):
config_id_old = None config_id_old = None
if config_id_old: if config_id_old:
memory_config=config_id_old memory_config = config_id_old
else: else:
memory_config=config.config_id memory_config = config.config_id
config_dict = { config_dict = {
"config_id": memory_config, "config_id": memory_config,
"config_name": config.config_name, "config_name": config.config_name,
@@ -289,7 +288,6 @@ class DataConfigService: # 数据配置服务类PostgreSQL
# 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式 # 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
return self._convert_timestamps_to_format(data_list) return self._convert_timestamps_to_format(data_list)
async def pilot_run_stream(self, payload: ConfigPilotRun, language: str = "zh") -> AsyncGenerator[str, None]: async def pilot_run_stream(self, payload: ConfigPilotRun, language: str = "zh") -> AsyncGenerator[str, None]:
""" """
流式执行试运行,产生 SSE 格式的进度事件 流式执行试运行,产生 SSE 格式的进度事件
@@ -344,11 +342,13 @@ class DataConfigService: # 数据配置服务类PostgreSQL
# 关联了本体场景,优先使用 custom_text # 关联了本体场景,优先使用 custom_text
if hasattr(payload, 'custom_text') and payload.custom_text: if hasattr(payload, 'custom_text') and payload.custom_text:
dialogue_text = payload.custom_text.strip() dialogue_text = payload.custom_text.strip()
logger.info(f"[PILOT_RUN_STREAM] Using custom_text for scene_id={memory_config.scene_id}, length: {len(dialogue_text)}") logger.info(
f"[PILOT_RUN_STREAM] Using custom_text for scene_id={memory_config.scene_id}, length: {len(dialogue_text)}")
else: else:
# 如果没有提供 custom_text回退到 dialogue_text # 如果没有提供 custom_text回退到 dialogue_text
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else "" dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
logger.info(f"[PILOT_RUN_STREAM] No custom_text provided, using dialogue_text for scene_id={memory_config.scene_id}") logger.info(
f"[PILOT_RUN_STREAM] No custom_text provided, using dialogue_text for scene_id={memory_config.scene_id}")
else: else:
# 没有关联本体场景,使用 dialogue_text # 没有关联本体场景,使用 dialogue_text
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else "" dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
@@ -360,7 +360,6 @@ class DataConfigService: # 数据配置服务类PostgreSQL
logger.info(f"[PILOT_RUN_STREAM] Final text preview: {dialogue_text[:100]}") logger.info(f"[PILOT_RUN_STREAM] Final text preview: {dialogue_text[:100]}")
# 步骤 2: 创建进度回调函数捕获管线进度 # 步骤 2: 创建进度回调函数捕获管线进度
# 使用队列在回调和生成器之间传递进度事件 # 使用队列在回调和生成器之间传递进度事件
progress_queue: asyncio.Queue = asyncio.Queue() progress_queue: asyncio.Queue = asyncio.Queue()
@@ -382,7 +381,8 @@ class DataConfigService: # 数据配置服务类PostgreSQL
try: try:
from app.services.pilot_run_service import run_pilot_extraction from app.services.pilot_run_service import run_pilot_extraction
logger.info(f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}") logger.info(
f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}")
await run_pilot_extraction( await run_pilot_extraction(
memory_config=memory_config, memory_config=memory_config,
dialogue_text=dialogue_text, dialogue_text=dialogue_text,
@@ -483,11 +483,10 @@ class DataConfigService: # 数据配置服务类PostgreSQL
"time": int(time.time() * 1000) "time": int(time.time() * 1000)
}) })
async def _compute_ontology_coverage( async def _compute_ontology_coverage(
self, self,
extracted_result: Dict[str, Any], extracted_result: Dict[str, Any],
memory_config, memory_config,
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
"""根据提取结果中的实体类型,与场景/通用本体类型做互斥分类统计。 """根据提取结果中的实体类型,与场景/通用本体类型做互斥分类统计。
@@ -580,8 +579,6 @@ class DataConfigService: # 数据配置服务类PostgreSQL
# -------------------- Neo4j Search & Analytics (fused from data_search_service.py) -------------------- # -------------------- Neo4j Search & Analytics (fused from data_search_service.py) --------------------
# Ensure env for connector (e.g., NEO4J_PASSWORD) # Ensure env for connector (e.g., NEO4J_PASSWORD)
load_dotenv()
_neo4j_connector = Neo4jConnector()
async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]: async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]:
@@ -701,10 +698,11 @@ async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]
) )
return result return result
async def analytics_hot_memory_tags( async def analytics_hot_memory_tags(
db: Session, db: Session,
current_user: User, current_user: User,
limit: int = 10 limit: int = 10
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
获取热门记忆标签按数量排序并返回前N个 获取热门记忆标签按数量排序并返回前N个
@@ -815,11 +813,11 @@ async def analytics_recent_activity_stats(workspace_id: Optional[str] = None) ->
source = "log" source = "log"
total = ( total = (
stats.get("chunk_count", 0) stats.get("chunk_count", 0)
+ stats.get("statements_count", 0) + stats.get("statements_count", 0)
+ stats.get("triplet_entities_count", 0) + stats.get("triplet_entities_count", 0)
+ stats.get("triplet_relations_count", 0) + stats.get("triplet_relations_count", 0)
+ stats.get("temporal_count", 0) + stats.get("temporal_count", 0)
) )
# 计算"最新一次活动多久前"(仅日志来源时有效) # 计算"最新一次活动多久前"(仅日志来源时有效)
@@ -845,5 +843,3 @@ async def analytics_recent_activity_stats(workspace_id: Optional[str] = None) ->
data = {"total": total, "stats": stats, "latest_relative": latest_relative, "source": source} data = {"total": total, "stats": stats, "latest_relative": latest_relative, "source": source}
return data return data

View File

@@ -9,17 +9,15 @@
- OpenAI: 支持 URL 和 base64 格式 - OpenAI: 支持 URL 和 base64 格式
""" """
import base64 import base64
import csv
import io import io
import uuid import json
import zipfile import zipfile
import chardet
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
import csv
import json
import PyPDF2 import PyPDF2
import chardet
import httpx import httpx
import magic import magic
import openpyxl import openpyxl
@@ -35,7 +33,6 @@ from app.models.file_metadata_model import FileMetadata
from app.schemas.app_schema import FileInput, FileType, TransferMethod from app.schemas.app_schema import FileInput, FileType, TransferMethod
from app.schemas.model_schema import ModelInfo from app.schemas.model_schema import ModelInfo
from app.services.audio_transcription_service import AudioTranscriptionService from app.services.audio_transcription_service import AudioTranscriptionService
from app.tasks import write_perceptual_memory
logger = get_business_logger() logger = get_business_logger()
@@ -342,15 +339,12 @@ class MultimodalService:
async def process_files( async def process_files(
self, self,
end_user_id: uuid.UUID | str,
files: Optional[List[FileInput]], files: Optional[List[FileInput]],
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
处理文件列表,返回 LLM 可用的格式 处理文件列表,返回 LLM 可用的格式
Args: Args:
end_user_id: 用户ID
files: 文件输入列表 files: 文件输入列表
Returns: Returns:
@@ -358,8 +352,6 @@ class MultimodalService:
""" """
if not files: if not files:
return [] return []
if isinstance(end_user_id, uuid.UUID):
end_user_id = str(end_user_id)
# 获取对应的策略 # 获取对应的策略
# dashscope 的 omni 模型使用 OpenAI 兼容格式 # dashscope 的 omni 模型使用 OpenAI 兼容格式
@@ -380,23 +372,15 @@ class MultimodalService:
if file.type == FileType.IMAGE and "vision" in self.capability: if file.type == FileType.IMAGE and "vision" in self.capability:
is_support, content = await self._process_image(file, strategy) is_support, content = await self._process_image(file, strategy)
result.append(content) result.append(content)
if is_support:
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
elif file.type == FileType.DOCUMENT: elif file.type == FileType.DOCUMENT:
is_support, content = await self._process_document(file, strategy) is_support, content = await self._process_document(file, strategy)
result.append(content) result.append(content)
if is_support:
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
elif file.type == FileType.AUDIO and "audio" in self.capability: elif file.type == FileType.AUDIO and "audio" in self.capability:
is_support, content = await self._process_audio(file, strategy) is_support, content = await self._process_audio(file, strategy)
result.append(content) result.append(content)
if is_support:
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
elif file.type == FileType.VIDEO and "video" in self.capability: elif file.type == FileType.VIDEO and "video" in self.capability:
is_support, content = await self._process_video(file, strategy) is_support, content = await self._process_video(file, strategy)
result.append(content) result.append(content)
if is_support:
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
else: else:
logger.warning(f"不支持的文件类型: {file.type}") logger.warning(f"不支持的文件类型: {file.type}")
except Exception as e: except Exception as e:
@@ -483,17 +467,6 @@ class MultimodalService:
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件provider={self.provider}") logger.info(f"成功处理 {len(result)}/{len(files)} 个文件provider={self.provider}")
return result return result
def write_perceptual_memory(
self,
end_user_id: str,
file_type: str,
file_url: str,
file_message: dict
):
"""写入感知记忆"""
if end_user_id and self.api_config:
write_perceptual_memory.delay(end_user_id, self.api_config.model_dump(), file_type, file_url, file_message)
async def _process_image(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]: async def _process_image(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]:
""" """
处理图片文件 处理图片文件

View File

@@ -297,9 +297,12 @@ async def run_pilot_extraction(
chunk_nodes, chunk_nodes,
statement_nodes, statement_nodes,
entity_nodes, entity_nodes,
_,
statement_chunk_edges, statement_chunk_edges,
statement_entity_edges, statement_entity_edges,
entity_edges, entity_edges,
_,
_
) = extraction_result ) = extraction_result
log_time("Extraction Pipeline", time.time() - step_start, log_file) log_time("Extraction Pipeline", time.time() - step_start, log_file)

View File

@@ -1887,7 +1887,8 @@ async def _extract_node_properties(label: str, properties: Dict[str, Any],node_
"Chunk": ["content", "created_at"], "Chunk": ["content", "created_at"],
"Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption","emotion_keywords","emotion_type","emotion_subject"], "Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption","emotion_keywords","emotion_type","emotion_subject"],
"ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption","aliases","connect_strength"], "ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption","aliases","connect_strength"],
"MemorySummary": ["summary", "content", "created_at", "caption"] # 添加 content 字段 "MemorySummary": ["summary", "content", "created_at", "caption"], # 添加 content 字段
"Perceptual": ["file_name", "file_path", "file_type", "domain", "topic", "keywords", "summary"]
} }
# 获取该节点类型的白名单字段 # 获取该节点类型的白名单字段

View File

@@ -20,6 +20,7 @@ from app.core.workflow.variable.base_variable import FileObject
from app.db import get_db from app.db import get_db
from app.models import App from app.models import App
from app.models.workflow_model import WorkflowConfig, WorkflowExecution from app.models.workflow_model import WorkflowConfig, WorkflowExecution
from app.repositories import knowledge_repository
from app.repositories.workflow_repository import ( from app.repositories.workflow_repository import (
WorkflowConfigRepository, WorkflowConfigRepository,
WorkflowExecutionRepository, WorkflowExecutionRepository,
@@ -29,6 +30,7 @@ from app.schemas import DraftRunRequest, FileInput
from app.services.conversation_service import ConversationService from app.services.conversation_service import ConversationService
from app.services.multi_agent_service import convert_uuids_to_str from app.services.multi_agent_service import convert_uuids_to_str
from app.services.multimodal_service import MultimodalService from app.services.multimodal_service import MultimodalService
from app.services.workspace_service import get_workspace_storage_type_without_auth
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -540,6 +542,25 @@ class WorkflowService:
mapped = internal_event mapped = internal_event
return mapped return mapped
def _get_memory_store_info(self, workspace_id: uuid.UUID) -> tuple[str, str]:
storage_type = get_workspace_storage_type_without_auth(self.db, workspace_id)
user_rag_memory_id = ""
if storage_type == "rag":
knowledge = knowledge_repository.get_knowledge_by_name(
db=self.db,
name="USER_RAG_MERORY",
workspace_id=workspace_id
)
if knowledge:
user_rag_memory_id = str(knowledge.id)
else:
logger.warning(
f"No knowledge base named 'USER_RAG_MEMORY' found, "
f"workspace_id: {workspace_id}, will use neo4j storage"
)
storage_type = 'neo4j'
return storage_type, user_rag_memory_id
# ==================== 工作流执行 ==================== # ==================== 工作流执行 ====================
async def run( async def run(
@@ -607,6 +628,7 @@ class WorkflowService:
try: try:
files = await self._handle_file_input(payload.files) files = await self._handle_file_input(payload.files)
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
input_data["files"] = files input_data["files"] = files
message_id = uuid.uuid4() message_id = uuid.uuid4()
# 更新状态为运行中 # 更新状态为运行中
@@ -631,7 +653,9 @@ class WorkflowService:
input_data=input_data, input_data=input_data,
execution_id=execution.execution_id, execution_id=execution.execution_id,
workspace_id=str(workspace_id), workspace_id=str(workspace_id),
user_id=payload.user_id user_id=payload.user_id,
memory_storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
) )
# 更新执行结果 # 更新执行结果
if result.get("status") == "completed": if result.get("status") == "completed":
@@ -780,6 +804,7 @@ class WorkflowService:
try: try:
files = await self._handle_file_input(payload.files) files = await self._handle_file_input(payload.files)
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
input_data["files"] = files input_data["files"] = files
self.update_execution_status(execution.execution_id, "running") self.update_execution_status(execution.execution_id, "running")
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid) executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
@@ -801,6 +826,8 @@ class WorkflowService:
execution_id=execution.execution_id, execution_id=execution.execution_id,
workspace_id=str(workspace_id), workspace_id=str(workspace_id),
user_id=payload.user_id, user_id=payload.user_id,
memory_storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
): ):
if event.get("event") == "workflow_end": if event.get("event") == "workflow_end":
status = event.get("data", {}).get("status") status = event.get("data", {}).get("status")

View File

@@ -863,7 +863,7 @@ def get_workspace_storage_type(
def get_workspace_storage_type_without_auth( def get_workspace_storage_type_without_auth(
db: Session, db: Session,
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
) -> Optional[str]: ) -> str:
"""获取工作空间的存储类型(无需权限验证,用于公开分享等场景) """获取工作空间的存储类型(无需权限验证,用于公开分享等场景)
Args: Args:

View File

@@ -1073,9 +1073,15 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
@celery_app.task(name="app.core.memory.agent.write_message", bind=True) @celery_app.task(name="app.core.memory.agent.write_message", bind=True)
def write_message_task(self, end_user_id: str, message: list[dict], config_id: str | int, storage_type: str, def write_message_task(
user_rag_memory_id: str, self,
language: str = "zh") -> Dict[str, Any]: end_user_id: str,
message: list[dict],
config_id: str | int,
storage_type: str,
user_rag_memory_id: str,
language: str = "zh"
) -> Dict[str, Any]:
"""Celery task to process a write message via MemoryAgentService. """Celery task to process a write message via MemoryAgentService.
Args: Args:
end_user_id: Group ID for the memory agent (also used as end_user_id) end_user_id: Group ID for the memory agent (also used as end_user_id)
@@ -1091,7 +1097,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
Raises: Raises:
Exception on failure Exception on failure
""" """
logger.info( logger.info(
f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, " f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, "
f"config_id={config_id} (type: {type(config_id).__name__}), " f"config_id={config_id} (type: {type(config_id).__name__}), "
@@ -1105,14 +1110,11 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
try: try:
with get_db_context() as db: with get_db_context() as db:
actual_config_id = resolve_config_id(config_id, db) actual_config_id = resolve_config_id(config_id, db)
print(100 * '-') logger.info(f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} "
print(actual_config_id) f"(type: {type(actual_config_id).__name__})")
print(100 * '-')
logger.info(
f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} (type: {type(actual_config_id).__name__})")
except (ValueError, AttributeError) as e: except (ValueError, AttributeError) as e:
logger.error( logger.error(f"[CELERY WRITE] Invalid config_id format: {config_id} "
f"[CELERY WRITE] Invalid config_id format: {config_id} (type: {type(config_id).__name__}), error: {e}") f"(type: {type(config_id).__name__}), error: {e}")
return { return {
"status": "FAILURE", "status": "FAILURE",
"error": f"Invalid config_id format: {config_id} - {str(e)}", "error": f"Invalid config_id format: {config_id} - {str(e)}",
@@ -1151,8 +1153,8 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
result = loop.run_until_complete(_run()) result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
logger.info( logger.info(f"[CELERY WRITE] Task completed successfully "
f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") f"- elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
# 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用 # 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用
try: try:
@@ -1167,7 +1169,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
) )
except Exception as _e: except Exception as _e:
logger.warning(f"[CELERY WRITE] 写入 last_done 时间戳失败(不影响主流程): {_e}") logger.warning(f"[CELERY WRITE] 写入 last_done 时间戳失败(不影响主流程): {_e}")
return { return {
"status": "SUCCESS", "status": "SUCCESS",
"result": result, "result": result,
@@ -2611,57 +2612,6 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
} }
@celery_app.task(
name="app.tasks.write_perceptual_memory",
bind=True,
ignore_result=True,
max_retries=0,
acks_late=False,
time_limit=3600,
soft_time_limit=3300,
)
def write_perceptual_memory(
self,
end_user_id: str,
model_api_config: dict,
file_type: str,
file_url: str,
file_message: dict
):
"""
Write perceptual memory for a user into PostgreSQL and Neo4j.
This task generates or updates the user's perceptual memory
in the backend databases. It is intended to be executed asynchronously
via Celery.
Args:
end_user_id (uuid.UUID): The unique identifier of the end user.
model_api_config (ModelInfo): API configuration for the model
used to generate perceptual memory.
file_type (str): The file type
file_url (url): The url of file
file_message (dict): The file message containing details about the file
to be processed.
Returns:
None
"""
file_url_md5 = hashlib.md5(file_url.encode("utf-8")).hexdigest()
set_asyncio_event_loop()
with RedisLock(f"perceptual:{file_url_md5}", redis_client=get_sync_redis_client()):
model_info = ModelInfo(**model_api_config)
with get_db_context() as db:
memory_perceptual_service = MemoryPerceptualService(db)
return asyncio.run(memory_perceptual_service.generate_perceptual_memory(
end_user_id,
model_info,
file_type,
file_url,
file_message,
))
# ============================================================================= # =============================================================================
# 社区聚类补全任务(触发型) # 社区聚类补全任务(触发型)
# ============================================================================= # =============================================================================
@@ -2672,7 +2622,7 @@ def write_perceptual_memory(
ignore_result=False, ignore_result=False,
max_retries=0, max_retries=0,
acks_late=False, acks_late=False,
time_limit=7200, # 2小时硬超时 time_limit=7200, # 2小时硬超时
soft_time_limit=6900, soft_time_limit=6900,
) )
def init_community_clustering_for_users(self, end_user_ids: List[str], workspace_id: Optional[str] = None) -> Dict[str, Any]: def init_community_clustering_for_users(self, end_user_ids: List[str], workspace_id: Optional[str] = None) -> Dict[str, Any]:
@@ -2787,7 +2737,8 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace
embedding_model_id=embedding_model_id, embedding_model_id=embedding_model_id,
) )
logger.info(f"[CommunityCluster] 用户 {end_user_id}{len(entities)} 个实体开始全量聚类llm_model_id={llm_model_id}") logger.info(
f"[CommunityCluster] 用户 {end_user_id}{len(entities)} 个实体开始全量聚类llm_model_id={llm_model_id}")
await engine.full_clustering(end_user_id) await engine.full_clustering(end_user_id)
initialized += 1 initialized += 1
logger.info(f"[CommunityCluster] 用户 {end_user_id} 聚类完成") logger.info(f"[CommunityCluster] 用户 {end_user_id} 聚类完成")
@@ -2810,12 +2761,6 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace
} }
try: try:
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
loop = set_asyncio_event_loop() loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run()) result = loop.run_until_complete(_run())
result["elapsed_time"] = time.time() - start_time result["elapsed_time"] = time.time() - start_time