Merge branch 'develop' into feature/ui_upgrade_zy

This commit is contained in:
zhaoying
2026-03-20 11:49:00 +08:00
286 changed files with 23406 additions and 5328 deletions

View File

@@ -45,7 +45,8 @@ RUN --mount=type=cache,id=mem_apt,target=/var/cache/apt,sharing=locked \
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \ apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
apt install -y libjemalloc-dev && \ apt install -y libjemalloc-dev && \
apt install -y python3-pip pipx nginx unzip curl wget git vim less && \ apt install -y python3-pip pipx nginx unzip curl wget git vim less && \
apt install -y ghostscript apt install -y ghostscript && \
apt install -y libmagic1
RUN if [ "$NEED_MIRROR" == "1" ]; then \ RUN if [ "$NEED_MIRROR" == "1" ]; then \
pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \ pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \

View File

@@ -60,7 +60,12 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne
# are written from script.py.mako # are written from script.py.mako
# output_encoding = utf-8 # output_encoding = utf-8
sqlalchemy.url = postgresql://user:password@localhost/dbname # Database connection URL - DO NOT hardcode credentials here!
# Connection string is set dynamically from environment variables in migrations/env.py
# Required env vars: DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, DB_NAME
# Example: postgresql://user:password@localhost:5432/dbname
; sqlalchemy.url = postgresql://user:password@host:port/dbname
sqlalchemy.url = driver://user:password@host:port/dbname
[post_write_hooks] [post_write_hooks]

View File

@@ -1,10 +1,11 @@
import os
import asyncio import asyncio
import json import json
import logging import logging
from typing import Dict, Any, Optional from typing import Dict, Any, Optional
import redis.asyncio as redis import redis.asyncio as redis
from redis.asyncio import ConnectionPool from redis.asyncio import ConnectionPool
from app.core.config import settings from app.core.config import settings
# 设置日志记录器 # 设置日志记录器

View File

@@ -63,9 +63,9 @@ celery_app.conf.update(
accept_content=['json'], accept_content=['json'],
result_serializer='json', result_serializer='json',
# 时区 # # 时区
timezone='Asia/Shanghai', # timezone='Asia/Shanghai',
enable_utc=True, # enable_utc=False,
# 任务追踪 # 任务追踪
task_track_started=True, task_track_started=True,
@@ -96,6 +96,7 @@ celery_app.conf.update(
'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'}, 'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'},
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'}, 'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'}, 'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
'app.tasks.write_perceptual_memory': {'queue': 'memory_tasks'},
# Long-term storage tasks → memory_tasks queue (batched write strategies) # Long-term storage tasks → memory_tasks queue (batched write strategies)
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'}, 'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
@@ -113,6 +114,9 @@ celery_app.conf.update(
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'}, 'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'}, 'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'},
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'}, 'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
'app.tasks.init_implicit_emotions_for_users': {'queue': 'periodic_tasks'},
'app.tasks.init_interest_distribution_for_users': {'queue': 'periodic_tasks'},
'app.tasks.init_community_clustering_for_users': {'queue': 'periodic_tasks'},
}, },
) )
@@ -129,7 +133,7 @@ implicit_emotions_update_schedule = crontab(
minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE, minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE,
) )
#构建定时任务配置 # 构建定时任务配置
beat_schedule_config = { beat_schedule_config = {
"run-workspace-reflection": { "run-workspace-reflection": {
"task": "app.tasks.workspace_reflection_task", "task": "app.tasks.workspace_reflection_task",

View File

@@ -16,6 +16,7 @@ from . import (
file_controller, file_controller,
file_storage_controller, file_storage_controller,
home_page_controller, home_page_controller,
i18n_controller,
implicit_memory_controller, implicit_memory_controller,
knowledge_controller, knowledge_controller,
knowledgeshare_controller, knowledgeshare_controller,
@@ -94,5 +95,6 @@ manager_router.include_router(memory_working_controller.router)
manager_router.include_router(file_storage_controller.router) manager_router.include_router(file_storage_controller.router)
manager_router.include_router(ontology_controller.router) manager_router.include_router(ontology_controller.router)
manager_router.include_router(skill_controller.router) manager_router.include_router(skill_controller.router)
manager_router.include_router(i18n_controller.router)
__all__ = ["manager_router"] __all__ = ["manager_router"]

View File

@@ -1,10 +1,12 @@
import uuid import uuid
import io
from typing import Optional, Annotated from typing import Optional, Annotated
import yaml import yaml
from fastapi import APIRouter, Depends, Path, Form, UploadFile, File from fastapi import APIRouter, Depends, Path, Form, UploadFile, File
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from urllib.parse import quote
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
@@ -25,6 +27,7 @@ from app.services.app_service import AppService
from app.services.app_statistics_service import AppStatisticsService from app.services.app_statistics_service import AppStatisticsService
from app.services.workflow_import_service import WorkflowImportService from app.services.workflow_import_service import WorkflowImportService
from app.services.workflow_service import WorkflowService, get_workflow_service from app.services.workflow_service import WorkflowService, get_workflow_service
from app.services.app_dsl_service import AppDslService
router = APIRouter(prefix="/apps", tags=["Apps"]) router = APIRouter(prefix="/apps", tags=["Apps"])
logger = get_business_logger() logger = get_business_logger()
@@ -50,6 +53,7 @@ def list_apps(
status: str | None = None, status: str | None = None,
search: str | None = None, search: str | None = None,
include_shared: bool = True, include_shared: bool = True,
shared_only: bool = False,
page: int = 1, page: int = 1,
pagesize: int = 10, pagesize: int = 10,
ids: Optional[str] = None, ids: Optional[str] = None,
@@ -81,6 +85,7 @@ def list_apps(
status=status, status=status,
search=search, search=search,
include_shared=include_shared, include_shared=include_shared,
shared_only=shared_only,
page=page, page=page,
pagesize=pagesize, pagesize=pagesize,
) )
@@ -90,6 +95,37 @@ def list_apps(
return success(data=PageData(page=meta, items=items)) return success(data=PageData(page=meta, items=items))
@router.get("/my-shared-out", summary="列出本工作空间主动分享出去的记录")
@cur_workspace_access_guard()
def list_my_shared_out(
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""列出本工作空间主动分享给其他工作空间的所有记录(我的共享)"""
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
shares = service.list_my_shared_out(workspace_id=workspace_id)
data = [app_schema.AppShare.model_validate(s) for s in shares]
return success(data=data)
@router.delete("/share/{target_workspace_id}", summary="取消对某工作空间的所有应用分享")
@cur_workspace_access_guard()
def unshare_all_apps_to_workspace(
target_workspace_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""Cancel all app shares from current workspace to a target workspace."""
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
count = service.unshare_all_apps_to_workspace(
target_workspace_id=target_workspace_id,
workspace_id=workspace_id
)
return success(msg=f"已取消 {count} 个应用的分享", data={"count": count})
@router.get("/{app_id}", summary="获取应用详情") @router.get("/{app_id}", summary="获取应用详情")
@cur_workspace_access_guard() @cur_workspace_access_guard()
def get_app( def get_app(
@@ -158,6 +194,7 @@ def delete_app(
def copy_app( def copy_app(
app_id: uuid.UUID, app_id: uuid.UUID,
new_name: Optional[str] = None, new_name: Optional[str] = None,
payload: app_schema.CopyAppRequest = None,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
): ):
@@ -169,6 +206,8 @@ def copy_app(
- 不影响原应用 - 不影响原应用
""" """
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
# body takes precedence over query param for backward compatibility
new_name = (payload.new_name if payload else None) or new_name
logger.info( logger.info(
"用户请求复制应用", "用户请求复制应用",
extra={ extra={
@@ -218,6 +257,27 @@ def get_agent_config(
return success(data=app_schema.AgentConfig.model_validate(cfg)) return success(data=app_schema.AgentConfig.model_validate(cfg))
@router.get("/{app_id}/opening", summary="获取应用开场白配置")
@cur_workspace_access_guard()
def get_opening(
app_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""返回开场白文本和预设问题,供前端对话界面初始化时展示"""
workspace_id = current_user.current_workspace_id
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
features = cfg.features or {}
if hasattr(features, "model_dump"):
features = features.model_dump()
opening = features.get("opening_statement", {})
return success(data=app_schema.OpeningResponse(
enabled=opening.get("enabled", False),
statement=opening.get("statement"),
suggested_questions=opening.get("suggested_questions", []),
))
@router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)") @router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)")
@cur_workspace_access_guard() @cur_workspace_access_guard()
def publish_app( def publish_app(
@@ -299,7 +359,8 @@ def share_app(
app_id=app_id, app_id=app_id,
target_workspace_ids=payload.target_workspace_ids, target_workspace_ids=payload.target_workspace_ids,
user_id=current_user.id, user_id=current_user.id,
workspace_id=workspace_id workspace_id=workspace_id,
permission=payload.permission
) )
data = [app_schema.AppShare.model_validate(s) for s in shares] data = [app_schema.AppShare.model_validate(s) for s in shares]
@@ -330,6 +391,32 @@ def unshare_app(
return success(msg="应用分享已取消") return success(msg="应用分享已取消")
@router.patch("/{app_id}/share/{target_workspace_id}", summary="更新共享权限")
@cur_workspace_access_guard()
def update_share_permission(
app_id: uuid.UUID,
target_workspace_id: uuid.UUID,
payload: app_schema.UpdateSharePermissionRequest,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""更新共享权限readonly <-> editable
- 只能修改自己工作空间应用的共享权限
"""
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
share = service.update_share_permission(
app_id=app_id,
target_workspace_id=target_workspace_id,
permission=payload.permission,
workspace_id=workspace_id
)
return success(data=app_schema.AppShare.model_validate(share))
@router.get("/{app_id}/shares", summary="列出应用的分享记录") @router.get("/{app_id}/shares", summary="列出应用的分享记录")
@cur_workspace_access_guard() @cur_workspace_access_guard()
def list_app_shares( def list_app_shares(
@@ -353,6 +440,46 @@ def list_app_shares(
return success(data=data) return success(data=data)
@router.delete("/shared/{source_workspace_id}", summary="批量移除某来源工作空间的所有共享应用")
@cur_workspace_access_guard()
def remove_all_shared_apps_from_workspace(
source_workspace_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""Remove all shared apps from a specific source workspace (recipient operation)."""
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
count = service.remove_all_shared_apps_from_workspace(
source_workspace_id=source_workspace_id,
workspace_id=workspace_id
)
return success(msg=f"已移除 {count} 个共享应用", data={"count": count})
@router.delete("/{app_id}/shared", summary="移除共享给我的应用")
@cur_workspace_access_guard()
def remove_shared_app(
app_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""被共享者从自己的工作空间移除共享应用
- 不会删除源应用,只删除共享记录
- 只能移除共享给自己工作空间的应用
"""
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
service.remove_shared_app(
app_id=app_id,
workspace_id=workspace_id
)
return success(msg="已移除共享应用")
@router.post("/{app_id}/draft/run", summary="试运行 Agent使用当前草稿配置") @router.post("/{app_id}/draft/run", summary="试运行 Agent使用当前草稿配置")
@cur_workspace_access_guard() @cur_workspace_access_guard()
async def draft_run( async def draft_run(
@@ -393,7 +520,7 @@ async def draft_run(
# 提前验证和准备(在流式响应开始前完成) # 提前验证和准备(在流式响应开始前完成)
from app.services.app_service import AppService from app.services.app_service import AppService
from app.services.multi_agent_service import MultiAgentService from app.services.multi_agent_service import MultiAgentService
from app.models import AgentConfig, ModelConfig from app.models import AgentConfig, ModelConfig, AppRelease
from sqlalchemy import select from sqlalchemy import select
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.services.draft_run_service import AgentRunService from app.services.draft_run_service import AgentRunService
@@ -410,11 +537,12 @@ async def draft_run(
service._validate_app_accessible(app, workspace_id) service._validate_app_accessible(app, workspace_id)
if payload.user_id is None: if payload.user_id is None:
# 先获取 app 的 workspace_id
end_user_repo = EndUserRepository(db) end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user( new_end_user = end_user_repo.get_or_create_end_user(
app_id=app_id, app_id=app_id,
workspace_id=app.workspace_id,
other_id=str(current_user.id), other_id=str(current_user.id),
original_user_id=str(current_user.id) # Save original user_id to other_id
) )
payload.user_id = str(new_end_user.id) payload.user_id = str(new_end_user.id)
@@ -431,18 +559,29 @@ async def draft_run(
service._check_agent_config(app_id) service._check_agent_config(app_id)
# 2. 获取 Agent 配置 # 2. 获取 Agent 配置
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) # 共享应用:从最新发布版本读配置快照,而非草稿
agent_cfg = db.scalars(stmt).first() is_shared = app.workspace_id != workspace_id
if not agent_cfg: if is_shared:
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING) if not app.current_release_id:
raise BusinessException("该应用尚未发布,无法使用", BizCode.AGENT_CONFIG_MISSING)
release = db.get(AppRelease, app.current_release_id)
if not release:
raise BusinessException("发布版本不存在", BizCode.AGENT_CONFIG_MISSING)
agent_cfg = service._agent_config_from_release(release)
model_config = db.get(ModelConfig, release.default_model_config_id) if release.default_model_config_id else None
else:
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
agent_cfg = db.scalars(stmt).first()
if not agent_cfg:
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
# 3. 获取模型配置 # 3. 获取模型配置
model_config = None model_config = None
if agent_cfg.default_model_config_id: if agent_cfg.default_model_config_id:
model_config = db.get(ModelConfig, agent_cfg.default_model_config_id) model_config = db.get(ModelConfig, agent_cfg.default_model_config_id)
if not model_config: if not model_config:
from app.core.exceptions import ResourceNotFoundException from app.core.exceptions import ResourceNotFoundException
raise ResourceNotFoundException("模型配置", str(agent_cfg.default_model_config_id)) raise ResourceNotFoundException("模型配置", str(agent_cfg.default_model_config_id))
# 流式返回 # 流式返回
if payload.stream: if payload.stream:
@@ -598,7 +737,17 @@ async def draft_run(
msg="多 Agent 任务执行成功" msg="多 Agent 任务执行成功"
) )
elif app.type == AppType.WORKFLOW: # 工作流 elif app.type == AppType.WORKFLOW: # 工作流
config = workflow_service.check_config(app_id) # 共享应用:从最新发布版本读配置快照,而非草稿
is_shared = app.workspace_id != workspace_id
if is_shared:
if not app.current_release_id:
raise BusinessException("该应用尚未发布,无法使用", BizCode.AGENT_CONFIG_MISSING)
release = db.get(AppRelease, app.current_release_id)
if not release:
raise BusinessException("发布版本不存在", BizCode.AGENT_CONFIG_MISSING)
config = service._workflow_config_from_release(release)
else:
config = workflow_service.check_config(app_id)
# 3. 流式返回 # 3. 流式返回
if payload.stream: if payload.stream:
logger.debug( logger.debug(
@@ -741,6 +890,16 @@ async def draft_run_compare(
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
service._validate_app_accessible(app, workspace_id) service._validate_app_accessible(app, workspace_id)
if payload.user_id is None:
# 先获取 app 的 workspace_id
end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user(
app_id=app_id,
workspace_id=app.workspace_id,
other_id=str(current_user.id),
)
payload.user_id = str(new_end_user.id)
# 2. 获取 Agent 配置 # 2. 获取 Agent 配置
from sqlalchemy import select from sqlalchemy import select
from app.models import AgentConfig from app.models import AgentConfig
@@ -786,6 +945,13 @@ async def draft_run_compare(
"conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id "conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id
}) })
# 从 features 中读取功能开关(与 draft_run 保持一致)
features_config: dict = agent_cfg.features or {}
if hasattr(features_config, 'model_dump'):
features_config = features_config.model_dump()
web_search_feature = features_config.get("web_search", {})
web_search = isinstance(web_search_feature, dict) and web_search_feature.get("enabled", False)
# 流式返回 # 流式返回
if payload.stream: if payload.stream:
async def event_generator(): async def event_generator():
@@ -797,11 +963,11 @@ async def draft_run_compare(
message=payload.message, message=payload.message,
workspace_id=workspace_id, workspace_id=workspace_id,
conversation_id=payload.conversation_id, conversation_id=payload.conversation_id,
user_id=payload.user_id or str(current_user.id), user_id=payload.user_id,
variables=payload.variables, variables=payload.variables,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id, user_rag_memory_id=user_rag_memory_id,
web_search=True, web_search=web_search,
memory=True, memory=True,
parallel=payload.parallel, parallel=payload.parallel,
timeout=payload.timeout or 60, timeout=payload.timeout or 60,
@@ -828,11 +994,11 @@ async def draft_run_compare(
message=payload.message, message=payload.message,
workspace_id=workspace_id, workspace_id=workspace_id,
conversation_id=payload.conversation_id, conversation_id=payload.conversation_id,
user_id=payload.user_id or str(current_user.id), user_id=payload.user_id,
variables=payload.variables, variables=payload.variables,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id, user_rag_memory_id=user_rag_memory_id,
web_search=True, web_search=web_search,
memory=True, memory=True,
parallel=payload.parallel, parallel=payload.parallel,
timeout=payload.timeout or 60, timeout=payload.timeout or 60,
@@ -1010,3 +1176,57 @@ def get_workspace_api_statistics(
) )
return success(data=result) return success(data=result)
@router.get("/{app_id}/export", summary="导出应用配置为 YAML 文件")
@cur_workspace_access_guard()
async def export_app(
app_id: uuid.UUID,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
release_id: Optional[uuid.UUID] = None
):
"""导出 agent / multi_agent / workflow 应用配置为 YAML 文件流。
release_id: 指定发布版本id不传则导出当前草稿配置。
"""
yaml_str, filename = AppDslService(db).export_dsl(app_id, release_id)
encoded = quote(filename, safe=".")
yaml_bytes = yaml_str.encode("utf-8")
file_stream = io.BytesIO(yaml_bytes)
file_stream.seek(0)
return StreamingResponse(
file_stream,
media_type="application/octet-stream; charset=utf-8",
headers={"Content-Disposition": f"attachment; filename={encoded}",
"Content-Length": str(len(yaml_bytes))}
)
@router.post("/import", summary="从 YAML 文件导入应用")
@cur_workspace_access_guard()
async def import_app(
file: UploadFile = File(...),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""从 YAML 文件导入 agent / multi_agent / workflow 应用。
跨空间/跨租户导入时,模型/工具/知识库会按名称匹配,匹配不到则置空并返回 warnings。
"""
if not file.filename.lower().endswith((".yaml", ".yml")):
return fail(msg="仅支持 YAML 文件", code=BizCode.BAD_REQUEST)
raw = (await file.read()).decode("utf-8")
dsl = yaml.safe_load(raw)
if not dsl or "app" not in dsl:
return fail(msg="YAML 格式无效,缺少 app 字段", code=BizCode.BAD_REQUEST)
new_app, warnings = AppDslService(db).import_dsl(
dsl=dsl,
workspace_id=current_user.current_workspace_id,
tenant_id=current_user.tenant_id,
user_id=current_user.id,
)
return success(
data={"app": app_schema.App.model_validate(new_app), "warnings": warnings},
msg="应用导入成功" + (",但部分资源需手动配置" if warnings else "")
)

View File

@@ -1,4 +1,5 @@
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Callable
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -16,6 +17,7 @@ from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.dependencies import get_current_user, oauth2_scheme from app.dependencies import get_current_user, oauth2_scheme
from app.models.user_model import User from app.models.user_model import User
from app.i18n.dependencies import get_translator
# 获取专用日志器 # 获取专用日志器
auth_logger = get_auth_logger() auth_logger = get_auth_logger()
@@ -26,7 +28,8 @@ router = APIRouter(tags=["Authentication"])
@router.post("/token", response_model=ApiResponse) @router.post("/token", response_model=ApiResponse)
async def login_for_access_token( async def login_for_access_token(
form_data: TokenRequest, form_data: TokenRequest,
db: Session = Depends(get_db) db: Session = Depends(get_db),
t: Callable = Depends(get_translator)
): ):
"""用户登录获取token""" """用户登录获取token"""
auth_logger.info(f"用户登录请求: {form_data.email}") auth_logger.info(f"用户登录请求: {form_data.email}")
@@ -40,10 +43,10 @@ async def login_for_access_token(
invite_info = workspace_service.validate_invite_token(db, form_data.invite) invite_info = workspace_service.validate_invite_token(db, form_data.invite)
if not invite_info.is_valid: if not invite_info.is_valid:
raise BusinessException("邀请码无效或已过期", code=BizCode.BAD_REQUEST) raise BusinessException(t("auth.invite.invalid"), code=BizCode.BAD_REQUEST)
if invite_info.email != form_data.email: if invite_info.email != form_data.email:
raise BusinessException("邀请邮箱与登录邮箱不匹配", code=BizCode.BAD_REQUEST) raise BusinessException(t("auth.invite.email_mismatch"), code=BizCode.BAD_REQUEST)
auth_logger.info(f"邀请码验证成功: workspace={invite_info.workspace_name}") auth_logger.info(f"邀请码验证成功: workspace={invite_info.workspace_name}")
try: try:
# 尝试认证用户 # 尝试认证用户
@@ -69,7 +72,7 @@ async def login_for_access_token(
elif e.code == BizCode.PASSWORD_ERROR: elif e.code == BizCode.PASSWORD_ERROR:
# 用户存在但密码错误 # 用户存在但密码错误
auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}") auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}")
raise BusinessException("接受邀请失败,密码验证错误", BizCode.LOGIN_FAILED) raise BusinessException(t("auth.invite.password_verification_failed"), BizCode.LOGIN_FAILED)
else: else:
# 其他认证失败情况,直接抛出 # 其他认证失败情况,直接抛出
raise raise
@@ -82,7 +85,7 @@ async def login_for_access_token(
except BusinessException as e: except BusinessException as e:
# 其他认证失败情况,直接抛出 # 其他认证失败情况,直接抛出
raise BusinessException(e.message,BizCode.LOGIN_FAILED) raise BusinessException(e.message, BizCode.LOGIN_FAILED)
# 创建 tokens # 创建 tokens
access_token, access_token_id = security.create_access_token(subject=user.id) access_token, access_token_id = security.create_access_token(subject=user.id)
@@ -110,14 +113,15 @@ async def login_for_access_token(
expires_at=access_expires_at, expires_at=access_expires_at,
refresh_expires_at=refresh_expires_at refresh_expires_at=refresh_expires_at
), ),
msg="登录成功" msg=t("auth.login.success")
) )
@router.post("/refresh", response_model=ApiResponse) @router.post("/refresh", response_model=ApiResponse)
async def refresh_token( async def refresh_token(
refresh_request: RefreshTokenRequest, refresh_request: RefreshTokenRequest,
db: Session = Depends(get_db) db: Session = Depends(get_db),
t: Callable = Depends(get_translator)
): ):
"""刷新token""" """刷新token"""
auth_logger.info("收到token刷新请求") auth_logger.info("收到token刷新请求")
@@ -125,18 +129,18 @@ async def refresh_token(
# 验证 refresh token # 验证 refresh token
userId = security.verify_token(refresh_request.refresh_token, "refresh") userId = security.verify_token(refresh_request.refresh_token, "refresh")
if not userId: if not userId:
raise BusinessException("无效的refresh token", code=BizCode.TOKEN_INVALID) raise BusinessException(t("auth.token.invalid_refresh_token"), code=BizCode.TOKEN_INVALID)
# 检查用户是否存在 # 检查用户是否存在
user = auth_service.get_user_by_id(db, userId) user = auth_service.get_user_by_id(db, userId)
if not user: if not user:
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND) raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
# 检查 refresh token 黑名单 # 检查 refresh token 黑名单
if settings.ENABLE_SINGLE_SESSION: if settings.ENABLE_SINGLE_SESSION:
refresh_token_id = security.get_token_id(refresh_request.refresh_token) refresh_token_id = security.get_token_id(refresh_request.refresh_token)
if refresh_token_id and await SessionService.is_token_blacklisted(refresh_token_id): if refresh_token_id and await SessionService.is_token_blacklisted(refresh_token_id):
raise BusinessException("Refresh token已失效", code=BizCode.TOKEN_BLACKLISTED) raise BusinessException(t("auth.token.refresh_token_blacklisted"), code=BizCode.TOKEN_BLACKLISTED)
# 生成新 tokens # 生成新 tokens
new_access_token, new_access_token_id = security.create_access_token(subject=user.id) new_access_token, new_access_token_id = security.create_access_token(subject=user.id)
@@ -167,7 +171,7 @@ async def refresh_token(
expires_at=access_expires_at, expires_at=access_expires_at,
refresh_expires_at=refresh_expires_at refresh_expires_at=refresh_expires_at
), ),
msg="token刷新成功" msg=t("auth.token.refresh_success")
) )
@@ -175,14 +179,15 @@ async def refresh_token(
async def logout( async def logout(
token: str = Depends(oauth2_scheme), token: str = Depends(oauth2_scheme),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db),
t: Callable = Depends(get_translator)
): ):
"""登出当前用户加入token黑名单并清理会话""" """登出当前用户加入token黑名单并清理会话"""
auth_logger.info(f"用户 {current_user.username} 请求登出") auth_logger.info(f"用户 {current_user.username} 请求登出")
token_id = security.get_token_id(token) token_id = security.get_token_id(token)
if not token_id: if not token_id:
raise BusinessException("无效的access token", code=BizCode.TOKEN_INVALID) raise BusinessException(t("auth.token.invalid"), code=BizCode.TOKEN_INVALID)
# 加入黑名单 # 加入黑名单
await SessionService.blacklist_token(token_id) await SessionService.blacklist_token(token_id)
@@ -192,5 +197,5 @@ async def logout(
await SessionService.clear_user_session(current_user.username) await SessionService.clear_user_session(current_user.username)
auth_logger.info(f"用户 {current_user.username} 登出成功") auth_logger.info(f"用户 {current_user.username} 登出成功")
return success(msg="登出成功") return success(msg=t("auth.logout.success"))

View File

@@ -15,7 +15,7 @@ import os
import uuid import uuid
from typing import Any from typing import Any
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
from fastapi.responses import FileResponse, RedirectResponse from fastapi.responses import FileResponse, RedirectResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -47,6 +47,19 @@ router = APIRouter(
) )
def _match_scheme(request: Request, url: str) -> str:
"""
将 presigned URL 的协议替换为与当前请求一致的协议http/https
解决反向代理场景下 presigned URL 协议与请求协议不匹配的问题。
"""
incoming_scheme = request.headers.get("x-forwarded-proto") or request.url.scheme
if url.startswith("http://") and incoming_scheme == "https":
return "https://" + url[7:]
if url.startswith("https://") and incoming_scheme == "http":
return "http://" + url[8:]
return url
@router.post("/files", response_model=ApiResponse) @router.post("/files", response_model=ApiResponse)
async def upload_file( async def upload_file(
file: UploadFile = File(...), file: UploadFile = File(...),
@@ -280,6 +293,7 @@ async def upload_file_with_share_token(
@router.get("/files/{file_id}", response_model=Any) @router.get("/files/{file_id}", response_model=Any)
async def download_file( async def download_file(
request: Request,
file_id: uuid.UUID, file_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
@@ -327,6 +341,7 @@ async def download_file(
else: else:
try: try:
presigned_url = await storage_service.get_file_url(file_key, expires=3600) presigned_url = await storage_service.get_file_url(file_key, expires=3600)
presigned_url = _match_scheme(request, presigned_url)
api_logger.info(f"Redirecting to presigned URL: file_key={file_key}") api_logger.info(f"Redirecting to presigned URL: file_key={file_key}")
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND) return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
except FileNotFoundError: except FileNotFoundError:
@@ -400,6 +415,7 @@ async def delete_file(
@router.get("/files/{file_id}/url", response_model=ApiResponse) @router.get("/files/{file_id}/url", response_model=ApiResponse)
async def get_file_url( async def get_file_url(
request: Request,
file_id: uuid.UUID, file_id: uuid.UUID,
expires: int = None, expires: int = None,
permanent: bool = False, permanent: bool = False,
@@ -463,6 +479,7 @@ async def get_file_url(
else: else:
# For remote storage (OSS/S3), get presigned URL # For remote storage (OSS/S3), get presigned URL
url = await storage_service.get_file_url(file_key, expires=expires) url = await storage_service.get_file_url(file_key, expires=expires)
url = _match_scheme(request, url)
api_logger.info(f"Generated file URL: file_id={file_id}") api_logger.info(f"Generated file URL: file_id={file_id}")
return success( return success(
@@ -484,6 +501,7 @@ async def get_file_url(
@router.get("/public/{file_id}", response_model=Any) @router.get("/public/{file_id}", response_model=Any)
async def public_download_file( async def public_download_file(
request: Request,
file_id: uuid.UUID, file_id: uuid.UUID,
expires: int = 0, expires: int = 0,
signature: str = "", signature: str = "",
@@ -555,6 +573,7 @@ async def public_download_file(
# For remote storage, redirect to presigned URL # For remote storage, redirect to presigned URL
try: try:
presigned_url = await storage_service.get_file_url(file_key, expires=3600) presigned_url = await storage_service.get_file_url(file_key, expires=3600)
presigned_url = _match_scheme(request, presigned_url)
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND) return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
except Exception as e: except Exception as e:
api_logger.error(f"Failed to get presigned URL: {e}") api_logger.error(f"Failed to get presigned URL: {e}")
@@ -566,6 +585,7 @@ async def public_download_file(
@router.get("/permanent/{file_id}", response_model=Any) @router.get("/permanent/{file_id}", response_model=Any)
async def permanent_download_file( async def permanent_download_file(
request: Request,
file_id: uuid.UUID, file_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
storage_service: FileStorageService = Depends(get_file_storage_service), storage_service: FileStorageService = Depends(get_file_storage_service),
@@ -625,6 +645,7 @@ async def permanent_download_file(
try: try:
# Use a very long expiration (7 days max for most cloud providers) # Use a very long expiration (7 days max for most cloud providers)
presigned_url = await storage_service.get_file_url(file_key, expires=604800) presigned_url = await storage_service.get_file_url(file_key, expires=604800)
presigned_url = _match_scheme(request, presigned_url)
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND) return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
except Exception as e: except Exception as e:
api_logger.error(f"Failed to get presigned URL: {e}") api_logger.error(f"Failed to get presigned URL: {e}")

View File

@@ -0,0 +1,833 @@
"""
I18n Management API Controller
This module provides management APIs for:
- Language management (list, get, add, update languages)
- Translation management (get, update, reload translations)
"""
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from typing import Callable, Optional
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user, get_current_superuser
from app.i18n.dependencies import get_translator
from app.i18n.service import get_translation_service
from app.models.user_model import User
from app.schemas.i18n_schema import (
LanguageInfo,
LanguageListResponse,
LanguageCreateRequest,
LanguageUpdateRequest,
TranslationResponse,
TranslationUpdateRequest,
MissingTranslationsResponse,
ReloadResponse
)
from app.schemas.response_schema import ApiResponse
api_logger = get_api_logger()
router = APIRouter(
prefix="/i18n",
tags=["I18n Management"],
)
# ============================================================================
# Language Management APIs
# ============================================================================
@router.get("/languages", response_model=ApiResponse)
def get_languages(
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_user)
):
"""
Get list of all supported languages.
Returns:
List of language information including code, name, and status
"""
api_logger.info(f"Get languages request from user: {current_user.username}")
from app.core.config import settings
translation_service = get_translation_service()
# Get available locales from translation service
available_locales = translation_service.get_available_locales()
# Build language info list
languages = []
for locale in available_locales:
is_default = locale == settings.I18N_DEFAULT_LANGUAGE
is_enabled = locale in settings.I18N_SUPPORTED_LANGUAGES
# Get native names
native_names = {
"zh": "中文(简体)",
"en": "English",
"ja": "日本語",
"ko": "한국어",
"fr": "Français",
"de": "Deutsch",
"es": "Español"
}
language_info = LanguageInfo(
code=locale,
name=f"{locale.upper()}",
native_name=native_names.get(locale, locale),
is_enabled=is_enabled,
is_default=is_default
)
languages.append(language_info)
response = LanguageListResponse(languages=languages)
api_logger.info(f"Returning {len(languages)} languages")
return success(data=response.dict(), msg=t("common.success.retrieved"))
@router.get("/languages/{locale}", response_model=ApiResponse)
def get_language(
locale: str,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_user)
):
"""
Get information about a specific language.
Args:
locale: Language code (e.g., 'zh', 'en')
Returns:
Language information
"""
api_logger.info(f"Get language info request: locale={locale}, user={current_user.username}")
from app.core.config import settings
translation_service = get_translation_service()
# Check if locale exists
available_locales = translation_service.get_available_locales()
if locale not in available_locales:
api_logger.warning(f"Language not found: {locale}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=t("i18n.language.not_found", locale=locale)
)
# Build language info
is_default = locale == settings.I18N_DEFAULT_LANGUAGE
is_enabled = locale in settings.I18N_SUPPORTED_LANGUAGES
native_names = {
"zh": "中文(简体)",
"en": "English",
"ja": "日本語",
"ko": "한국어",
"fr": "Français",
"de": "Deutsch",
"es": "Español"
}
language_info = LanguageInfo(
code=locale,
name=f"{locale.upper()}",
native_name=native_names.get(locale, locale),
is_enabled=is_enabled,
is_default=is_default
)
api_logger.info(f"Returning language info for: {locale}")
return success(data=language_info.dict(), msg=t("common.success.retrieved"))
@router.post("/languages", response_model=ApiResponse)
def add_language(
request: LanguageCreateRequest,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Add a new language (admin only).
Note: This endpoint validates the request but actual language addition
requires creating translation files in the locales directory.
Args:
request: Language creation request
Returns:
Success message
"""
api_logger.info(
f"Add language request: code={request.code}, admin={current_user.username}"
)
from app.core.config import settings
translation_service = get_translation_service()
# Check if language already exists
available_locales = translation_service.get_available_locales()
if request.code in available_locales:
api_logger.warning(f"Language already exists: {request.code}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=t("i18n.language.already_exists", locale=request.code)
)
# Note: Actual language addition requires creating translation files
# This endpoint serves as a validation and documentation point
api_logger.info(
f"Language addition validated: {request.code}. "
"Translation files need to be created manually."
)
return success(
msg=t(
"i18n.language.add_instructions",
locale=request.code,
dir=settings.I18N_CORE_LOCALES_DIR
)
)
@router.put("/languages/{locale}", response_model=ApiResponse)
def update_language(
locale: str,
request: LanguageUpdateRequest,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Update language configuration (admin only).
Note: This endpoint validates the request but actual configuration
changes require updating environment variables or config files.
Args:
locale: Language code
request: Language update request
Returns:
Success message
"""
api_logger.info(
f"Update language request: locale={locale}, admin={current_user.username}"
)
translation_service = get_translation_service()
# Check if language exists
available_locales = translation_service.get_available_locales()
if locale not in available_locales:
api_logger.warning(f"Language not found: {locale}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=t("i18n.language.not_found", locale=locale)
)
# Note: Actual configuration changes require updating settings
# This endpoint serves as a validation and documentation point
api_logger.info(
f"Language update validated: {locale}. "
"Configuration changes require environment variable updates."
)
return success(msg=t("i18n.language.update_instructions", locale=locale))
# ============================================================================
# Translation Management APIs
# ============================================================================
@router.get("/translations", response_model=ApiResponse)
def get_all_translations(
locale: Optional[str] = None,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_user)
):
"""
Get all translations for all or specific locale.
Args:
locale: Optional locale filter
Returns:
All translations organized by locale and namespace
"""
api_logger.info(
f"Get all translations request: locale={locale}, user={current_user.username}"
)
translation_service = get_translation_service()
if locale:
# Get translations for specific locale
available_locales = translation_service.get_available_locales()
if locale not in available_locales:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=t("i18n.language.not_found", locale=locale)
)
translations = {
locale: translation_service._cache.get(locale, {})
}
else:
# Get all translations
translations = translation_service._cache
response = TranslationResponse(translations=translations)
api_logger.info(f"Returning translations for: {locale or 'all locales'}")
return success(data=response.dict(), msg=t("common.success.retrieved"))
@router.get("/translations/{locale}", response_model=ApiResponse)
def get_locale_translations(
locale: str,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_user)
):
"""
Get all translations for a specific locale.
Args:
locale: Language code
Returns:
All translations for the locale organized by namespace
"""
api_logger.info(
f"Get locale translations request: locale={locale}, user={current_user.username}"
)
translation_service = get_translation_service()
# Check if locale exists
available_locales = translation_service.get_available_locales()
if locale not in available_locales:
api_logger.warning(f"Language not found: {locale}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=t("i18n.language.not_found", locale=locale)
)
translations = translation_service._cache.get(locale, {})
api_logger.info(f"Returning {len(translations)} namespaces for locale: {locale}")
return success(data={"locale": locale, "translations": translations}, msg=t("common.success.retrieved"))
@router.get("/translations/{locale}/{namespace}", response_model=ApiResponse)
def get_namespace_translations(
locale: str,
namespace: str,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_user)
):
"""
Get translations for a specific namespace in a locale.
Args:
locale: Language code
namespace: Translation namespace (e.g., 'common', 'auth')
Returns:
Translations for the specified namespace
"""
api_logger.info(
f"Get namespace translations request: locale={locale}, "
f"namespace={namespace}, user={current_user.username}"
)
translation_service = get_translation_service()
# Check if locale exists
available_locales = translation_service.get_available_locales()
if locale not in available_locales:
api_logger.warning(f"Language not found: {locale}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=t("i18n.language.not_found", locale=locale)
)
# Get namespace translations
locale_translations = translation_service._cache.get(locale, {})
namespace_translations = locale_translations.get(namespace, {})
if not namespace_translations:
api_logger.warning(f"Namespace not found: {namespace} in locale: {locale}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=t("i18n.namespace.not_found", namespace=namespace, locale=locale)
)
api_logger.info(
f"Returning translations for namespace: {namespace} in locale: {locale}"
)
return success(
data={
"locale": locale,
"namespace": namespace,
"translations": namespace_translations
},
msg=t("common.success.retrieved")
)
@router.put("/translations/{locale}/{key:path}", response_model=ApiResponse)
def update_translation(
locale: str,
key: str,
request: TranslationUpdateRequest,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Update a single translation (admin only).
Note: This endpoint validates the request but actual translation updates
require modifying translation files in the locales directory.
Args:
locale: Language code
key: Translation key (format: "namespace.key.subkey")
request: Translation update request
Returns:
Success message
"""
api_logger.info(
f"Update translation request: locale={locale}, key={key}, "
f"admin={current_user.username}"
)
translation_service = get_translation_service()
# Check if locale exists
available_locales = translation_service.get_available_locales()
if locale not in available_locales:
api_logger.warning(f"Language not found: {locale}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=t("i18n.language.not_found", locale=locale)
)
# Validate key format
if "." not in key:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=t("i18n.translation.invalid_key_format", key=key)
)
# Note: Actual translation updates require modifying JSON files
# This endpoint serves as a validation and documentation point
api_logger.info(
f"Translation update validated: {locale}/{key}. "
"Translation files need to be updated manually."
)
return success(
msg=t("i18n.translation.update_instructions", locale=locale, key=key)
)
@router.get("/translations/missing", response_model=ApiResponse)
def get_missing_translations(
locale: Optional[str] = None,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_user)
):
"""
Get list of missing translations.
Compares translations across locales to find missing keys.
Args:
locale: Optional locale to check (defaults to checking all non-default locales)
Returns:
List of missing translation keys
"""
api_logger.info(
f"Get missing translations request: locale={locale}, user={current_user.username}"
)
from app.core.config import settings
translation_service = get_translation_service()
default_locale = settings.I18N_DEFAULT_LANGUAGE
available_locales = translation_service.get_available_locales()
# Get default locale translations as reference
default_translations = translation_service._cache.get(default_locale, {})
# Collect all keys from default locale
def collect_keys(data, prefix=""):
keys = []
for key, value in data.items():
full_key = f"{prefix}.{key}" if prefix else key
if isinstance(value, dict):
keys.extend(collect_keys(value, full_key))
else:
keys.append(full_key)
return keys
default_keys = set()
for namespace, translations in default_translations.items():
namespace_keys = collect_keys(translations, namespace)
default_keys.update(namespace_keys)
# Find missing keys in target locale(s)
missing_by_locale = {}
target_locales = [locale] if locale else [
loc for loc in available_locales if loc != default_locale
]
for target_locale in target_locales:
if target_locale not in available_locales:
continue
target_translations = translation_service._cache.get(target_locale, {})
target_keys = set()
for namespace, translations in target_translations.items():
namespace_keys = collect_keys(translations, namespace)
target_keys.update(namespace_keys)
missing_keys = default_keys - target_keys
if missing_keys:
missing_by_locale[target_locale] = sorted(list(missing_keys))
response = MissingTranslationsResponse(missing_translations=missing_by_locale)
total_missing = sum(len(keys) for keys in missing_by_locale.values())
api_logger.info(f"Found {total_missing} missing translations across {len(missing_by_locale)} locales")
return success(data=response.dict(), msg=t("common.success.retrieved"))
@router.post("/reload", response_model=ApiResponse)
def reload_translations(
locale: Optional[str] = None,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Trigger hot reload of translation files (admin only).
Args:
locale: Optional locale to reload (defaults to reloading all locales)
Returns:
Reload status and statistics
"""
api_logger.info(
f"Reload translations request: locale={locale or 'all'}, "
f"admin={current_user.username}"
)
from app.core.config import settings
if not settings.I18N_ENABLE_HOT_RELOAD:
api_logger.warning("Hot reload is disabled in configuration")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=t("i18n.reload.disabled")
)
translation_service = get_translation_service()
try:
# Reload translations
translation_service.reload(locale)
# Get statistics
available_locales = translation_service.get_available_locales()
reloaded_locales = [locale] if locale else available_locales
response = ReloadResponse(
success=True,
reloaded_locales=reloaded_locales,
total_locales=len(available_locales)
)
api_logger.info(
f"Successfully reloaded translations for: {', '.join(reloaded_locales)}"
)
return success(data=response.dict(), msg=t("i18n.reload.success"))
except Exception as e:
api_logger.error(f"Failed to reload translations: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=t("i18n.reload.failed", error=str(e))
)
# ============================================================================
# Performance Monitoring APIs
# ============================================================================
@router.get("/metrics", response_model=ApiResponse)
def get_metrics(
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Get i18n performance metrics (admin only).
Returns:
Performance metrics including:
- Request counts
- Missing translations
- Timing statistics
- Locale usage
- Error counts
"""
api_logger.info(f"Get metrics request: admin={current_user.username}")
translation_service = get_translation_service()
metrics = translation_service.get_metrics_summary()
api_logger.info("Returning i18n metrics")
return success(data=metrics, msg=t("common.success.retrieved"))
@router.get("/metrics/cache", response_model=ApiResponse)
def get_cache_stats(
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Get cache statistics (admin only).
Returns:
Cache statistics including:
- Hit/miss rates
- LRU cache performance
- Loaded locales
- Memory usage
"""
api_logger.info(f"Get cache stats request: admin={current_user.username}")
translation_service = get_translation_service()
cache_stats = translation_service.get_cache_stats()
memory_usage = translation_service.get_memory_usage()
data = {
"cache": cache_stats,
"memory": memory_usage
}
api_logger.info("Returning cache statistics")
return success(data=data, msg=t("common.success.retrieved"))
@router.get("/metrics/prometheus")
def get_prometheus_metrics(
current_user: User = Depends(get_current_superuser)
):
"""
Get metrics in Prometheus format (admin only).
Returns:
Prometheus-formatted metrics as plain text
"""
api_logger.info(f"Get Prometheus metrics request: admin={current_user.username}")
from app.i18n.metrics import get_metrics
metrics = get_metrics()
prometheus_output = metrics.export_prometheus()
from fastapi.responses import PlainTextResponse
return PlainTextResponse(content=prometheus_output)
@router.post("/metrics/reset", response_model=ApiResponse)
def reset_metrics(
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Reset all metrics (admin only).
Returns:
Success message
"""
api_logger.info(f"Reset metrics request: admin={current_user.username}")
from app.i18n.metrics import get_metrics
metrics = get_metrics()
metrics.reset()
translation_service = get_translation_service()
translation_service.cache.reset_stats()
api_logger.info("Metrics reset completed")
return success(msg=t("i18n.metrics.reset_success"))
# ============================================================================
# Missing Translation Logging and Reporting APIs
# ============================================================================
@router.get("/logs/missing", response_model=ApiResponse)
def get_missing_translation_logs(
locale: Optional[str] = None,
limit: Optional[int] = 100,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Get missing translation logs (admin only).
Returns logged missing translations with context information.
Args:
locale: Optional locale filter
limit: Maximum number of entries to return (default: 100)
Returns:
Missing translation logs with context
"""
api_logger.info(
f"Get missing translation logs request: locale={locale}, "
f"limit={limit}, admin={current_user.username}"
)
translation_service = get_translation_service()
translation_logger = translation_service.translation_logger
# Get missing translations
missing_translations = translation_logger.get_missing_translations(locale)
# Get missing with context
missing_with_context = translation_logger.get_missing_with_context(locale, limit)
# Get statistics
statistics = translation_logger.get_statistics()
data = {
"missing_translations": missing_translations,
"recent_context": missing_with_context,
"statistics": statistics
}
api_logger.info(
f"Returning {statistics['total_missing']} missing translations"
)
return success(data=data, msg=t("common.success.retrieved"))
@router.get("/logs/missing/report", response_model=ApiResponse)
def generate_missing_translation_report(
locale: Optional[str] = None,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Generate a comprehensive missing translation report (admin only).
Args:
locale: Optional locale filter
Returns:
Comprehensive report with missing translations and statistics
"""
api_logger.info(
f"Generate missing translation report request: locale={locale}, "
f"admin={current_user.username}"
)
translation_service = get_translation_service()
translation_logger = translation_service.translation_logger
# Generate report
report = translation_logger.generate_report(locale)
api_logger.info(
f"Generated report with {report['total_missing']} missing translations"
)
return success(data=report, msg=t("common.success.retrieved"))
@router.post("/logs/missing/export", response_model=ApiResponse)
def export_missing_translations(
locale: Optional[str] = None,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Export missing translations to JSON file (admin only).
Args:
locale: Optional locale filter
Returns:
Export status and file path
"""
api_logger.info(
f"Export missing translations request: locale={locale}, "
f"admin={current_user.username}"
)
from datetime import datetime
translation_service = get_translation_service()
translation_logger = translation_service.translation_logger
# Generate filename with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
locale_suffix = f"_{locale}" if locale else "_all"
output_file = f"logs/i18n/missing_translations{locale_suffix}_{timestamp}.json"
# Export to file
translation_logger.export_to_json(output_file)
api_logger.info(f"Missing translations exported to: {output_file}")
return success(
data={"file_path": output_file},
msg=t("i18n.logs.export_success", file=output_file)
)
@router.delete("/logs/missing", response_model=ApiResponse)
def clear_missing_translation_logs(
locale: Optional[str] = None,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Clear missing translation logs (admin only).
Args:
locale: Optional locale to clear (clears all if not specified)
Returns:
Success message
"""
api_logger.info(
f"Clear missing translation logs request: locale={locale or 'all'}, "
f"admin={current_user.username}"
)
translation_service = get_translation_service()
translation_logger = translation_service.translation_logger
# Clear logs
translation_logger.clear(locale)
api_logger.info(f"Cleared missing translation logs for: {locale or 'all locales'}")
return success(msg=t("i18n.logs.clear_success"))

View File

@@ -19,7 +19,7 @@ from app.models import mcp_market_config_model
from app.models.user_model import User from app.models.user_model import User
from app.schemas import mcp_market_config_schema from app.schemas import mcp_market_config_schema
from app.schemas.response_schema import ApiResponse from app.schemas.response_schema import ApiResponse
from app.services import mcp_market_config_service from app.services import mcp_market_config_service, mcp_market_service
# Obtain a dedicated API logger # Obtain a dedicated API logger
api_logger = get_api_logger() api_logger = get_api_logger()
@@ -55,6 +55,12 @@ async def get_mcp_servers(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="The paging parameter must be greater than 0" detail="The paging parameter must be greater than 0"
) )
if page * pagesize > 100:
api_logger.warning(f"Paging parameters exceed ModelScope limit: page={page}, pagesize={pagesize}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The maximum number of MCP services can view is 100. Please visit the ModelScope MCP Plaza."
)
# 2. Query mcp market config information from the database # 2. Query mcp market config information from the database
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}") api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
@@ -64,14 +70,16 @@ async def get_mcp_servers(
if not db_mcp_market_config: if not db_mcp_market_config:
api_logger.warning( api_logger.warning(
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}") f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
raise HTTPException( return success(msg='The mcp market config does not exist or access is denied')
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market config does not exist or access is denied"
)
# 3. Execute paged query # 3. Execute paged query
api = MCPApi()
token = db_mcp_market_config.token token = db_mcp_market_config.token
if not token:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="MCP market config token is not configured"
)
api = MCPApi()
api.login(token) api.login(token)
body = { body = {
@@ -115,6 +123,17 @@ async def get_mcp_servers(
"has_next": True if page * pagesize < total else False "has_next": True if page * pagesize < total else False
} }
} }
# 5. Update mck_market.mcp_count
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=db_mcp_market_config.mcp_market_id, current_user=current_user)
if not db_mcp_market:
api_logger.warning(f"The mcp market does not exist or access is denied: mcp_market_id={db_mcp_market_config.mcp_market_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market does not exist or access is denied"
)
db_mcp_market.mcp_count = total
db.commit()
db.refresh(db_mcp_market)
return success(data=result, msg="Query of mcp servers list successful") return success(data=result, msg="Query of mcp servers list successful")
@@ -140,14 +159,16 @@ async def get_operational_mcp_servers(
if not db_mcp_market_config: if not db_mcp_market_config:
api_logger.warning( api_logger.warning(
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}") f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
raise HTTPException( return success(msg='The mcp market config does not exist or access is denied')
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market config does not exist or access is denied"
)
# 2. Execute paged query # 2. Execute paged query
api = MCPApi()
token = db_mcp_market_config.token token = db_mcp_market_config.token
if not token:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="MCP market config token is not configured"
)
api = MCPApi()
api.login(token) api.login(token)
url = f'{api.mcp_base_url}/operational' url = f'{api.mcp_base_url}/operational'
@@ -198,14 +219,16 @@ async def get_mcp_server(
if not db_mcp_market_config: if not db_mcp_market_config:
api_logger.warning( api_logger.warning(
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}") f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
raise HTTPException( return success(msg='The mcp market config does not exist or access is denied')
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market config does not exist or access is denied"
)
# 2. Get detailed information for a specific MCP Server # 2. Get detailed information for a specific MCP Server
api = MCPApi()
token = db_mcp_market_config.token token = db_mcp_market_config.token
if not token:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="MCP market config token is not configured"
)
api = MCPApi()
api.login(token) api.login(token)
result = api.get_mcp_server(server_id=server_id) result = api.get_mcp_server(server_id=server_id)
@@ -226,7 +249,26 @@ async def create_mcp_market_config(
try: try:
api_logger.debug(f"Start creating the mcp market config: {create_data.mcp_market_id}") api_logger.debug(f"Start creating the mcp market config: {create_data.mcp_market_id}")
# 1. Check if the mcp market name already exists # 1. Validate token can access ModelScope MCP market
if not create_data.token:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Token is required to access ModelScope MCP market"
)
try:
api = MCPApi()
api.login(create_data.token)
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
cookies = api.get_cookies(create_data.token)
r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies)
raise_for_http_status(r)
except Exception as e:
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unable to access ModelScope MCP market with the provided token: {str(e)}"
)
# 2. Check if the mcp market name already exists
db_mcp_market_config_exist = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=create_data.mcp_market_id, current_user=current_user) db_mcp_market_config_exist = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=create_data.mcp_market_id, current_user=current_user)
if db_mcp_market_config_exist: if db_mcp_market_config_exist:
api_logger.warning(f"The mcp market id already exists: {create_data.mcp_market_id}") api_logger.warning(f"The mcp market id already exists: {create_data.mcp_market_id}")
@@ -234,6 +276,30 @@ async def create_mcp_market_config(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The mcp market id already exists: {create_data.mcp_market_id}" detail=f"The mcp market id already exists: {create_data.mcp_market_id}"
) )
# 2. verify token
create_data.status = 1
try:
api = MCPApi()
token = create_data.token
api.login(token)
body = {
'filter': {},
'page_number': 1,
'page_size': 20,
'search': ""
}
cookies = api.get_cookies(token)
r = api.session.put(
url=api.mcp_base_url,
headers=api.builder_headers(api.headers),
json=body,
cookies=cookies)
raise_for_http_status(r)
except requests.exceptions.RequestException as e:
api_logger.error(f"Failed to get MCP servers: {str(e)}")
create_data.status = 0
# 3. create mcp_market_config
db_mcp_market_config = mcp_market_config_service.create_mcp_market_config(db=db, mcp_market_config=create_data, current_user=current_user) db_mcp_market_config = mcp_market_config_service.create_mcp_market_config(db=db, mcp_market_config=create_data, current_user=current_user)
api_logger.info( api_logger.info(
f"The mcp market config has been successfully created: (ID: {db_mcp_market_config.id})") f"The mcp market config has been successfully created: (ID: {db_mcp_market_config.id})")
@@ -262,10 +328,7 @@ async def get_mcp_market_config(
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user) db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
if not db_mcp_market_config: if not db_mcp_market_config:
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}") api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
raise HTTPException( return success(msg='The mcp market config does not exist or access is denied')
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market config does not exist or access is denied"
)
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})") api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)), return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
@@ -295,10 +358,7 @@ async def get_mcp_market_config_by_mcp_market_id(
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=mcp_market_id, current_user=current_user) db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
if not db_mcp_market_config: if not db_mcp_market_config:
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_id={mcp_market_id}") api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_id={mcp_market_id}")
raise HTTPException( return success(msg='The mcp market config does not exist or access is denied')
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market config does not exist or access is denied"
)
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})") api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)), return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
@@ -324,12 +384,25 @@ async def update_mcp_market_config(
if not db_mcp_market_config: if not db_mcp_market_config:
api_logger.warning( api_logger.warning(
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}") f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
raise HTTPException( return success(msg='The mcp market config does not exist or access is denied')
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market config does not exist or you do not have permission to access it"
)
# 2. Update fields (only update non-null fields) # 2. Validate new token if provided
if update_data.token is not None:
try:
api = MCPApi()
api.login(update_data.token)
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
cookies = api.get_cookies(update_data.token)
r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies)
raise_for_http_status(r)
except Exception as e:
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unable to access ModelScope MCP market with the provided token: {str(e)}"
)
# 3. Update fields (only update non-null fields)
api_logger.debug(f"Start updating the mcp market config fields: {mcp_market_config_id}") api_logger.debug(f"Start updating the mcp market config fields: {mcp_market_config_id}")
update_dict = update_data.dict(exclude_unset=True) update_dict = update_data.dict(exclude_unset=True)
updated_fields = [] updated_fields = []
@@ -344,7 +417,7 @@ async def update_mcp_market_config(
if updated_fields: if updated_fields:
api_logger.debug(f"updated fields: {', '.join(updated_fields)}") api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
# 3. Save to database # 4. Save to database
try: try:
db.commit() db.commit()
db.refresh(db_mcp_market_config) db.refresh(db_mcp_market_config)
@@ -357,7 +430,7 @@ async def update_mcp_market_config(
detail=f"The mcp market config update failed: {str(e)}" detail=f"The mcp market config update failed: {str(e)}"
) )
# 4. Return the updated mcp market config # 5. Return the updated mcp market config
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)), return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
msg="The mcp market config information updated successfully") msg="The mcp market config information updated successfully")
@@ -381,10 +454,7 @@ async def delete_mcp_market_config(
if not db_mcp_market_config: if not db_mcp_market_config:
api_logger.warning( api_logger.warning(
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}") f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
raise HTTPException( return success(msg='The mcp market config does not exist or access is denied')
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market config does not exist or you do not have permission to access it"
)
# 2. Deleting mcp market config # 2. Deleting mcp market config
mcp_market_config_service.delete_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user) mcp_market_config_service.delete_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)

View File

@@ -1,4 +1,5 @@
from fastapi import APIRouter, Depends, HTTPException, status, Query from fastapi import APIRouter, Depends, HTTPException, status, Query
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing import Optional from typing import Optional
from app.core.response_utils import success from app.core.response_utils import success
@@ -149,6 +150,21 @@ async def get_workspace_end_users(
return {uid: {"total": 0} for uid in end_user_ids} return {uid: {"total": 0} for uid in end_user_ids}
# 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据
try:
from app.celery_app import celery_app as _celery_app
_celery_app.send_task(
"app.tasks.init_implicit_emotions_for_users",
kwargs={"end_user_ids": end_user_ids},
)
_celery_app.send_task(
"app.tasks.init_interest_distribution_for_users",
kwargs={"end_user_ids": end_user_ids},
)
api_logger.info(f"已触发按需初始化任务,候选用户数: {len(end_user_ids)}")
except Exception as e:
api_logger.warning(f"触发按需初始化任务失败(不影响主流程): {e}")
# 并发执行配置查询和记忆数量查询 # 并发执行配置查询和记忆数量查询
memory_configs_map, memory_nums_map = await asyncio.gather( memory_configs_map, memory_nums_map = await asyncio.gather(
get_memory_configs(), get_memory_configs(),
@@ -178,6 +194,15 @@ async def get_workspace_end_users(
except Exception as e: except Exception as e:
api_logger.warning(f"Redis 缓存写入失败: {str(e)}") api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
# 触发社区聚类补全任务(异步,不阻塞接口响应)
# 对有 ExtractedEntity 但无 Community 节点的存量用户自动补跑全量聚类
try:
from app.tasks import init_community_clustering_for_users
init_community_clustering_for_users.delay(end_user_ids=end_user_ids)
api_logger.info(f"已触发社区聚类补全任务,候选用户数: {len(end_user_ids)}")
except Exception as e:
api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}")
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录") api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
return success(data=result, msg="宿主列表获取成功") return success(data=result, msg="宿主列表获取成功")
@@ -387,14 +412,15 @@ def get_current_user_rag_total_num(
@router.get("/rag_content", response_model=ApiResponse) @router.get("/rag_content", response_model=ApiResponse)
def get_rag_content( def get_rag_content(
end_user_id: str = Query(..., description="宿主ID"), end_user_id: str = Query(..., description="宿主ID"),
limit: int = Query(15, description="返回记录数"), page: int = Query(1, gt=0, description="页码从1开始"),
pagesize: int = Query(15, gt=0, le=100, description="每页返回记录数"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
""" """
获取当前宿主知识库中的chunk内容 获取当前宿主知识库中的chunk内容(分页)
""" """
data = memory_dashboard_service.get_rag_content(end_user_id, limit, db, current_user) data = memory_dashboard_service.get_rag_content(end_user_id, page, pagesize, db, current_user)
return success(data=data, msg="宿主RAGchunk数据获取成功") return success(data=data, msg="宿主RAGchunk数据获取成功")
@@ -407,25 +433,17 @@ async def get_chunk_summary_tag(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
""" """
获取chunk总结、提取的标签和人物形象 读取RAG摘要、标签和人物形象纯读库不触发生成
返回格式: 返回格式:
{ {
"summary": "chunk内容的总结", "summary": "用户摘要",
"tags": [ "tags": [{"tag": "标签1", "frequency": 5}, ...],
{"tag": "标签1", "frequency": 5}, "personas": ["产品设计师", ...],
{"tag": "标签2", "frequency": 3}, "generated": true/false // false表示尚未生产请调用 /generate_rag_profile
...
],
"personas": [
"产品设计师",
"旅行爱好者",
"摄影发烧友",
...
]
} }
""" """
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id}chunk摘要标签人物形象") api_logger.info(f"用户 {current_user.username} 取宿主 {end_user_id}RAG摘要/标签/人物形象")
data = await memory_dashboard_service.get_chunk_summary_and_tags( data = await memory_dashboard_service.get_chunk_summary_and_tags(
end_user_id=end_user_id, end_user_id=end_user_id,
@@ -435,8 +453,7 @@ async def get_chunk_summary_tag(
current_user=current_user current_user=current_user
) )
api_logger.info(f"成功获取chunk摘要、{len(data.get('tags', []))} 个标签和 {len(data.get('personas', []))} 个人物形象") return success(data=data, msg="获取成功")
return success(data=data, msg="chunk摘要、标签和人物形象获取成功")
@router.get("/chunk_insight", response_model=ApiResponse) @router.get("/chunk_insight", response_model=ApiResponse)
@@ -447,14 +464,18 @@ async def get_chunk_insight(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
""" """
获取chunk的洞察内容 读取RAG洞察报告纯读库不触发生成
返回格式: 返回格式:
{ {
"insight": "对chunk内容的深度洞察分析" "insight": "总体概述",
"behavior_pattern": "行为模式",
"key_findings": "关键发现",
"growth_trajectory": "成长轨迹",
"generated": true/false // false表示尚未生产请调用 /generate_rag_profile
} }
""" """
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id}chunk洞察") api_logger.info(f"用户 {current_user.username} 取宿主 {end_user_id}RAG洞察")
data = await memory_dashboard_service.get_chunk_insight( data = await memory_dashboard_service.get_chunk_insight(
end_user_id=end_user_id, end_user_id=end_user_id,
@@ -463,8 +484,37 @@ async def get_chunk_insight(
current_user=current_user current_user=current_user
) )
api_logger.info("成功获取chunk洞察") return success(data=data, msg="获取成功")
return success(data=data, msg="chunk洞察获取成功")
class GenerateRagProfileRequest(BaseModel):
end_user_id: str = Field(..., description="宿主ID")
limit: int = Field(15, description="参与生成的chunk数量上限")
max_tags: int = Field(10, description="最大标签数量")
@router.post("/generate_rag_profile", response_model=ApiResponse)
async def generate_rag_profile(
body: GenerateRagProfileRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
生产接口为RAG存储模式的宿主全量重新生成完整画像并持久化到end_user表。
每次请求都会重新生成,覆盖已有数据。
"""
api_logger.info(f"用户 {current_user.username} 触发RAG画像生产: end_user_id={body.end_user_id}")
data = await memory_dashboard_service.generate_rag_profile(
end_user_id=body.end_user_id,
limit=body.limit,
max_tags=body.max_tags,
db=db,
current_user=current_user,
)
api_logger.info(f"RAG画像生产完成: {data}")
return success(data=data, msg="RAG画像生产完成")
@router.get("/dashboard_data", response_model=ApiResponse) @router.get("/dashboard_data", response_model=ApiResponse)
@@ -553,9 +603,12 @@ async def dashboard_data(
) )
neo4j_data["total_memory"] = total_memory_data.get("total_memory_count", 0) neo4j_data["total_memory"] = total_memory_data.get("total_memory_count", 0)
# total_app: 统计当前空间下的所有app数量 # total_app: 统计当前空间下的所有app数量
from app.repositories import app_repository # 包含自有app + 被分享给本工作空间的app
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id) from app.services import app_service as _app_svc
neo4j_data["total_app"] = len(apps_orm) _, total_app = _app_svc.AppService(db).list_apps(
workspace_id=workspace_id, include_shared=True, pagesize=1
)
neo4j_data["total_app"] = total_app
api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}, 应用数量: {neo4j_data['total_app']}") api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}, 应用数量: {neo4j_data['total_app']}")
except Exception as e: except Exception as e:
api_logger.warning(f"获取记忆总量失败: {str(e)}") api_logger.warning(f"获取记忆总量失败: {str(e)}")

View File

@@ -1,3 +1,19 @@
"""
Memory Reflection Controller
This module provides REST API endpoints for managing memory reflection configurations
and operations. It handles reflection engine setup, configuration management, and
execution of self-reflection processes across memory systems.
Key Features:
- Reflection configuration management (save, retrieve, update)
- Workspace-wide reflection execution across multiple applications
- Individual configuration-based reflection runs
- Multi-language support for reflection outputs
- Integration with Neo4j memory storage and LLM models
- Comprehensive error handling and logging
"""
import asyncio import asyncio
import time import time
import uuid import uuid
@@ -28,9 +44,13 @@ from sqlalchemy.orm import Session
from app.utils.config_utils import resolve_config_id from app.utils.config_utils import resolve_config_id
# Load environment variables for configuration
load_dotenv() load_dotenv()
# Initialize API logger for request tracking and debugging
api_logger = get_api_logger() api_logger = get_api_logger()
# Configure router with prefix and tags for API organization
router = APIRouter( router = APIRouter(
prefix="/memory", prefix="/memory",
tags=["Memory"], tags=["Memory"],
@@ -43,7 +63,38 @@ async def save_reflection_config(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
"""Save reflection configuration to data_comfig table""" """
Save reflection configuration to memory config table
Persists reflection engine configuration settings to the data_config table,
including reflection parameters, model settings, and evaluation criteria.
Validates configuration parameters and ensures data consistency.
Args:
request: Memory reflection configuration data including:
- config_id: Configuration identifier to update
- reflection_enabled: Whether reflection is enabled
- reflection_period_in_hours: Reflection execution interval
- reflexion_range: Scope of reflection (partial/all)
- baseline: Reflection strategy (time/fact/hybrid)
- reflection_model_id: LLM model for reflection operations
- memory_verify: Enable memory verification checks
- quality_assessment: Enable quality assessment evaluation
current_user: Authenticated user saving the configuration
db: Database session for data operations
Returns:
dict: Success response with saved reflection configuration data
Raises:
HTTPException 400: If config_id is missing or parameters are invalid
HTTPException 500: If configuration save operation fails
Database Operations:
- Updates memory_config table with reflection settings
- Commits transaction and refreshes entity
- Maintains configuration consistency
"""
try: try:
config_id = request.config_id config_id = request.config_id
config_id = resolve_config_id(config_id, db) config_id = resolve_config_id(config_id, db)
@@ -54,6 +105,7 @@ async def save_reflection_config(
) )
api_logger.info(f"用户 {current_user.username} 保存反思配置config_id: {config_id}") api_logger.info(f"用户 {current_user.username} 保存反思配置config_id: {config_id}")
# Update reflection configuration in database
memory_config = MemoryConfigRepository.update_reflection_config( memory_config = MemoryConfigRepository.update_reflection_config(
db, db,
config_id=config_id, config_id=config_id,
@@ -66,6 +118,7 @@ async def save_reflection_config(
quality_assessment=request.quality_assessment quality_assessment=request.quality_assessment
) )
# Commit transaction and refresh entity
db.commit() db.commit()
db.refresh(memory_config) db.refresh(memory_config)
@@ -102,13 +155,55 @@ async def start_workspace_reflection(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
"""启动工作空间中所有匹配应用的反思功能""" """
Start reflection functionality for all matching applications in workspace
Initiates reflection processes across all applications within the user's current
workspace that have valid memory configurations. Processes each application's
configurations and associated end users, executing reflection operations
with proper error isolation and transaction management.
This endpoint serves as a workspace-wide reflection orchestrator, ensuring
that reflection failures for individual users don't affect other operations.
Args:
current_user: Authenticated user initiating workspace reflection
db: Database session for configuration queries
Returns:
dict: Success response with reflection results for all processed applications:
- app_id: Application identifier
- config_id: Memory configuration identifier
- end_user_id: End user identifier
- reflection_result: Individual reflection operation result
Processing Logic:
1. Retrieve all applications in the current workspace
2. Filter applications with valid memory configurations
3. For each configuration, find matching releases
4. Execute reflection for each end user with isolated transactions
5. Aggregate results with error handling per user
Error Handling:
- Individual user reflection failures are isolated
- Failed operations are logged and included in results
- Database transactions are isolated per user to prevent cascading failures
- Comprehensive error reporting for debugging
Raises:
HTTPException 500: If workspace reflection initialization fails
Performance Notes:
- Uses independent database sessions for each user operation
- Prevents transaction failures from affecting other users
- Comprehensive logging for operation tracking
"""
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
try: try:
api_logger.info(f"用户 {current_user.username} 启动workspace反思workspace_id: {workspace_id}") api_logger.info(f"用户 {current_user.username} 启动workspace反思workspace_id: {workspace_id}")
# 使用独立的数据库会话来获取工作空间应用详情,避免事务失败 # Use independent database session to get workspace app details, avoiding transaction failures
from app.db import get_db_context from app.db import get_db_context
with get_db_context() as query_db: with get_db_context() as query_db:
service = WorkspaceAppService(query_db) service = WorkspaceAppService(query_db)
@@ -116,8 +211,9 @@ async def start_workspace_reflection(
reflection_results = [] reflection_results = []
# Process each application in the workspace
for data in result['apps_detailed_info']: for data in result['apps_detailed_info']:
# 跳过没有配置的应用 # Skip applications without configurations
if not data['memory_configs']: if not data['memory_configs']:
api_logger.debug(f"应用 {data['id']} 没有memory_configs跳过") api_logger.debug(f"应用 {data['id']} 没有memory_configs跳过")
continue continue
@@ -126,22 +222,22 @@ async def start_workspace_reflection(
memory_configs = data['memory_configs'] memory_configs = data['memory_configs']
end_users = data['end_users'] end_users = data['end_users']
# 为每个配置和用户组合执行反思 # Execute reflection for each configuration and user combination
for config in memory_configs: for config in memory_configs:
config_id_str = str(config['config_id']) config_id_str = str(config['config_id'])
# 找到匹配此配置的所有release # Find all releases matching this configuration
matching_releases = [r for r in releases if str(r['config']) == config_id_str] matching_releases = [r for r in releases if str(r['config']) == config_id_str]
if not matching_releases: if not matching_releases:
api_logger.debug(f"配置 {config_id_str} 没有匹配的release") api_logger.debug(f"配置 {config_id_str} 没有匹配的release")
continue continue
# 为每个用户执行反思 - 使用独立的数据库会话 # Execute reflection for each user - using independent database sessions
for user in end_users: for user in end_users:
api_logger.info(f"为用户 {user['id']} 启动反思config_id: {config_id_str}") api_logger.info(f"为用户 {user['id']} 启动反思config_id: {config_id_str}")
# 为每个用户创建独立的数据库会话,避免事务失败影响其他用户 # Create independent database session for each user to avoid transaction failure impact
with get_db_context() as user_db: with get_db_context() as user_db:
try: try:
reflection_service = MemoryReflectionService(user_db) reflection_service = MemoryReflectionService(user_db)
@@ -184,14 +280,51 @@ async def start_reflection_configs(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
"""通过config_id查询memory_config表中的反思配置信息""" """
Query reflection configuration information by config_id
Retrieves detailed reflection configuration settings from the memory_config
table for a specific configuration ID. Provides comprehensive reflection
parameters including model settings, evaluation criteria, and operational flags.
Args:
config_id: Configuration identifier (UUID or integer) to query
current_user: Authenticated user making the request
db: Database session for data operations
Returns:
dict: Success response with detailed reflection configuration:
- config_id: Resolved configuration identifier
- reflection_enabled: Whether reflection is enabled for this config
- reflection_period_in_hours: Reflection execution interval
- reflexion_range: Scope of reflection operations (partial/all)
- baseline: Reflection strategy (time/fact/hybrid)
- reflection_model_id: LLM model identifier for reflection
- memory_verify: Memory verification flag
- quality_assessment: Quality assessment flag
Database Operations:
- Queries memory_config table by resolved config_id
- Retrieves all reflection-related configuration fields
- Resolves configuration ID for consistent formatting
Raises:
HTTPException 404: If configuration with specified ID is not found
HTTPException 500: If configuration query operation fails
ID Resolution:
- Supports both UUID and integer config_id formats
- Automatically resolves to appropriate internal format
- Maintains consistency across different ID representations
"""
config_id = resolve_config_id(config_id, db) config_id = resolve_config_id(config_id, db)
try: try:
config_id=resolve_config_id(config_id,db) config_id=resolve_config_id(config_id,db)
api_logger.info(f"用户 {current_user.username} 查询反思配置config_id: {config_id}") api_logger.info(f"用户 {current_user.username} 查询反思配置config_id: {config_id}")
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id) result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
memory_config_id = resolve_config_id(result.config_id, db) memory_config_id = resolve_config_id(result.config_id, db)
# 构建返回数据
# Build response data with comprehensive configuration details
reflection_config = { reflection_config = {
"config_id": memory_config_id, "config_id": memory_config_id,
"reflection_enabled": result.enable_self_reflexion, "reflection_enabled": result.enable_self_reflexion,
@@ -205,9 +338,11 @@ async def start_reflection_configs(
api_logger.info(f"成功查询反思配置config_id: {config_id}") api_logger.info(f"成功查询反思配置config_id: {config_id}")
return success(data=reflection_config, msg="反思配置查询成功") return success(data=reflection_config, msg="反思配置查询成功")
api_logger.info(f"Successfully queried reflection config, config_id: {config_id}")
return success(data=reflection_config, msg="Reflection configuration query successful")
except HTTPException: except HTTPException:
# 重新抛出HTTP异常 # Re-raise HTTP exceptions without modification
raise raise
except Exception as e: except Exception as e:
api_logger.error(f"查询反思配置失败: {str(e)}") api_logger.error(f"查询反思配置失败: {str(e)}")
@@ -223,13 +358,66 @@ async def reflection_run(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
"""Activate the reflection function for all matching applications in the workspace""" """
# 使用集中化的语言校验 Execute reflection engine with specified configuration
Runs the reflection engine using configuration parameters from the database.
Validates model availability, sets up the reflection engine with proper
configuration, and executes the reflection process with multi-language support.
This endpoint provides a test run capability for reflection configurations,
allowing users to validate their reflection settings and see results before
deploying to production environments.
Args:
config_id: Configuration identifier (UUID or integer) for reflection settings
language_type: Language preference header for output localization (optional)
current_user: Authenticated user executing the reflection
db: Database session for configuration queries
Returns:
dict: Success response with reflection execution results including:
- baseline: Reflection strategy used
- source_data: Input data processed
- memory_verifies: Memory verification results (if enabled)
- quality_assessments: Quality assessment results (if enabled)
- reflexion_data: Generated reflection insights and solutions
Configuration Validation:
- Verifies configuration exists in database
- Validates LLM model availability
- Falls back to default model if specified model is unavailable
- Ensures all required parameters are properly set
Reflection Engine Setup:
- Creates ReflectionConfig with database parameters
- Initializes Neo4j connector for memory access
- Sets up ReflectionEngine with validated model
- Configures language preferences for output
Error Handling:
- Model validation with fallback to default
- Configuration validation and error reporting
- Comprehensive logging for debugging
- Graceful handling of missing configurations
Raises:
HTTPException 404: If configuration is not found
HTTPException 500: If reflection execution fails
Performance Notes:
- Direct database query for configuration retrieval
- Model validation to prevent runtime failures
- Efficient reflection engine initialization
- Language-aware output processing
"""
# Use centralized language validation for consistent localization
language = get_language_from_header(language_type) language = get_language_from_header(language_type)
api_logger.info(f"用户 {current_user.username} 查询反思配置config_id: {config_id}") api_logger.info(f"用户 {current_user.username} 查询反思配置config_id: {config_id}")
config_id = resolve_config_id(config_id, db) config_id = resolve_config_id(config_id, db)
# 使用MemoryConfigRepository查询反思配置
# Query reflection configuration using MemoryConfigRepository
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id) result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
if not result: if not result:
raise HTTPException( raise HTTPException(
@@ -239,7 +427,7 @@ async def reflection_run(
api_logger.info(f"成功查询反思配置config_id: {config_id}") api_logger.info(f"成功查询反思配置config_id: {config_id}")
# 验证模型ID是否存在 # Validate model ID existence
model_id = result.reflection_model_id model_id = result.reflection_model_id
if model_id: if model_id:
try: try:
@@ -250,6 +438,7 @@ async def reflection_run(
# 可以设置为None让反思引擎使用默认模型 # 可以设置为None让反思引擎使用默认模型
model_id = None model_id = None
# Create reflection configuration with database parameters
config = ReflectionConfig( config = ReflectionConfig(
enabled=result.enable_self_reflexion, enabled=result.enable_self_reflexion,
iteration_period=result.iteration_period, iteration_period=result.iteration_period,
@@ -262,11 +451,13 @@ async def reflection_run(
model_id=model_id, model_id=model_id,
language_type=language_type language_type=language_type
) )
# Initialize Neo4j connector and reflection engine
connector = Neo4jConnector() connector = Neo4jConnector()
engine = ReflectionEngine( engine = ReflectionEngine(
config=config, config=config,
neo4j_connector=connector, neo4j_connector=connector,
llm_client=model_id # 传入验证后的 model_id llm_client=model_id # Pass validated model_id
) )
result=await (engine.reflection_run()) result=await (engine.reflection_run())

View File

@@ -1,3 +1,18 @@
"""
Memory Short Term Controller
This module provides REST API endpoints for managing short-term and long-term memory
data retrieval and analysis. It handles memory system statistics, data aggregation,
and provides comprehensive memory insights for end users.
Key Features:
- Short-term memory data retrieval and statistics
- Long-term memory data aggregation
- Entity count integration
- Multi-language response support
- Memory system analytics and reporting
"""
from typing import Optional from typing import Optional
from dotenv import load_dotenv from dotenv import load_dotenv
@@ -13,9 +28,13 @@ from app.models.user_model import User
from app.services.memory_short_service import LongService, ShortService from app.services.memory_short_service import LongService, ShortService
from app.services.memory_storage_service import search_entity from app.services.memory_storage_service import search_entity
# Load environment variables for configuration
load_dotenv() load_dotenv()
# Initialize API logger for request tracking and debugging
api_logger = get_api_logger() api_logger = get_api_logger()
# Configure router with prefix and tags for API organization
router = APIRouter( router = APIRouter(
prefix="/memory/short", prefix="/memory/short",
tags=["Memory"], tags=["Memory"],
@@ -27,24 +46,73 @@ async def short_term_configs(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
# 使用集中化的语言校验 """
Retrieve comprehensive short-term and long-term memory statistics
Provides a comprehensive overview of memory system data for a specific end user,
including short-term memory entries, long-term memory aggregations, entity counts,
and retrieval statistics. Supports multi-language responses based on request headers.
This endpoint serves as a central dashboard for memory system analytics, combining
data from multiple memory subsystems to provide a holistic view of user memory state.
Args:
end_user_id: Unique identifier for the end user whose memory data to retrieve
language_type: Language preference header for response localization (optional)
current_user: Authenticated user making the request (injected by dependency)
db: Database session for data operations (injected by dependency)
Returns:
dict: Success response containing comprehensive memory statistics:
- short_term: List of short-term memory entries with detailed data
- long_term: List of long-term memory aggregations and summaries
- entity: Count of entities associated with the end user
- retrieval_number: Total count of short-term memory retrievals
- long_term_number: Total count of long-term memory entries
Response Structure:
{
"code": 200,
"msg": "Short-term memory system data retrieved successfully",
"data": {
"short_term": [...], # Short-term memory entries
"long_term": [...], # Long-term memory data
"entity": 42, # Entity count
"retrieval_number": 156, # Short-term retrieval count
"long_term_number": 23 # Long-term memory count
}
}
Raises:
HTTPException: If end_user_id is invalid or data retrieval fails
Performance Notes:
- Combines multiple service calls for comprehensive data
- Entity search is performed asynchronously for better performance
- Response time depends on memory data volume for the specified user
"""
# Use centralized language validation for consistent localization
language = get_language_from_header(language_type) language = get_language_from_header(language_type)
# 获取短期记忆数据 # Retrieve short-term memory data and statistics
short_term=ShortService(end_user_id, db) short_term = ShortService(end_user_id, db)
short_result=short_term.get_short_databasets() short_result = short_term.get_short_databasets() # Get short-term memory entries
short_count=short_term.get_short_count() short_count = short_term.get_short_count() # Get short-term retrieval count
long_term=LongService(end_user_id, db) # Retrieve long-term memory data and aggregations
long_result=long_term.get_long_databasets() long_term = LongService(end_user_id, db)
long_result = long_term.get_long_databasets() # Get long-term memory entries
# Get entity count for the specified end user
entity_result = await search_entity(end_user_id) entity_result = await search_entity(end_user_id)
# Compile comprehensive memory statistics response
result = { result = {
'short_term': short_result, 'short_term': short_result, # Short-term memory entries
'long_term': long_result, 'long_term': long_result, # Long-term memory data
'entity': entity_result.get('num', 0), 'entity': entity_result.get('num', 0), # Entity count (default to 0 if not found)
"retrieval_number":short_count, "retrieval_number": short_count, # Short-term retrieval statistics
"long_term_number":len(long_result) "long_term_number": len(long_result) # Long-term memory entry count
} }
return success(data=result, msg="短期记忆系统数据获取成功") return success(data=result, msg="短期记忆系统数据获取成功")

View File

@@ -8,6 +8,7 @@ from app.core.response_utils import success
from app.db import get_db from app.db import get_db
from app.dependencies import get_current_user from app.dependencies import get_current_user
from app.models import User from app.models import User
from app.schemas import conversation_schema
from app.schemas.response_schema import ApiResponse from app.schemas.response_schema import ApiResponse
from app.services.conversation_service import ConversationService from app.services.conversation_service import ConversationService
@@ -90,11 +91,7 @@ def get_messages(
conversation_id, conversation_id,
) )
messages = [ messages = [
{ conversation_schema.Message.model_validate(message)
"role": message.role,
"content": message.content,
"created_at": int(message.created_at.timestamp() * 1000),
}
for message in messages_obj for message in messages_obj
] ]
return success(data=messages, msg="get conversation history success") return success(data=messages, msg="get conversation history success")

View File

@@ -13,7 +13,6 @@ from app.core.logging_config import get_business_logger
from app.core.response_utils import success, fail from app.core.response_utils import success, fail
from app.db import get_db, get_db_read from app.db import get_db, get_db_read
from app.dependencies import get_share_user_id, ShareTokenData from app.dependencies import get_share_user_id, ShareTokenData
from app.models.app_model import App
from app.models.app_model import AppType from app.models.app_model import AppType
from app.repositories import knowledge_repository from app.repositories import knowledge_repository
from app.repositories.end_user_repository import EndUserRepository from app.repositories.end_user_repository import EndUserRepository
@@ -22,6 +21,7 @@ from app.schemas import release_share_schema, conversation_schema
from app.schemas.response_schema import PageData, PageMeta from app.schemas.response_schema import PageData, PageMeta
from app.services import workspace_service from app.services import workspace_service
from app.services.app_chat_service import AppChatService, get_app_chat_service from app.services.app_chat_service import AppChatService, get_app_chat_service
from app.services.app_service import AppService
from app.services.auth_service import create_access_token from app.services.auth_service import create_access_token
from app.services.conversation_service import ConversationService from app.services.conversation_service import ConversationService
from app.services.release_share_service import ReleaseShareService from app.services.release_share_service import ReleaseShareService
@@ -215,8 +215,11 @@ def list_conversations(
service = SharedChatService(db) service = SharedChatService(db)
share, release = service.get_release_by_share_token(share_data.share_token, password) share, release = service.get_release_by_share_token(share_data.share_token, password)
end_user_repo = EndUserRepository(db) end_user_repo = EndUserRepository(db)
app_service = AppService(db)
app = app_service._get_app_or_404(share.app_id)
new_end_user = end_user_repo.get_or_create_end_user( new_end_user = end_user_repo.get_or_create_end_user(
app_id=share.app_id, app_id=share.app_id,
workspace_id=app.workspace_id,
other_id=other_id other_id=other_id
) )
logger.debug(new_end_user.id) logger.debug(new_end_user.id)
@@ -308,25 +311,29 @@ async def chat(
# Store end_user_id in database with original user_id # Store end_user_id in database with original user_id
end_user_repo = EndUserRepository(db) end_user_repo = EndUserRepository(db)
app_service = AppService(db)
app = app_service._get_app_or_404(share.app_id)
workspace_id = app.workspace_id
new_end_user = end_user_repo.get_or_create_end_user( new_end_user = end_user_repo.get_or_create_end_user(
app_id=share.app_id, app_id=share.app_id,
workspace_id=workspace_id,
other_id=other_id, other_id=other_id,
original_user_id=user_id # Save original user_id to other_id original_user_id=user_id
) )
end_user_id = str(new_end_user.id) end_user_id = str(new_end_user.id)
appid = share.app_id # appid = share.app_id
"""获取存储类型和工作空间的ID""" """获取存储类型和工作空间的ID"""
# 直接通过 SQLAlchemy 查询 app仅查询未删除的应用 # 直接通过 SQLAlchemy 查询 app仅查询未删除的应用
app = db.query(App).filter( # app = db.query(App).filter(
App.id == appid, # App.id == appid,
App.is_active.is_(True) # App.is_active.is_(True)
).first() # ).first()
if not app: # if not app:
raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND) # raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
workspace_id = app.workspace_id # workspace_id = app.workspace_id
# 直接从 workspace 获取 storage_type公开分享场景无需权限检查 # 直接从 workspace 获取 storage_type公开分享场景无需权限检查
storage_type = workspace_service.get_workspace_storage_type_without_auth( storage_type = workspace_service.get_workspace_storage_type_without_auth(
@@ -610,11 +617,11 @@ async def chat(
# 多 Agent 非流式返回 # 多 Agent 非流式返回
result = await app_chat_service.workflow_chat( result = await app_chat_service.workflow_chat(
message=payload.message, message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串 user_id=end_user_id, # 转换为字符串
variables=payload.variables, variables=payload.variables,
files=payload.files,
config=config, config=config,
web_search=payload.web_search, web_search=payload.web_search,
memory=payload.memory, memory=payload.memory,
@@ -654,17 +661,21 @@ async def config_query(
workflow_service = WorkflowService(db) workflow_service = WorkflowService(db)
content = { content = {
"app_type": release.app.type, "app_type": release.app.type,
"variables": workflow_service.get_start_node_variables(release.config) "variables": workflow_service.get_start_node_variables(release.config),
"memory": workflow_service.is_memory_enable(release.config),
"features": release.config.get("features")
} }
elif release.app.type == AppType.AGENT: elif release.app.type == AppType.AGENT:
content = { content = {
"app_type": release.app.type, "app_type": release.app.type,
"variables": release.config.get("variables") "variables": release.config.get("variables"),
"features": release.config.get("features")
} }
elif release.app.type == AppType.MULTI_AGENT: elif release.app.type == AppType.MULTI_AGENT:
content = { content = {
"app_type": release.app.type, "app_type": release.app.type,
"variables": [] "variables": [],
"features": release.config.get("features")
} }
else: else:
return fail(msg="Unsupported app type", code=BizCode.APP_TYPE_NOT_SUPPORTED) return fail(msg="Unsupported app type", code=BizCode.APP_TYPE_NOT_SUPPORTED)

View File

@@ -95,8 +95,8 @@ async def chat(
end_user_repo = EndUserRepository(db) end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user( new_end_user = end_user_repo.get_or_create_end_user(
app_id=app.id, app_id=app.id,
workspace_id=workspace_id,
other_id=other_id, other_id=other_id,
original_user_id=other_id # Save original user_id to other_id
) )
end_user_id = str(new_end_user.id) end_user_id = str(new_end_user.id)
web_search = True web_search = True
@@ -280,6 +280,7 @@ async def chat(
memory=memory, memory=memory,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id, user_rag_memory_id=user_rag_memory_id,
files=payload.files,
app_id=app.id, app_id=app.id,
workspace_id=workspace_id, workspace_id=workspace_id,
release_id=app.current_release.id release_id=app.current_release.id

View File

@@ -3,8 +3,11 @@ from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.schemas.tool_schema import ( from app.schemas.tool_schema import (
ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest, CustomToolTestRequest ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest,
CustomToolTestRequest, ToolActiveUpdate
) )
from app.core.response_utils import success from app.core.response_utils import success
@@ -14,6 +17,7 @@ from app.models import User
from app.models.tool_model import ToolType, ToolStatus, AuthType from app.models.tool_model import ToolType, ToolStatus, AuthType
from app.services.tool_service import ToolService from app.services.tool_service import ToolService
from app.schemas.response_schema import ApiResponse from app.schemas.response_schema import ApiResponse
from app.core.exceptions import BusinessException
router = APIRouter(prefix="/tools", tags=["Tool System"]) router = APIRouter(prefix="/tools", tags=["Tool System"])
@@ -103,7 +107,7 @@ async def create_tool(
val = getattr(request, key, None) val = getattr(request, key, None)
if val is not None: if val is not None:
request.config[key] = val request.config[key] = val
tool_id = service.create_tool( tool_id = await service.create_tool(
name=request.name, name=request.name,
tool_type=request.tool_type, tool_type=request.tool_type,
tenant_id=current_user.tenant_id, tenant_id=current_user.tenant_id,
@@ -113,6 +117,8 @@ async def create_tool(
tags=request.tags tags=request.tags
) )
return success(data={"tool_id": tool_id}, msg="工具创建成功") return success(data={"tool_id": tool_id}, msg="工具创建成功")
except BusinessException as e:
raise HTTPException(status_code=400, detail=e.message)
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
except Exception as e: except Exception as e:
@@ -153,7 +159,7 @@ async def delete_tool(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service) service: ToolService = Depends(get_tool_service)
): ):
"""删除工具""" """删除工具逻辑删除is_active=False"""
try: try:
success_flag = service.delete_tool(tool_id, current_user.tenant_id) success_flag = service.delete_tool(tool_id, current_user.tenant_id)
if not success_flag: if not success_flag:
@@ -165,6 +171,30 @@ async def delete_tool(
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.patch("/{tool_id}/active", response_model=ApiResponse)
async def set_tool_active(
tool_id: str,
request: ToolActiveUpdate,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""设置工具可用状态(启用/禁用)
- is_active=true: 启用工具
- is_active=false: 禁用工具(等同于删除,但可恢复)
"""
try:
success_flag = service.set_tool_active(tool_id, current_user.tenant_id, request.is_active)
if not success_flag:
raise HTTPException(status_code=404, detail="工具不存在")
action = "启用" if request.is_active else "禁用"
return success(msg=f"工具已{action}")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/execution/execute", response_model=ApiResponse) @router.post("/execution/execute", response_model=ApiResponse)
async def execute_tool( async def execute_tool(
request: ToolExecuteRequest, request: ToolExecuteRequest,
@@ -222,8 +252,10 @@ async def sync_mcp_tools(
try: try:
result = await service.sync_mcp_tools(tool_id, current_user.tenant_id) result = await service.sync_mcp_tools(tool_id, current_user.tenant_id)
if not result.get("success", False): if not result.get("success", False):
raise HTTPException(status_code=400, detail=result.get("message", "同步失败")) raise BusinessException(result.get("message", "工具列表同步失败"), BizCode.BAD_REQUEST)
return success(data=result, msg="MCP工具列表同步完成") return success(data=result, msg="MCP工具列表同步完成")
except BusinessException:
raise
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@@ -246,8 +278,10 @@ async def test_tool_connection(
# 普通连接测试 # 普通连接测试
result = await service.test_connection(tool_id, current_user.tenant_id) result = await service.test_connection(tool_id, current_user.tenant_id)
if result["success"] is False: if result["success"] is False:
raise HTTPException(status_code=400, detail=result["message"]) raise BusinessException(result["message"], BizCode.SERVICE_UNAVAILABLE)
return success(data=result, msg="连接测试完成") return success(data=result, msg="连接测试完成")
except BusinessException:
raise
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))

View File

@@ -1,6 +1,7 @@
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
import uuid import uuid
from typing import Callable
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
@@ -19,6 +20,7 @@ from app.services import user_service
from app.core.logging_config import get_api_logger from app.core.logging_config import get_api_logger
from app.core.response_utils import success from app.core.response_utils import success
from app.core.security import verify_password from app.core.security import verify_password
from app.i18n.dependencies import get_translator
# 获取API专用日志器 # 获取API专用日志器
api_logger = get_api_logger() api_logger = get_api_logger()
@@ -33,7 +35,8 @@ router = APIRouter(
def create_superuser( def create_superuser(
user: user_schema.UserCreate, user: user_schema.UserCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_superuser: User = Depends(get_current_superuser) current_superuser: User = Depends(get_current_superuser),
t: Callable = Depends(get_translator)
): ):
"""创建超级管理员(仅超级管理员可访问)""" """创建超级管理员(仅超级管理员可访问)"""
api_logger.info(f"超级管理员创建请求: {user.username}, email: {user.email}") api_logger.info(f"超级管理员创建请求: {user.username}, email: {user.email}")
@@ -42,7 +45,7 @@ def create_superuser(
api_logger.info(f"超级管理员创建成功: {result.username} (ID: {result.id})") api_logger.info(f"超级管理员创建成功: {result.username} (ID: {result.id})")
result_schema = user_schema.User.model_validate(result) result_schema = user_schema.User.model_validate(result)
return success(data=result_schema, msg="超级管理员创建成功") return success(data=result_schema, msg=t("users.create.superuser_success"))
@router.delete("/{user_id}", response_model=ApiResponse) @router.delete("/{user_id}", response_model=ApiResponse)
@@ -50,6 +53,7 @@ def delete_user(
user_id: uuid.UUID, user_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
): ):
"""停用用户(软删除)""" """停用用户(软删除)"""
api_logger.info(f"用户停用请求: user_id={user_id}, 操作者: {current_user.username}") api_logger.info(f"用户停用请求: user_id={user_id}, 操作者: {current_user.username}")
@@ -57,13 +61,14 @@ def delete_user(
db=db, user_id_to_deactivate=user_id, current_user=current_user db=db, user_id_to_deactivate=user_id, current_user=current_user
) )
api_logger.info(f"用户停用成功: {result.username} (ID: {result.id})") api_logger.info(f"用户停用成功: {result.username} (ID: {result.id})")
return success(msg="用户停用成功") return success(msg=t("users.delete.deactivate_success"))
@router.post("/{user_id}/activate", response_model=ApiResponse) @router.post("/{user_id}/activate", response_model=ApiResponse)
def activate_user( def activate_user(
user_id: uuid.UUID, user_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
): ):
"""激活用户""" """激活用户"""
api_logger.info(f"用户激活请求: user_id={user_id}, 操作者: {current_user.username}") api_logger.info(f"用户激活请求: user_id={user_id}, 操作者: {current_user.username}")
@@ -74,13 +79,14 @@ def activate_user(
api_logger.info(f"用户激活成功: {result.username} (ID: {result.id})") api_logger.info(f"用户激活成功: {result.username} (ID: {result.id})")
result_schema = user_schema.User.model_validate(result) result_schema = user_schema.User.model_validate(result)
return success(data=result_schema, msg="用户激活成功") return success(data=result_schema, msg=t("users.activate.success"))
@router.get("", response_model=ApiResponse) @router.get("", response_model=ApiResponse)
def get_current_user_info( def get_current_user_info(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
): ):
"""获取当前用户信息""" """获取当前用户信息"""
api_logger.info(f"当前用户信息请求: {current_user.username}") api_logger.info(f"当前用户信息请求: {current_user.username}")
@@ -105,7 +111,7 @@ def get_current_user_info(
break break
api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}") api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}")
return success(data=result_schema, msg="用户信息获取成功") return success(data=result_schema, msg=t("users.info.get_success"))
@router.get("/superusers", response_model=ApiResponse) @router.get("/superusers", response_model=ApiResponse)
@@ -113,6 +119,7 @@ def get_tenant_superusers(
include_inactive: bool = False, include_inactive: bool = False,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_superuser), current_user: User = Depends(get_current_superuser),
t: Callable = Depends(get_translator)
): ):
"""获取当前租户下的超管账号列表(仅超级管理员可访问)""" """获取当前租户下的超管账号列表(仅超级管理员可访问)"""
api_logger.info(f"获取租户超管列表请求: {current_user.username}") api_logger.info(f"获取租户超管列表请求: {current_user.username}")
@@ -125,7 +132,7 @@ def get_tenant_superusers(
api_logger.info(f"租户超管列表获取成功: count={len(superusers)}") api_logger.info(f"租户超管列表获取成功: count={len(superusers)}")
superusers_schema = [user_schema.User.model_validate(u) for u in superusers] superusers_schema = [user_schema.User.model_validate(u) for u in superusers]
return success(data=superusers_schema, msg="租户超管列表获取成功") return success(data=superusers_schema, msg=t("users.list.superusers_success"))
@@ -134,6 +141,7 @@ def get_user_info_by_id(
user_id: uuid.UUID, user_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
): ):
"""根据用户ID获取用户信息""" """根据用户ID获取用户信息"""
api_logger.info(f"获取用户信息请求: user_id={user_id}, 操作者: {current_user.username}") api_logger.info(f"获取用户信息请求: user_id={user_id}, 操作者: {current_user.username}")
@@ -144,7 +152,7 @@ def get_user_info_by_id(
api_logger.info(f"用户信息获取成功: {result.username}") api_logger.info(f"用户信息获取成功: {result.username}")
result_schema = user_schema.User.model_validate(result) result_schema = user_schema.User.model_validate(result)
return success(data=result_schema, msg="用户信息获取成功") return success(data=result_schema, msg=t("users.info.get_success"))
@router.put("/change-password", response_model=ApiResponse) @router.put("/change-password", response_model=ApiResponse)
@@ -152,6 +160,7 @@ async def change_password(
request: ChangePasswordRequest, request: ChangePasswordRequest,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
): ):
"""修改当前用户密码""" """修改当前用户密码"""
api_logger.info(f"用户密码修改请求: {current_user.username}") api_logger.info(f"用户密码修改请求: {current_user.username}")
@@ -164,7 +173,7 @@ async def change_password(
current_user=current_user current_user=current_user
) )
api_logger.info(f"用户密码修改成功: {current_user.username}") api_logger.info(f"用户密码修改成功: {current_user.username}")
return success(msg="密码修改成功") return success(msg=t("auth.password.change_success"))
@router.put("/admin/change-password", response_model=ApiResponse) @router.put("/admin/change-password", response_model=ApiResponse)
@@ -172,6 +181,7 @@ async def admin_change_password(
request: AdminChangePasswordRequest, request: AdminChangePasswordRequest,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_superuser), current_user: User = Depends(get_current_superuser),
t: Callable = Depends(get_translator)
): ):
"""超级管理员修改指定用户的密码""" """超级管理员修改指定用户的密码"""
api_logger.info(f"管理员密码修改请求: 管理员 {current_user.username} 修改用户 {request.user_id}") api_logger.info(f"管理员密码修改请求: 管理员 {current_user.username} 修改用户 {request.user_id}")
@@ -186,16 +196,17 @@ async def admin_change_password(
# 根据是否生成了随机密码来构造响应 # 根据是否生成了随机密码来构造响应
if request.new_password: if request.new_password:
api_logger.info(f"管理员密码修改成功: 用户 {request.user_id}") api_logger.info(f"管理员密码修改成功: 用户 {request.user_id}")
return success(msg="密码修改成功") return success(msg=t("auth.password.change_success"))
else: else:
api_logger.info(f"管理员密码重置成功: 用户 {request.user_id}, 随机密码已生成") api_logger.info(f"管理员密码重置成功: 用户 {request.user_id}, 随机密码已生成")
return success(data=generated_password, msg="密码重置成功") return success(data=generated_password, msg=t("auth.password.reset_success"))
@router.post("/verify_pwd", response_model=ApiResponse) @router.post("/verify_pwd", response_model=ApiResponse)
def verify_pwd( def verify_pwd(
request: VerifyPasswordRequest, request: VerifyPasswordRequest,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
): ):
"""验证当前用户密码""" """验证当前用户密码"""
api_logger.info(f"用户验证密码请求: {current_user.username}") api_logger.info(f"用户验证密码请求: {current_user.username}")
@@ -203,8 +214,8 @@ def verify_pwd(
is_valid = verify_password(request.password, current_user.hashed_password) is_valid = verify_password(request.password, current_user.hashed_password)
api_logger.info(f"用户密码验证结果: {current_user.username}, valid={is_valid}") api_logger.info(f"用户密码验证结果: {current_user.username}, valid={is_valid}")
if not is_valid: if not is_valid:
raise BusinessException("密码验证失败", code=BizCode.VALIDATION_FAILED) raise BusinessException(t("users.errors.password_verification_failed"), code=BizCode.VALIDATION_FAILED)
return success(data={"valid": is_valid}, msg="验证完成") return success(data={"valid": is_valid}, msg=t("common.success.retrieved"))
@router.post("/send-email-code", response_model=ApiResponse) @router.post("/send-email-code", response_model=ApiResponse)
@@ -212,6 +223,7 @@ async def send_email_code(
request: SendEmailCodeRequest, request: SendEmailCodeRequest,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
): ):
"""发送邮箱验证码""" """发送邮箱验证码"""
api_logger.info(f"用户请求发送邮箱验证码: {current_user.username}, email={request.email}") api_logger.info(f"用户请求发送邮箱验证码: {current_user.username}, email={request.email}")
@@ -219,7 +231,7 @@ async def send_email_code(
await user_service.send_email_code_method(db=db, email=request.email, user_id=current_user.id) await user_service.send_email_code_method(db=db, email=request.email, user_id=current_user.id)
api_logger.info(f"邮箱验证码已发送: {current_user.username}") api_logger.info(f"邮箱验证码已发送: {current_user.username}")
return success(msg="验证码已发送到您的邮箱,请查收") return success(msg=t("users.email.code_sent"))
@router.put("/change-email", response_model=ApiResponse) @router.put("/change-email", response_model=ApiResponse)
@@ -227,6 +239,7 @@ async def change_email(
request: VerifyEmailCodeRequest, request: VerifyEmailCodeRequest,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
): ):
"""验证验证码并修改邮箱""" """验证验证码并修改邮箱"""
api_logger.info(f"用户修改邮箱: {current_user.username}, new_email={request.new_email}") api_logger.info(f"用户修改邮箱: {current_user.username}, new_email={request.new_email}")
@@ -239,4 +252,51 @@ async def change_email(
) )
api_logger.info(f"用户邮箱修改成功: {current_user.username}") api_logger.info(f"用户邮箱修改成功: {current_user.username}")
return success(msg="邮箱修改成功") return success(msg=t("users.email.change_success"))
@router.get("/me/language", response_model=ApiResponse)
def get_current_user_language(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
):
"""获取当前用户的语言偏好"""
api_logger.info(f"获取用户语言偏好: {current_user.username}")
language = user_service.get_user_language_preference(
db=db,
user_id=current_user.id,
current_user=current_user
)
api_logger.info(f"用户语言偏好获取成功: {current_user.username}, language={language}")
return success(
data=user_schema.LanguagePreferenceResponse(language=language),
msg=t("users.language.get_success")
)
@router.put("/me/language", response_model=ApiResponse)
def update_current_user_language(
request: user_schema.LanguagePreferenceRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
):
"""设置当前用户的语言偏好"""
api_logger.info(f"更新用户语言偏好: {current_user.username}, language={request.language}")
updated_user = user_service.update_user_language_preference(
db=db,
user_id=current_user.id,
language=request.language,
current_user=current_user
)
api_logger.info(f"用户语言偏好更新成功: {current_user.username}, language={request.language}")
return success(
data=user_schema.LanguagePreferenceResponse(language=updated_user.preferred_language),
msg=t("users.language.update_success")
)

View File

@@ -17,6 +17,7 @@ from app.services.user_memory_service import (
UserMemoryService, UserMemoryService,
analytics_memory_types, analytics_memory_types,
analytics_graph_data, analytics_graph_data,
analytics_community_graph_data,
) )
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
from app.schemas.response_schema import ApiResponse from app.schemas.response_schema import ApiResponse
@@ -295,6 +296,42 @@ async def get_graph_data_api(
return fail(BizCode.INTERNAL_ERROR, "图数据查询失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "图数据查询失败", str(e))
@router.get("/analytics/community_graph", response_model=ApiResponse)
async def get_community_graph_data_api(
end_user_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
workspace_id = current_user.current_workspace_id
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询社区图谱但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(
f"社区图谱查询请求: end_user_id={end_user_id}, user={current_user.username}, "
f"workspace={workspace_id}"
)
try:
result = await analytics_community_graph_data(db=db, end_user_id=end_user_id)
if "message" in result and result["statistics"]["total_nodes"] == 0:
api_logger.warning(f"社区图谱查询返回空结果: {result.get('message')}")
return success(data=result, msg=result.get("message", "查询成功"))
api_logger.info(
f"成功获取社区图谱: end_user_id={end_user_id}, "
f"nodes={result['statistics']['total_nodes']}, "
f"edges={result['statistics']['total_edges']}"
)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"社区图谱查询失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "社区图谱查询失败", str(e))
@router.get("/read_end_user/profile", response_model=ApiResponse) @router.get("/read_end_user/profile", response_model=ApiResponse)
async def get_end_user_profile( async def get_end_user_profile(
end_user_id: str, end_user_id: str,

View File

@@ -14,6 +14,12 @@ from app.dependencies import (
get_current_user, get_current_user,
workspace_access_guard, workspace_access_guard,
) )
from app.i18n.dependencies import get_current_language, get_translator
from app.i18n.serializers import (
WorkspaceSerializer,
WorkspaceMemberSerializer,
WorkspaceInviteSerializer
)
from app.models.tenant_model import Tenants from app.models.tenant_model import Tenants
from app.models.user_model import User from app.models.user_model import User
from app.models.workspace_model import InviteStatus from app.models.workspace_model import InviteStatus
@@ -65,7 +71,9 @@ def get_workspaces(
include_current: bool = Query(True, description="是否包含当前工作空间"), include_current: bool = Query(True, description="是否包含当前工作空间"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
current_tenant: Tenants = Depends(get_current_tenant) current_tenant: Tenants = Depends(get_current_tenant),
language: str = Depends(get_current_language),
t: callable = Depends(get_translator)
): ):
"""获取当前租户下用户参与的所有工作空间 """获取当前租户下用户参与的所有工作空间
@@ -88,8 +96,13 @@ def get_workspaces(
) )
api_logger.info(f"成功获取 {len(workspaces)} 个工作空间") api_logger.info(f"成功获取 {len(workspaces)} 个工作空间")
workspaces_schema = [WorkspaceResponse.model_validate(w) for w in workspaces]
return success(data=workspaces_schema, msg="工作空间列表获取成功") # 使用序列化器添加国际化字段
serializer = WorkspaceSerializer()
workspaces_data = [WorkspaceResponse.model_validate(w).model_dump() for w in workspaces]
workspaces_i18n = serializer.serialize_list(workspaces_data, language)
return success(data=workspaces_i18n, msg=t("workspace.list_retrieved"))
@router.post("", response_model=ApiResponse) @router.post("", response_model=ApiResponse)
@@ -98,6 +111,8 @@ def create_workspace(
language_type: str = Header(default="zh", alias="X-Language-Type"), language_type: str = Header(default="zh", alias="X-Language-Type"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_superuser), current_user: User = Depends(get_current_superuser),
language: str = Depends(get_current_language),
t: callable = Depends(get_translator)
): ):
"""创建新的工作空间""" """创建新的工作空间"""
from app.core.language_utils import get_language_from_header from app.core.language_utils import get_language_from_header
@@ -118,8 +133,13 @@ def create_workspace(
f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, " f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, "
f"创建者: {current_user.username}, language={language}" f"创建者: {current_user.username}, language={language}"
) )
result_schema = WorkspaceResponse.model_validate(result)
return success(data=result_schema, msg="工作空间创建成功") # 使用序列化器添加国际化字段
serializer = WorkspaceSerializer()
result_data = WorkspaceResponse.model_validate(result).model_dump()
result_i18n = serializer.serialize(result_data, language)
return success(data=result_i18n, msg=t("workspace.created"))
@router.put("", response_model=ApiResponse) @router.put("", response_model=ApiResponse)
@cur_workspace_access_guard() @cur_workspace_access_guard()
@@ -127,6 +147,8 @@ def update_workspace(
workspace: WorkspaceUpdate, workspace: WorkspaceUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
language: str = Depends(get_current_language),
t: callable = Depends(get_translator)
): ):
"""更新工作空间""" """更新工作空间"""
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
@@ -139,14 +161,21 @@ def update_workspace(
user=current_user, user=current_user,
) )
api_logger.info(f"工作空间更新成功 - ID: {workspace_id}, 用户: {current_user.username}") api_logger.info(f"工作空间更新成功 - ID: {workspace_id}, 用户: {current_user.username}")
result_schema = WorkspaceResponse.model_validate(result)
return success(data=result_schema, msg="工作空间更新成功") # 使用序列化器添加国际化字段
serializer = WorkspaceSerializer()
result_data = WorkspaceResponse.model_validate(result).model_dump()
result_i18n = serializer.serialize(result_data, language)
return success(data=result_i18n, msg=t("workspace.updated"))
@router.get("/members", response_model=ApiResponse) @router.get("/members", response_model=ApiResponse)
@cur_workspace_access_guard() @cur_workspace_access_guard()
def get_cur_workspace_members( def get_cur_workspace_members(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
language: str = Depends(get_current_language),
t: callable = Depends(get_translator)
): ):
"""获取工作空间成员列表(关系序列化)""" """获取工作空间成员列表(关系序列化)"""
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {current_user.current_workspace_id} 的成员列表") api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {current_user.current_workspace_id} 的成员列表")
@@ -157,8 +186,14 @@ def get_cur_workspace_members(
user=current_user, user=current_user,
) )
api_logger.info(f"工作空间成员列表获取成功 - ID: {current_user.current_workspace_id}, 数量: {len(members)}") api_logger.info(f"工作空间成员列表获取成功 - ID: {current_user.current_workspace_id}, 数量: {len(members)}")
# 转换为表格项并使用序列化器添加国际化字段
table_items = _convert_members_to_table_items(members) table_items = _convert_members_to_table_items(members)
return success(data=table_items, msg="工作空间成员列表获取成功") serializer = WorkspaceMemberSerializer()
members_data = [item.model_dump() for item in table_items]
members_i18n = serializer.serialize_list(members_data, language)
return success(data=members_i18n, msg=t("workspace.members.list_retrieved"))
@router.put("/members", response_model=ApiResponse) @router.put("/members", response_model=ApiResponse)
@@ -168,6 +203,7 @@ def update_workspace_members(
updates: List[WorkspaceMemberUpdate], updates: List[WorkspaceMemberUpdate],
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
t: callable = Depends(get_translator)
): ):
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的成员角色") api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的成员角色")
@@ -178,7 +214,7 @@ def update_workspace_members(
user=current_user, user=current_user,
) )
api_logger.info(f"工作空间成员角色更新成功 - ID: {workspace_id}, 数量: {len(members)}") api_logger.info(f"工作空间成员角色更新成功 - ID: {workspace_id}, 数量: {len(members)}")
return success(msg="成员角色更新成功") return success(msg=t("workspace.members.role_updated"))
@router.delete("/members/{member_id}", response_model=ApiResponse) @router.delete("/members/{member_id}", response_model=ApiResponse)
@@ -187,6 +223,7 @@ def delete_workspace_member(
member_id: uuid.UUID, member_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
t: callable = Depends(get_translator)
): ):
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}") api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
@@ -198,7 +235,7 @@ def delete_workspace_member(
user=current_user, user=current_user,
) )
api_logger.info(f"工作空间成员删除成功 - ID: {workspace_id}, 成员: {member_id}") api_logger.info(f"工作空间成员删除成功 - ID: {workspace_id}, 成员: {member_id}")
return success(msg="成员删除成功") return success(msg=t("workspace.members.deleted"))
# 创建空间协作邀请 # 创建空间协作邀请
@@ -208,6 +245,8 @@ def create_workspace_invite(
invite_data: WorkspaceInviteCreate, invite_data: WorkspaceInviteCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
language: str = Depends(get_current_language),
t: callable = Depends(get_translator)
): ):
"""创建工作空间邀请""" """创建工作空间邀请"""
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
@@ -220,7 +259,12 @@ def create_workspace_invite(
user=current_user user=current_user
) )
api_logger.info(f"工作空间邀请创建成功 - 工作空间: {workspace_id}, 邮箱: {invite_data.email}") api_logger.info(f"工作空间邀请创建成功 - 工作空间: {workspace_id}, 邮箱: {invite_data.email}")
return success(data=result, msg="邀请创建成功")
# 使用序列化器添加国际化字段
serializer = WorkspaceInviteSerializer()
result_i18n = serializer.serialize(result, language)
return success(data=result_i18n, msg=t("workspace.invites.created"))
@router.get("/invites", response_model=ApiResponse) @router.get("/invites", response_model=ApiResponse)
@@ -232,6 +276,8 @@ def get_workspace_invites(
offset: int = Query(0, ge=0), offset: int = Query(0, ge=0),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
language: str = Depends(get_current_language),
t: callable = Depends(get_translator)
): ):
"""获取工作空间邀请列表""" """获取工作空间邀请列表"""
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
@@ -246,18 +292,30 @@ def get_workspace_invites(
offset=offset offset=offset
) )
api_logger.info(f"成功获取 {len(invites)} 个邀请记录") api_logger.info(f"成功获取 {len(invites)} 个邀请记录")
return success(data=invites, msg="邀请列表获取成功")
# 使用序列化器添加国际化字段
serializer = WorkspaceInviteSerializer()
invites_i18n = serializer.serialize_list(invites, language)
return success(data=invites_i18n, msg=t("workspace.invites.list_retrieved"))
@public_router.get("/invites/validate/{token}", response_model=ApiResponse) @public_router.get("/invites/validate/{token}", response_model=ApiResponse)
def get_workspace_invite_info( def get_workspace_invite_info(
token: str, token: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
language: str = Depends(get_current_language),
t: callable = Depends(get_translator)
): ):
"""获取工作空间邀请用户信息(无需认证)""" """获取工作空间邀请用户信息(无需认证)"""
result = workspace_service.validate_invite_token(db=db, token=token) result = workspace_service.validate_invite_token(db=db, token=token)
api_logger.info(f"工作空间邀请验证成功 - 邀请: {token}") api_logger.info(f"工作空间邀请验证成功 - 邀请: {token}")
return success(data=result, msg="邀请验证成功")
# 使用序列化器添加国际化字段
serializer = WorkspaceInviteSerializer()
result_i18n = serializer.serialize(result, language)
return success(data=result_i18n, msg=t("workspace.invites.validated"))
@router.delete("/invites/{invite_id}", response_model=ApiResponse) @router.delete("/invites/{invite_id}", response_model=ApiResponse)
@@ -267,6 +325,8 @@ def revoke_workspace_invite(
invite_id: uuid.UUID, invite_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
language: str = Depends(get_current_language),
t: callable = Depends(get_translator)
): ):
"""撤销工作空间邀请""" """撤销工作空间邀请"""
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
@@ -279,7 +339,12 @@ def revoke_workspace_invite(
user=current_user user=current_user
) )
api_logger.info(f"工作空间邀请撤销成功 - 邀请: {invite_id}") api_logger.info(f"工作空间邀请撤销成功 - 邀请: {invite_id}")
return success(data=result, msg="邀请撤销成功")
# 使用序列化器添加国际化字段
serializer = WorkspaceInviteSerializer()
result_i18n = serializer.serialize(result, language)
return success(data=result_i18n, msg=t("workspace.invites.revoked"))
# ==================== 公开邀请接口(无需认证) ==================== # ==================== 公开邀请接口(无需认证) ====================
@@ -302,6 +367,7 @@ def switch_workspace(
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
t: callable = Depends(get_translator)
): ):
"""切换工作空间""" """切换工作空间"""
api_logger.info(f"用户 {current_user.username} 请求切换工作空间为 {workspace_id}") api_logger.info(f"用户 {current_user.username} 请求切换工作空间为 {workspace_id}")
@@ -312,7 +378,7 @@ def switch_workspace(
user=current_user, user=current_user,
) )
api_logger.info(f"成功切换工作空间为 {workspace_id}") api_logger.info(f"成功切换工作空间为 {workspace_id}")
return success(msg="工作空间切换成功") return success(msg=t("workspace.switched"))
@router.get("/storage", response_model=ApiResponse) @router.get("/storage", response_model=ApiResponse)
@@ -320,6 +386,7 @@ def switch_workspace(
def get_workspace_storage_type( def get_workspace_storage_type(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
t: callable = Depends(get_translator)
): ):
"""获取当前工作空间的存储类型""" """获取当前工作空间的存储类型"""
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
@@ -331,7 +398,7 @@ def get_workspace_storage_type(
user=current_user user=current_user
) )
api_logger.info(f"成功获取工作空间 {workspace_id} 的存储类型: {storage_type}") api_logger.info(f"成功获取工作空间 {workspace_id} 的存储类型: {storage_type}")
return success(data={"storage_type": storage_type}, msg="存储类型获取成功") return success(data={"storage_type": storage_type}, msg=t("workspace.storage.type_retrieved"))
@router.get("/workspace_models", response_model=ApiResponse) @router.get("/workspace_models", response_model=ApiResponse)
@@ -339,6 +406,8 @@ def get_workspace_storage_type(
def workspace_models_configs( def workspace_models_configs(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
language: str = Depends(get_current_language),
t: callable = Depends(get_translator)
): ):
"""获取当前工作空间的模型配置llm, embedding, rerank""" """获取当前工作空间的模型配置llm, embedding, rerank"""
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
@@ -354,14 +423,14 @@ def workspace_models_configs(
api_logger.warning(f"工作空间 {workspace_id} 不存在或无权访问") api_logger.warning(f"工作空间 {workspace_id} 不存在或无权访问")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail="工作空间不存在或无权访问" detail=t("workspace.not_found")
) )
api_logger.info( api_logger.info(
f"成功获取工作空间 {workspace_id} 的模型配置: " f"成功获取工作空间 {workspace_id} 的模型配置: "
f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}" f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}"
) )
return success(data=WorkspaceModelsConfig.model_validate(configs), msg="模型配置获取成功") return success(data=WorkspaceModelsConfig.model_validate(configs), msg=t("workspace.models.config_retrieved"))
@router.put("/workspace_models", response_model=ApiResponse) @router.put("/workspace_models", response_model=ApiResponse)
@@ -370,6 +439,7 @@ def update_workspace_models_configs(
models_update: WorkspaceModelsUpdate, models_update: WorkspaceModelsUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
t: callable = Depends(get_translator)
): ):
"""更新当前工作空间的模型配置llm, embedding, rerank""" """更新当前工作空间的模型配置llm, embedding, rerank"""
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
@@ -386,5 +456,5 @@ def update_workspace_models_configs(
f"成功更新工作空间 {workspace_id} 的模型配置: " f"成功更新工作空间 {workspace_id} 的模型配置: "
f"llm={updated_workspace.llm}, embedding={updated_workspace.embedding}, rerank={updated_workspace.rerank}" f"llm={updated_workspace.llm}, embedding={updated_workspace.embedding}, rerank={updated_workspace.rerank}"
) )
return success(data=WorkspaceModelsConfig.model_validate(updated_workspace), msg="模型配置更新成功") return success(data=WorkspaceModelsConfig.model_validate(updated_workspace), msg=t("workspace.models.config_updated"))

View File

@@ -1,7 +1,6 @@
import json
import os import os
from pathlib import Path from pathlib import Path
from typing import Annotated, Any, Dict, Optional from typing import Annotated, Optional
from dotenv import load_dotenv from dotenv import load_dotenv
from pydantic import Field, TypeAdapter from pydantic import Field, TypeAdapter
@@ -115,6 +114,7 @@ class Settings:
S3_ACCESS_KEY_ID: str = os.getenv("S3_ACCESS_KEY_ID", "") S3_ACCESS_KEY_ID: str = os.getenv("S3_ACCESS_KEY_ID", "")
S3_SECRET_ACCESS_KEY: str = os.getenv("S3_SECRET_ACCESS_KEY", "") S3_SECRET_ACCESS_KEY: str = os.getenv("S3_SECRET_ACCESS_KEY", "")
S3_BUCKET_NAME: str = os.getenv("S3_BUCKET_NAME", "") S3_BUCKET_NAME: str = os.getenv("S3_BUCKET_NAME", "")
S3_ENDPOINT_URL: str = os.getenv("S3_ENDPOINT_URL", "")
# VOLC ASR settings # VOLC ASR settings
VOLC_APP_KEY: str = os.getenv("VOLC_APP_KEY", "") VOLC_APP_KEY: str = os.getenv("VOLC_APP_KEY", "")
@@ -162,6 +162,44 @@ class Settings:
# This controls the language used for memory summary titles and other generated content # This controls the language used for memory summary titles and other generated content
DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh") DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh")
# ========================================================================
# Internationalization (i18n) Configuration
# ========================================================================
# Default language for API responses
I18N_DEFAULT_LANGUAGE: str = os.getenv("I18N_DEFAULT_LANGUAGE", "zh")
# Supported languages (comma-separated)
I18N_SUPPORTED_LANGUAGES: list[str] = [
lang.strip()
for lang in os.getenv("I18N_SUPPORTED_LANGUAGES", "zh,en").split(",")
if lang.strip()
]
# Core locales directory (community edition)
# Use absolute path to work from any working directory
I18N_CORE_LOCALES_DIR: str = os.getenv(
"I18N_CORE_LOCALES_DIR",
os.path.join(os.path.dirname(os.path.dirname(__file__)), "locales")
)
# Premium locales directory (enterprise edition, optional)
I18N_PREMIUM_LOCALES_DIR: Optional[str] = os.getenv("I18N_PREMIUM_LOCALES_DIR", None)
# Enable translation cache
I18N_ENABLE_TRANSLATION_CACHE: bool = os.getenv("I18N_ENABLE_TRANSLATION_CACHE", "true").lower() == "true"
# LRU cache size for hot translations
I18N_LRU_CACHE_SIZE: int = int(os.getenv("I18N_LRU_CACHE_SIZE", "1000"))
# Enable hot reload of translation files
I18N_ENABLE_HOT_RELOAD: bool = os.getenv("I18N_ENABLE_HOT_RELOAD", "false").lower() == "true"
# Fallback language when translation is missing
I18N_FALLBACK_LANGUAGE: str = os.getenv("I18N_FALLBACK_LANGUAGE", "zh")
# Log missing translations
I18N_LOG_MISSING_TRANSLATIONS: bool = os.getenv("I18N_LOG_MISSING_TRANSLATIONS", "true").lower() == "true"
# Logging settings # Logging settings
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO") LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s") LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")

View File

@@ -1,16 +1,45 @@
from app.core.memory.agent.utils.llm_tools import ReadState, WriteState from app.core.memory.agent.utils.llm_tools import ReadState, WriteState
from app.schemas.memory_agent_schema import AgentMemoryDataset
def content_input_node(state: ReadState) -> ReadState: def content_input_node(state: ReadState) -> ReadState:
"""开始节点 - 提取内容并保持状态信息""" """
Start node - Extract content and maintain state information
Extracts the content from the first message in the state and returns it
as the data field while preserving all other state information.
Args:
state: ReadState containing messages and other state data
Returns:
ReadState: Updated state with extracted content in data field
"""
content = state['messages'][0].content if state.get('messages') else '' content = state['messages'][0].content if state.get('messages') else ''
# 返回内容并保持所有状态信息 # Return content and maintain all state information
for pronoun in AgentMemoryDataset.PRONOUN:
content = content.replace(pronoun, AgentMemoryDataset.NAME)
return {"data": content} return {"data": content}
def content_input_write(state: WriteState) -> WriteState: def content_input_write(state: WriteState) -> WriteState:
"""开始节点 - 提取内容并保持状态信息""" """
Start node - Extract content and maintain state information for write operations
Extracts the content from the first message in the state for write operations.
Args:
state: WriteState containing messages and other state data
Returns:
WriteState: Updated state with extracted content in data field
"""
content = state['messages'][0].content if state.get('messages') else '' content = state['messages'][0].content if state.get('messages') else ''
# 返回内容并保持所有状态信息 # Return content and maintain all state information
for pronoun in AgentMemoryDataset.PRONOUN:
content = content.replace(pronoun, AgentMemoryDataset.NAME)
return {"data": content} return {"data": content}

View File

@@ -19,19 +19,39 @@ logger = get_agent_logger(__name__)
class ProblemNodeService(LLMServiceMixin): class ProblemNodeService(LLMServiceMixin):
"""问题处理节点服务类""" """
Problem processing node service class
Handles problem decomposition and extension operations using LLM services.
Inherits from LLMServiceMixin to provide structured LLM calling capabilities.
Attributes:
template_service: Service for rendering Jinja2 templates
"""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.template_service = TemplateService(template_root) self.template_service = TemplateService(template_root)
# 创建全局服务实例 # Create global service instance
problem_service = ProblemNodeService() problem_service = ProblemNodeService()
async def Split_The_Problem(state: ReadState) -> ReadState: async def Split_The_Problem(state: ReadState) -> ReadState:
"""问题分解节点""" """
Problem decomposition node
Breaks down complex user queries into smaller, more manageable sub-problems.
Uses LLM to analyze the input and generate structured problem decomposition
with question types and reasoning.
Args:
state: ReadState containing user input and configuration
Returns:
ReadState: Updated state with problem decomposition results
"""
# 从状态中获取数据 # 从状态中获取数据
content = state.get('data', '') content = state.get('data', '')
end_user_id = state.get('end_user_id', '') end_user_id = state.get('end_user_id', '')
@@ -64,7 +84,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
# 添加更详细的日志记录 # 添加更详细的日志记录
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}") logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
# 验证结构化响应 # Validate structured response
if not structured or not hasattr(structured, 'root'): if not structured or not hasattr(structured, 'root'):
logger.warning("Split_The_Problem: 结构化响应为空或格式不正确") logger.warning("Split_The_Problem: 结构化响应为空或格式不正确")
split_result = json.dumps([], ensure_ascii=False) split_result = json.dumps([], ensure_ascii=False)
@@ -106,7 +126,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
exc_info=True exc_info=True
) )
# 提供更详细的错误信息 # Provide more detailed error information
error_details = { error_details = {
"error_type": type(e).__name__, "error_type": type(e).__name__,
"error_message": str(e), "error_message": str(e),
@@ -116,7 +136,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
logger.error(f"Split_The_Problem error details: {error_details}") logger.error(f"Split_The_Problem error details: {error_details}")
# 创建默认的空结果 # Create default empty result
result = { result = {
"context": json.dumps([], ensure_ascii=False), "context": json.dumps([], ensure_ascii=False),
"original": content, "original": content,
@@ -130,13 +150,25 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
} }
} }
# 返回更新后的状态,包含spit_context字段 # Return updated state including spit_context field
return {"spit_data": result} return {"spit_data": result}
async def Problem_Extension(state: ReadState) -> ReadState: async def Problem_Extension(state: ReadState) -> ReadState:
"""问题扩展节点""" """
# 获取原始数据和分解结果 Problem extension node
Extends the decomposed problems from Split_The_Problem node by generating
additional related questions and organizing them by original question.
Uses LLM to create comprehensive question extensions for better memory retrieval.
Args:
state: ReadState containing decomposed problems and configuration
Returns:
ReadState: Updated state with extended problem results
"""
# Get original data and decomposition results
start = time.time() start = time.time()
content = state.get('data', '') content = state.get('data', '')
data = state.get('spit_data', '')['context'] data = state.get('spit_data', '')['context']
@@ -182,7 +214,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}") logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
# 验证结构化响应 # Validate structured response
if not response_content or not hasattr(response_content, 'root'): if not response_content or not hasattr(response_content, 'root'):
logger.warning("Problem_Extension: 结构化响应为空或格式不正确") logger.warning("Problem_Extension: 结构化响应为空或格式不正确")
aggregated_dict = {} aggregated_dict = {}
@@ -216,7 +248,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
exc_info=True exc_info=True
) )
# 提供更详细的错误信息 # Provide more detailed error information
error_details = { error_details = {
"error_type": type(e).__name__, "error_type": type(e).__name__,
"error_message": str(e), "error_message": str(e),

View File

@@ -29,6 +29,18 @@ logger = get_agent_logger(__name__)
async def rag_config(state): async def rag_config(state):
"""
Configure RAG (Retrieval-Augmented Generation) settings
Creates configuration for knowledge base retrieval including similarity thresholds,
weights, and reranker settings.
Args:
state: Current state containing user_rag_memory_id
Returns:
dict: RAG configuration dictionary
"""
user_rag_memory_id = state.get('user_rag_memory_id', '') user_rag_memory_id = state.get('user_rag_memory_id', '')
kb_config = { kb_config = {
"knowledge_bases": [ "knowledge_bases": [
@@ -48,6 +60,19 @@ async def rag_config(state):
async def rag_knowledge(state, question): async def rag_knowledge(state, question):
"""
Retrieve knowledge using RAG approach
Performs knowledge retrieval from configured knowledge bases using the
provided question and returns formatted results.
Args:
state: Current state containing configuration
question: Question to search for
Returns:
tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results)
"""
kb_config = await rag_config(state) kb_config = await rag_config(state)
end_user_id = state.get('end_user_id', '') end_user_id = state.get('end_user_id', '')
user_rag_memory_id = state.get("user_rag_memory_id", '') user_rag_memory_id = state.get("user_rag_memory_id", '')
@@ -68,12 +93,24 @@ async def rag_knowledge(state, question):
async def llm_infomation(state: ReadState) -> ReadState: async def llm_infomation(state: ReadState) -> ReadState:
"""
Get LLM configuration information from state
Retrieves model configuration details including model ID and tenant ID
from the memory configuration in the current state.
Args:
state: ReadState containing memory configuration
Returns:
ReadState: Model configuration as Pydantic model
"""
memory_config = state.get('memory_config', None) memory_config = state.get('memory_config', None)
model_id = memory_config.llm_model_id model_id = memory_config.llm_model_id
tenant_id = memory_config.tenant_id tenant_id = memory_config.tenant_id
# 使用现有的 memory_config 而不是重新查询数据库 # Use existing memory_config instead of re-querying database
# 或者使用线程安全的数据库访问 # or use thread-safe database access
with get_db_context() as db: with get_db_context() as db:
result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id, tenant_id=tenant_id) result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id, tenant_id=tenant_id)
result_pydantic = model_schema.ModelConfig.model_validate(result_orm) result_pydantic = model_schema.ModelConfig.model_validate(result_orm)
@@ -82,16 +119,20 @@ async def llm_infomation(state: ReadState) -> ReadState:
async def clean_databases(data) -> str: async def clean_databases(data) -> str:
""" """
简化的数据库搜索结果清理函数 Simplified database search result cleaning function
Processes and cleans search results from various sources including
reranked results and time-based search results. Extracts text content
from structured data and returns as formatted string.
Args: Args:
data: 搜索结果数据 data: Search result data (can be string, dict, or other types)
Returns: Returns:
清理后的内容字符串 str: Cleaned content string
""" """
try: try:
# 解析JSON字符串 # Parse JSON string
if isinstance(data, str): if isinstance(data, str):
try: try:
data = json.loads(data) data = json.loads(data)
@@ -101,24 +142,24 @@ async def clean_databases(data) -> str:
if not isinstance(data, dict): if not isinstance(data, dict):
return str(data) return str(data)
# 获取结果数据 # Get result data
# with open("搜索结果.json","w",encoding='utf-8') as f: # with open("搜索结果.json","w",encoding='utf-8') as f:
# f.write(json.dumps(data, indent=4, ensure_ascii=False)) # f.write(json.dumps(data, indent=4, ensure_ascii=False))
results = data.get('results', data) results = data.get('results', data)
if not isinstance(results, dict): if not isinstance(results, dict):
return str(results) return str(results)
# 收集所有内容 # Collect all content
content_list = [] content_list = []
# 处理重排序结果 # Process reranked results
reranked = results.get('reranked_results', {}) reranked = results.get('reranked_results', {})
if reranked: if reranked:
for category in ['summaries', 'statements', 'chunks', 'entities']: for category in ['summaries', 'communities', 'statements', 'chunks', 'entities']:
items = reranked.get(category, []) items = reranked.get(category, [])
if isinstance(items, list): if isinstance(items, list):
content_list.extend(items) content_list.extend(items)
# 处理时间搜索结果 # Process time search results
time_search = results.get('time_search', {}) time_search = results.get('time_search', {})
if time_search: if time_search:
if isinstance(time_search, dict): if isinstance(time_search, dict):
@@ -128,11 +169,18 @@ async def clean_databases(data) -> str:
elif isinstance(time_search, list): elif isinstance(time_search, list):
content_list.extend(time_search) content_list.extend(time_search)
# 提取文本内容 # Extract text content对 community 按 name 去重(多次 tool 调用会产生重复)
text_parts = [] text_parts = []
seen_community_names = set()
for item in content_list: for item in content_list:
if isinstance(item, dict): if isinstance(item, dict):
text = item.get('statement') or item.get('content', '') # community 节点用 name 去重
if 'member_count' in item or 'core_entities' in item:
community_name = item.get('name') or item.get('id', '')
if community_name in seen_community_names:
continue
seen_community_names.add(community_name)
text = item.get('statement') or item.get('content') or item.get('summary', '')
if text: if text:
text_parts.append(text) text_parts.append(text)
elif isinstance(item, str): elif isinstance(item, str):
@@ -146,10 +194,19 @@ async def clean_databases(data) -> str:
async def retrieve_nodes(state: ReadState) -> ReadState: async def retrieve_nodes(state: ReadState) -> ReadState:
''' """
Retrieve information using simplified search approach
模型信息 Processes extended problems from previous nodes and performs retrieval
''' using either RAG or hybrid search based on storage type. Handles concurrent
processing of multiple questions and deduplicates results.
Args:
state: ReadState containing problem extensions and configuration
Returns:
ReadState: Updated state with retrieval results and intermediate outputs
"""
problem_extension = state.get('problem_extension', '')['context'] problem_extension = state.get('problem_extension', '')['context']
storage_type = state.get('storage_type', '') storage_type = state.get('storage_type', '')
@@ -163,7 +220,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
problem_list.append(data) problem_list.append(data)
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
# 创建异步任务处理单个问题 # Create async task to process individual questions
async def process_question_nodes(idx, question): async def process_question_nodes(idx, question):
try: try:
# Prepare search parameters based on storage type # Prepare search parameters based on storage type
@@ -209,7 +266,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
} }
} }
# 并发处理所有问题 # Process all questions concurrently
tasks = [process_question_nodes(idx, question) for idx, question in enumerate(problem_list)] tasks = [process_question_nodes(idx, question) for idx, question in enumerate(problem_list)]
databases_anser = await asyncio.gather(*tasks) databases_anser = await asyncio.gather(*tasks)
databases_data = { databases_data = {
@@ -257,7 +314,20 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
async def retrieve(state: ReadState) -> ReadState: async def retrieve(state: ReadState) -> ReadState:
# 从state中获取end_user_id """
Advanced retrieve function using LangChain agents and tools
Uses LangChain agents with specialized retrieval tools (time-based and hybrid)
to perform sophisticated information retrieval. Supports both RAG and traditional
memory storage approaches with concurrent processing and result deduplication.
Args:
state: ReadState containing problem extensions and configuration
Returns:
ReadState: Updated state with retrieval results and intermediate outputs
"""
# Get end_user_id from state
import time import time
start = time.time() start = time.time()
problem_extension = state.get('problem_extension', '')['context'] problem_extension = state.get('problem_extension', '')['context']
@@ -291,7 +361,11 @@ async def retrieve(state: ReadState) -> ReadState:
) )
time_retrieval_tool = create_time_retrieval_tool(end_user_id) time_retrieval_tool = create_time_retrieval_tool(end_user_id)
search_params = {"end_user_id": end_user_id, "return_raw_results": True} search_params = {
"end_user_id": end_user_id,
"return_raw_results": True,
"include": ["summaries", "statements", "chunks", "entities", "communities"],
}
hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params) hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params)
agent = create_agent( agent = create_agent(
llm, llm,
@@ -299,21 +373,21 @@ async def retrieve(state: ReadState) -> ReadState:
system_prompt=f"我是检索专家可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}" system_prompt=f"我是检索专家可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
) )
# 创建异步任务处理单个问题 # Create async task to process individual questions
import asyncio import asyncio
# 在模块级别定义信号量,限制最大并发数 # Define semaphore at module level to limit maximum concurrency
SEMAPHORE = asyncio.Semaphore(5) # 限制最多5个并发数据库操作 SEMAPHORE = asyncio.Semaphore(5) # Limit to maximum 5 concurrent database operations
async def process_question(idx, question): async def process_question(idx, question):
async with SEMAPHORE: # 限制并发 async with SEMAPHORE: # Limit concurrency
try: try:
if storage_type == "rag" and user_rag_memory_id: if storage_type == "rag" and user_rag_memory_id:
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state,
question) question)
else: else:
cleaned_query = question cleaned_query = question
# 使用 asyncio 在线程池中运行同步的 agent.invoke # Use asyncio to run synchronous agent.invoke in thread pool
import asyncio import asyncio
response = await asyncio.get_event_loop().run_in_executor( response = await asyncio.get_event_loop().run_in_executor(
None, None,
@@ -327,8 +401,32 @@ async def retrieve(state: ReadState) -> ReadState:
raw_results = tool_results['content'] raw_results = tool_results['content']
clean_content = await clean_databases(raw_results) clean_content = await clean_databases(raw_results)
# 社区展开:从 tool 返回结果中提取命中的 community
# 沿 BELONGS_TO_COMMUNITY 关系拉取关联 Statement 追加到 clean_content
_expanded_stmts_to_write = []
try:
results_dict = raw_results.get('results', {}) if isinstance(raw_results, dict) else {}
reranked = results_dict.get('reranked_results', {})
community_hits = reranked.get('communities', [])
if not community_hits:
community_hits = results_dict.get('communities', [])
if community_hits:
from app.core.memory.agent.services.search_service import expand_communities_to_statements
_expanded_stmts_to_write, new_texts = await expand_communities_to_statements(
community_results=community_hits,
end_user_id=end_user_id,
existing_content=clean_content,
)
if new_texts:
clean_content = clean_content + '\n' + '\n'.join(new_texts)
except Exception as parse_err:
logger.warning(f"[Retrieve] 解析社区命中结果失败,跳过展开: {parse_err}")
try: try:
raw_results = raw_results['results'] raw_results = raw_results['results']
# 写回展开结果,接口返回中可见(已在 helper 中清洗过字段)
if _expanded_stmts_to_write and isinstance(raw_results, dict):
raw_results.setdefault('reranked_results', {})['expanded_statements'] = _expanded_stmts_to_write
except Exception: except Exception:
raw_results = [] raw_results = []
@@ -362,7 +460,7 @@ async def retrieve(state: ReadState) -> ReadState:
} }
} }
# 并发处理所有问题 # Process all questions concurrently
import asyncio import asyncio
tasks = [process_question(idx, question) for idx, question in enumerate(problem_list)] tasks = [process_question(idx, question) for idx, question in enumerate(problem_list)]
databases_anser = await asyncio.gather(*tasks) databases_anser = await asyncio.gather(*tasks)

View File

@@ -23,18 +23,39 @@ logger = get_agent_logger(__name__)
class SummaryNodeService(LLMServiceMixin): class SummaryNodeService(LLMServiceMixin):
"""总结节点服务类""" """
Summary node service class
Handles summary generation operations using LLM services. Inherits from
LLMServiceMixin to provide structured LLM calling capabilities for
generating summaries from retrieved information.
Attributes:
template_service: Service for rendering Jinja2 templates
"""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.template_service = TemplateService(template_root) self.template_service = TemplateService(template_root)
# 创建全局服务实例 # Create global service instance
summary_service = SummaryNodeService() summary_service = SummaryNodeService()
async def rag_config(state): async def rag_config(state):
"""
Configure RAG (Retrieval-Augmented Generation) settings for summary operations
Creates configuration for knowledge base retrieval including similarity thresholds,
weights, and reranker settings specifically for summary generation.
Args:
state: Current state containing user_rag_memory_id
Returns:
dict: RAG configuration dictionary with knowledge base settings
"""
user_rag_memory_id = state.get('user_rag_memory_id', '') user_rag_memory_id = state.get('user_rag_memory_id', '')
kb_config = { kb_config = {
"knowledge_bases": [ "knowledge_bases": [
@@ -54,6 +75,23 @@ async def rag_config(state):
async def rag_knowledge(state, question): async def rag_knowledge(state, question):
"""
Retrieve knowledge using RAG approach for summary generation
Performs knowledge retrieval from configured knowledge bases using the
provided question and returns formatted results for summary processing.
Args:
state: Current state containing configuration
question: Question to search for in knowledge base
Returns:
tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results)
- retrieval_knowledge: List of retrieved knowledge chunks
- clean_content: Formatted content string
- cleaned_query: Processed query string
- raw_results: Raw retrieval results
"""
kb_config = await rag_config(state) kb_config = await rag_config(state)
end_user_id = state.get('end_user_id', '') end_user_id = state.get('end_user_id', '')
user_rag_memory_id = state.get("user_rag_memory_id", '') user_rag_memory_id = state.get("user_rag_memory_id", '')
@@ -74,6 +112,18 @@ async def rag_knowledge(state, question):
async def summary_history(state: ReadState) -> ReadState: async def summary_history(state: ReadState) -> ReadState:
"""
Retrieve conversation history for summary context
Gets the conversation history for the current user to provide context
for summary generation operations.
Args:
state: ReadState containing end_user_id
Returns:
ReadState: Conversation history data
"""
end_user_id = state.get("end_user_id", '') end_user_id = state.get("end_user_id", '')
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id) history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
return history return history
@@ -82,11 +132,26 @@ async def summary_history(state: ReadState) -> ReadState:
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model, async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,
search_mode) -> str: search_mode) -> str:
""" """
增强的summary_llm函数,包含更好的错误处理和数据验证 Enhanced summary_llm function with better error handling and data validation
Generates summaries using LLM with structured output. Includes fallback mechanisms
for handling LLM failures and provides robust error recovery.
Args:
state: ReadState containing current context
history: Conversation history for context
retrieve_info: Retrieved information to summarize
template_name: Jinja2 template name for prompt generation
operation_name: Type of operation (summary, input_summary, retrieve_summary)
response_model: Pydantic model for structured output
search_mode: Search mode flag ("0" for simple, "1" for complex)
Returns:
str: Generated summary text or fallback message
""" """
data = state.get("data", '') data = state.get("data", '')
# 构建系统提示词 # Build system prompt
if str(search_mode) == "0": if str(search_mode) == "0":
system_prompt = await summary_service.template_service.render_template( system_prompt = await summary_service.template_service.render_template(
template_name=template_name, template_name=template_name,
@@ -103,7 +168,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
retrieve_info=retrieve_info retrieve_info=retrieve_info
) )
try: try:
# 使用优化的LLM服务进行结构化输出 # Use optimized LLM service for structured output
with get_db_context() as db_session: with get_db_context() as db_session:
structured = await summary_service.call_llm_structured( structured = await summary_service.call_llm_structured(
state=state, state=state,
@@ -112,23 +177,23 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
response_model=response_model, response_model=response_model,
fallback_value=None fallback_value=None
) )
# 验证结构化响应 # Validate structured response
if structured is None: if structured is None:
logger.warning("LLM返回None使用默认回答") logger.warning("LLM返回None使用默认回答")
return "信息不足,无法回答" return "信息不足,无法回答"
# 根据操作类型提取答案 # Extract answer based on operation type
if operation_name == "summary": if operation_name == "summary":
aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答" aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答"
else: else:
# 处理RetrieveSummaryResponse # Handle RetrieveSummaryResponse
if hasattr(structured, 'data') and structured.data: if hasattr(structured, 'data') and structured.data:
aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答" aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答"
else: else:
logger.warning("结构化响应缺少data字段") logger.warning("结构化响应缺少data字段")
aimessages = "信息不足,无法回答" aimessages = "信息不足,无法回答"
# 验证答案不为空 # Validate answer is not empty
if not aimessages or aimessages.strip() == "": if not aimessages or aimessages.strip() == "":
aimessages = "信息不足,无法回答" aimessages = "信息不足,无法回答"
@@ -137,7 +202,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
except Exception as e: except Exception as e:
logger.error(f"结构化输出失败: {e}", exc_info=True) logger.error(f"结构化输出失败: {e}", exc_info=True)
# 尝试非结构化输出作为fallback # Try unstructured output as fallback
try: try:
logger.info("尝试非结构化输出作为fallback") logger.info("尝试非结构化输出作为fallback")
response = await summary_service.call_llm_simple( response = await summary_service.call_llm_simple(
@@ -148,9 +213,9 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
) )
if response and response.strip(): if response and response.strip():
# 简单清理响应 # Simple response cleaning
cleaned_response = response.strip() cleaned_response = response.strip()
# 移除可能的JSON标记 # Remove possible JSON markers
if cleaned_response.startswith('```'): if cleaned_response.startswith('```'):
lines = cleaned_response.split('\n') lines = cleaned_response.split('\n')
cleaned_response = '\n'.join(lines[1:-1]) cleaned_response = '\n'.join(lines[1:-1])
@@ -165,6 +230,19 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
async def summary_redis_save(state: ReadState, aimessages) -> ReadState: async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
"""
Save summary results to Redis session storage
Stores the generated summary and user query in Redis for session management
and conversation history tracking.
Args:
state: ReadState containing user and query information
aimessages: Generated summary message to save
Returns:
ReadState: Updated state after saving to Redis
"""
data = state.get("data", '') data = state.get("data", '')
end_user_id = state.get("end_user_id", '') end_user_id = state.get("end_user_id", '')
await SessionService(store).save_session( await SessionService(store).save_session(
@@ -179,6 +257,20 @@ async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState: async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState:
"""
Format summary results for different output types
Creates structured output formats for both input summary and retrieval summary
operations, including metadata and intermediate results for frontend display.
Args:
state: ReadState containing storage and user information
aimessages: Generated summary message
raw_results: Raw search/retrieval results
Returns:
tuple: (input_summary, retrieve_summary) formatted result dictionaries
"""
storage_type = state.get("storage_type", '') storage_type = state.get("storage_type", '')
user_rag_memory_id = state.get("user_rag_memory_id", '') user_rag_memory_id = state.get("user_rag_memory_id", '')
data = state.get("data", '') data = state.get("data", '')
@@ -217,6 +309,19 @@ async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState
async def Input_Summary(state: ReadState) -> ReadState: async def Input_Summary(state: ReadState) -> ReadState:
"""
Generate quick input summary from retrieved information
Performs fast retrieval and generates a quick summary response for user queries.
This function prioritizes speed by only searching summary nodes and provides
immediate feedback to users.
Args:
state: ReadState containing user query, storage configuration, and context
Returns:
ReadState: Dictionary containing summary results with status and metadata
"""
start = time.time() start = time.time()
storage_type = state.get("storage_type", '') storage_type = state.get("storage_type", '')
memory_config = state.get('memory_config', None) memory_config = state.get('memory_config', None)
@@ -229,13 +334,22 @@ async def Input_Summary(state: ReadState) -> ReadState:
"end_user_id": end_user_id, "end_user_id": end_user_id,
"question": data, "question": data,
"return_raw_results": True, "return_raw_results": True,
"include": ["summaries"] # Only search summary nodes for faster performance "include": ["summaries", "communities"] # MemorySummary 和 Community 同为高维度概括节点
} }
try: try:
if storage_type != "rag": if storage_type != "rag":
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(
memory_config=memory_config) **search_params,
memory_config=memory_config,
expand_communities=False, # 路径 "2" 只需要 community 的 summary 文本,不展开到 Statement
)
# 调试:打印 community 检索结果数量
if raw_results and isinstance(raw_results, dict):
reranked = raw_results.get('reranked_results', {})
community_hits = reranked.get('communities', [])
logger.debug(f"[Input_Summary] community 命中数: {len(community_hits)}, "
f"summary 命中数: {len(reranked.get('summaries', []))}")
else: else:
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data) retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
except Exception as e: except Exception as e:
@@ -266,6 +380,19 @@ async def Input_Summary(state: ReadState) -> ReadState:
async def Retrieve_Summary(state: ReadState) -> ReadState: async def Retrieve_Summary(state: ReadState) -> ReadState:
"""
Generate comprehensive summary from retrieved expansion issues
Processes retrieved expansion issues and generates a detailed summary using LLM.
This function handles complex retrieval results and provides comprehensive answers
based on expanded query results.
Args:
state: ReadState containing retrieve data with expansion issues
Returns:
ReadState: Dictionary containing comprehensive summary results
"""
retrieve = state.get("retrieve", '') retrieve = state.get("retrieve", '')
history = await summary_history(state) history = await summary_history(state)
import json import json
@@ -299,13 +426,26 @@ async def Retrieve_Summary(state: ReadState) -> ReadState:
duration = 0.0 duration = 0.0
log_time('Retrieval summary', duration) log_time('Retrieval summary', duration)
# 修复协程调用 - await,然后访问返回值 # Fixed coroutine call - await first, then access return value
summary_result = await summary_prompt(state, aimessages, retrieve_info_str) summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
summary = summary_result[1] summary = summary_result[1]
return {"summary": summary} return {"summary": summary}
async def Summary(state: ReadState) -> ReadState: async def Summary(state: ReadState) -> ReadState:
"""
Generate final comprehensive summary from verified data
Creates the final summary using verified expansion issues and conversation history.
This function processes verified data to generate the most comprehensive and
accurate response to user queries.
Args:
state: ReadState containing verified data and query information
Returns:
ReadState: Dictionary containing final summary results
"""
start = time.time() start = time.time()
query = state.get("data", '') query = state.get("data", '')
verify = state.get("verify", '') verify = state.get("verify", '')
@@ -336,13 +476,26 @@ async def Summary(state: ReadState) -> ReadState:
duration = 0.0 duration = 0.0
log_time('Retrieval summary', duration) log_time('Retrieval summary', duration)
# 修复协程调用 - await,然后访问返回值 # Fixed coroutine call - await first, then access return value
summary_result = await summary_prompt(state, aimessages, retrieve_info_str) summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
summary = summary_result[1] summary = summary_result[1]
return {"summary": summary} return {"summary": summary}
async def Summary_fails(state: ReadState) -> ReadState: async def Summary_fails(state: ReadState) -> ReadState:
"""
Generate fallback summary when normal summary process fails
Provides a fallback summary generation mechanism when the standard summary
process encounters errors or fails to produce satisfactory results. Uses
a specialized failure template to handle edge cases.
Args:
state: ReadState containing verified data and failure context
Returns:
ReadState: Dictionary containing fallback summary results
"""
storage_type = state.get("storage_type", '') storage_type = state.get("storage_type", '')
user_rag_memory_id = state.get("user_rag_memory_id", '') user_rag_memory_id = state.get("user_rag_memory_id", '')
history = await summary_history(state) history = await summary_history(state)

View File

@@ -18,24 +18,46 @@ logger = get_agent_logger(__name__)
class VerificationNodeService(LLMServiceMixin): class VerificationNodeService(LLMServiceMixin):
"""验证节点服务类""" """
Verification node service class
Handles data verification operations using LLM services. Inherits from
LLMServiceMixin to provide structured LLM calling capabilities for
verifying and validating retrieved information.
Attributes:
template_service: Service for rendering Jinja2 templates
"""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.template_service = TemplateService(template_root) self.template_service = TemplateService(template_root)
# 创建全局服务实例 # Create global service instance
verification_service = VerificationNodeService() verification_service = VerificationNodeService()
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult): async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
"""处理验证结果并生成输出格式""" """
Process verification results and generate output format
Transforms VerificationResult objects into structured output format suitable
for frontend consumption. Handles conversion of VerificationItem objects to
dictionary format and adds metadata for tracking.
Args:
state: ReadState containing storage and user configuration
messages_deal: VerificationResult containing verification outcomes
Returns:
dict: Formatted verification result with status and metadata
"""
storage_type = state.get('storage_type', '') storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '') user_rag_memory_id = state.get('user_rag_memory_id', '')
data = state.get('data', '') data = state.get('data', '')
# VerificationItem 对象转换为字典列表 # Convert VerificationItem objects to dictionary list
verified_data = [] verified_data = []
if messages_deal.expansion_issue: if messages_deal.expansion_issue:
for item in messages_deal.expansion_issue: for item in messages_deal.expansion_issue:
@@ -89,7 +111,7 @@ async def Verify(state: ReadState):
logger.info("Verify: 开始渲染模板") logger.info("Verify: 开始渲染模板")
# 生成 JSON schema 以指导 LLM 输出正确格式 # Generate JSON schema to guide LLM output format
json_schema = VerificationResult.model_json_schema() json_schema = VerificationResult.model_json_schema()
system_prompt = await verification_service.template_service.render_template( system_prompt = await verification_service.template_service.render_template(
@@ -104,8 +126,8 @@ async def Verify(state: ReadState):
# 使用优化的LLM服务添加超时保护 # 使用优化的LLM服务添加超时保护
logger.info("Verify: 开始调用 LLM") logger.info("Verify: 开始调用 LLM")
try: try:
# 添加 asyncio.wait_for 超时包裹,防止无限等待 # Add asyncio.wait_for timeout wrapper to prevent infinite waiting
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长) # Timeout set to 150 seconds (slightly longer than LLM config's 120 seconds)
with get_db_context() as db_session: with get_db_context() as db_session:
structured = await asyncio.wait_for( structured = await asyncio.wait_for(
@@ -122,7 +144,7 @@ async def Verify(state: ReadState):
"reason": "验证失败或超时" "reason": "验证失败或超时"
} }
), ),
timeout=150.0 # 150秒超时 timeout=150.0 # 150 second timeout
) )
logger.info(f"Verify: LLM 调用完成result={structured}") logger.info(f"Verify: LLM 调用完成result={structured}")
except asyncio.TimeoutError: except asyncio.TimeoutError:

View File

@@ -33,7 +33,19 @@ from app.core.memory.agent.langgraph_graph.routing.routers import (
@asynccontextmanager @asynccontextmanager
async def make_read_graph(): async def make_read_graph():
"""创建并返回 LangGraph 工作流""" """
Create and return a LangGraph workflow for memory reading operations
Builds a state graph workflow that handles memory retrieval, problem analysis,
verification, and summarization. The workflow includes nodes for content input,
problem splitting, retrieval, verification, and various summary operations.
Yields:
StateGraph: Compiled LangGraph workflow for memory reading
Raises:
Exception: If workflow creation fails
"""
try: try:
# Build workflow graph # Build workflow graph
workflow = StateGraph(ReadState) workflow = StateGraph(ReadState)
@@ -48,7 +60,7 @@ async def make_read_graph():
workflow.add_node("Summary", Summary) workflow.add_node("Summary", Summary)
workflow.add_node("Summary_fails", Summary_fails) workflow.add_node("Summary_fails", Summary_fails)
# 添加边 # Add edges to define workflow flow
workflow.add_edge(START, "content_input") workflow.add_edge(START, "content_input")
workflow.add_conditional_edges("content_input", Split_continue) workflow.add_conditional_edges("content_input", Split_continue)
workflow.add_edge("Input_Summary", END) workflow.add_edge("Input_Summary", END)
@@ -63,7 +75,7 @@ async def make_read_graph():
'''-----''' '''-----'''
# workflow.add_edge("Retrieve", END) # workflow.add_edge("Retrieve", END)
# 编译工作流 # Compile workflow
graph = workflow.compile() graph = workflow.compile()
yield graph yield graph
@@ -72,108 +84,3 @@ async def make_read_graph():
raise raise
finally: finally:
print("工作流创建完成") print("工作流创建完成")
async def main():
"""主函数 - 运行工作流"""
message = "昨天有什么好看的电影"
end_user_id = '88a459f5_text09' # 组ID
storage_type = 'neo4j' # 存储类型
search_switch = '1' # 搜索开关
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
# 获取数据库会话
db_session = next(get_db())
config_service = MemoryConfigService(db_session)
memory_config = config_service.load_memory_config(
config_id=17, # 改为整数
service_name="MemoryAgentService"
)
import time
start = time.time()
try:
async with make_read_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
"end_user_id": end_user_id
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
"memory_config": memory_config}
# 获取节点更新信息
_intermediate_outputs = []
summary = ''
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
config=config
):
for node_name, node_data in update_event.items():
print(f"处理节点: {node_name}")
# 处理不同Summary节点的返回结构
if 'Summary' in node_name:
if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
summary = node_data['InputSummary']['summary_result']
elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']:
summary = node_data['RetrieveSummary']['summary_result']
elif 'summary' in node_data and 'summary_result' in node_data['summary']:
summary = node_data['summary']['summary_result']
elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']:
summary = node_data['SummaryFails']['summary_result']
spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
if spit_data and spit_data != [] and spit_data != {}:
_intermediate_outputs.append(spit_data)
# Problem_Extension 节点
problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
if problem_extension and problem_extension != [] and problem_extension != {}:
_intermediate_outputs.append(problem_extension)
# Retrieve 节点
retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
_intermediate_outputs.extend(retrieve_node)
# Verify 节点
verify_n = node_data.get('verify', {}).get('_intermediate', None)
if verify_n and verify_n != [] and verify_n != {}:
_intermediate_outputs.append(verify_n)
# Summary 节点
summary_n = node_data.get('summary', {}).get('_intermediate', None)
if summary_n and summary_n != [] and summary_n != {}:
_intermediate_outputs.append(summary_n)
# # 过滤掉空值
# _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
#
# # 优化搜索结果
# print("=== 开始优化搜索结果 ===")
# optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
# result=reorder_output_results(optimized_outputs)
# # 保存优化后的结果到文件
# with open('_intermediate_outputs_optimized.json', 'w', encoding='utf-8') as f:
# import json
# f.write(json.dumps(result, indent=4, ensure_ascii=False))
#
print(f"=== 最终摘要 ===")
print(summary)
except Exception as e:
import traceback
traceback.print_exc()
finally:
db_session.close()
end = time.time()
print(100 * 'y')
print(f"总耗时: {end - start}s")
print(100 * 'y')
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View File

@@ -1,13 +1,13 @@
from typing import Literal from typing import Literal
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
counter = COUNTState(limit=3) counter = COUNTState(limit=3)
def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
def Split_continue(state: ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
""" """
Determine routing based on search_switch value. Determine routing based on search_switch value.
@@ -25,6 +25,7 @@ def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summa
return 'Input_Summary' return 'Input_Summary'
return 'Split_The_Problem' # 默认情况 return 'Split_The_Problem' # 默认情况
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]: def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
""" """
Determine routing based on search_switch value. Determine routing based on search_switch value.
@@ -43,8 +44,10 @@ def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
elif search_switch == '1': elif search_switch == '1':
return 'Retrieve_Summary' return 'Retrieve_Summary'
return 'Retrieve_Summary' # Default based on business logic return 'Retrieve_Summary' # Default based on business logic
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]: def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
status=state.get('verify', '')['status'] status = state.get('verify', '')['status']
# loop_count = counter.get_total() # loop_count = counter.get_total()
if "success" in status: if "success" in status:
# counter.reset() # counter.reset()
@@ -53,7 +56,7 @@ def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "co
# if loop_count < 2: # Maximum loop count is 3 # if loop_count < 2: # Maximum loop count is 3
# return "content_input" # return "content_input"
# else: # else:
# counter.reset() # counter.reset()
return "Summary_fails" return "Summary_fails"
else: else:
# Add default return value to avoid returning None # Add default return value to avoid returning None

View File

@@ -2,77 +2,104 @@ import json
import os import os
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.redis_tool import write_store
from app.core.memory.agent.utils.redis_tool import count_store from app.core.memory.agent.utils.redis_tool import count_store
from app.core.memory.agent.utils.redis_tool import write_store
from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context, get_db from app.db import get_db_context
from app.repositories.memory_short_repository import LongTermMemoryRepository from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.schemas.memory_agent_schema import AgentMemory_Long_Term from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_konwledges_server import write_rag from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task from app.tasks import write_message_task
from app.utils.config_utils import resolve_config_id from app.utils.config_utils import resolve_config_id
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id): async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
# RAG 模式:组合消息为字符串格式(保持原有逻辑) """
Write messages to RAG storage system
Combines user and AI messages into a single string format and stores them
in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval.
Args:
end_user_id: User identifier for the conversation
user_message: User's input message content
ai_message: AI's response message content
user_rag_memory_id: RAG memory identifier for storage location
"""
# RAG mode: combine messages into string format (maintain original logic)
combined_message = f"user: {user_message}\nassistant: {ai_message}" combined_message = f"user: {user_message}\nassistant: {ai_message}"
await write_rag(end_user_id, combined_message, user_rag_memory_id) await write_rag(end_user_id, combined_message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}') logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
actual_config_id, long_term_messages=[]):
async def write(
storage_type,
end_user_id,
user_message,
ai_message,
user_rag_memory_id,
actual_end_user_id,
actual_config_id,
long_term_messages=None
):
""" """
写入记忆(支持结构化消息) Write memory with structured message support
Handles memory writing operations for different storage types (Neo4j/RAG).
Supports both individual message pairs and batch long-term message processing.
Args: Args:
storage_type: 存储类型 (neo4j/rag) storage_type: Storage type identifier ("neo4j" or "rag")
end_user_id: 终端用户ID end_user_id: Terminal user identifier
user_message: 用户消息内容 user_message: User message content
ai_message: AI 回复内容 ai_message: AI response content
user_rag_memory_id: RAG 记忆ID user_rag_memory_id: RAG memory identifier
actual_end_user_id: 实际用户ID actual_end_user_id: Actual user identifier for storage
actual_config_id: 配置ID actual_config_id: Configuration identifier
long_term_messages: Optional list of structured messages for batch processing
逻辑说明: Logic explanation:
- RAG 模式:组合 user_message ai_message 为字符串格式,保持原有逻辑不变 - RAG mode: Combines user_message and ai_message into string format, maintains original logic
- Neo4j 模式:使用结构化消息列表 - Neo4j mode: Uses structured message lists
1. 如果 user_message ai_message 都不为空:创建配对消息 [user, assistant] 1. If both user_message and ai_message are not empty: Creates paired messages [user, assistant]
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景) 2. If only user_message exists: Creates single user message [user] (for historical memory scenarios)
3. 每条消息会被转换为独立的 Chunk保留 speaker 字段 3. Each message is converted to independent Chunk, preserving speaker field
""" """
db = next(get_db()) if long_term_messages is None:
try: long_term_messages = []
with get_db_context() as db:
actual_config_id = resolve_config_id(actual_config_id, db) actual_config_id = resolve_config_id(actual_config_id, db)
# Neo4j 模式:使用结构化消息列表 # Neo4j mode: Use structured message lists
structured_messages = [] structured_messages = []
# 始终添加用户消息(如果不为空) # Always add user message (if not empty)
if isinstance(user_message, str) and user_message.strip() != "": if isinstance(user_message, str) and user_message.strip() != "":
structured_messages.append({"role": "user", "content": user_message}) structured_messages.append({"role": "user", "content": user_message})
# 只有当 AI 回复不为空时才添加 assistant 消息 # Only add assistant message when AI reply is not empty
if isinstance(ai_message, str) and ai_message.strip() != "": if isinstance(ai_message, str) and ai_message.strip() != "":
structured_messages.append({"role": "assistant", "content": ai_message}) structured_messages.append({"role": "assistant", "content": ai_message})
# 如果提供了 long_term_messages,使用它替代 structured_messages # If long_term_messages provided, use it to replace structured_messages
if long_term_messages and isinstance(long_term_messages, list): if long_term_messages and isinstance(long_term_messages, list):
structured_messages = long_term_messages structured_messages = long_term_messages
elif long_term_messages and isinstance(long_term_messages, str): elif long_term_messages and isinstance(long_term_messages, str):
# 如果是 JSON 字符串,先解析 # If it's a JSON string, parse it first
try: try:
structured_messages = json.loads(long_term_messages) structured_messages = json.loads(long_term_messages)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}") logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
# 如果没有消息,直接返回 # If no messages, return directly
if not structured_messages: if not structured_messages:
logger.warning(f"No messages to write for user {actual_end_user_id}") logger.warning(f"No messages to write for user {actual_end_user_id}")
return return
@@ -80,29 +107,41 @@ async def write(storage_type, end_user_id, user_message, ai_message, user_rag_me
logger.info( logger.info(
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}") f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
write_id = write_message_task.delay( write_id = write_message_task.delay(
actual_end_user_id, # end_user_id: 用户ID actual_end_user_id, # end_user_id: User ID
structured_messages, # message: JSON 字符串格式的消息列表 structured_messages, # message: JSON string format message list
str(actual_config_id), # config_id: 配置ID字符串 str(actual_config_id), # config_id: Configuration ID string
storage_type, # storage_type: "neo4j" storage_type, # storage_type: "neo4j"
user_rag_memory_id or "" # user_rag_memory_id: RAG记忆IDNeo4j模式下不使用 user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
) )
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}") logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
write_status = get_task_memory_write_result(str(write_id)) write_status = get_task_memory_write_result(str(write_id))
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
finally:
db.close()
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope):
"""
Save long-term memory data to database
Handles the storage of long-term memory data based on different strategies
(chunk-based or aggregate-based) and manages the transition from short-term
to long-term memory storage.
Args:
long_term_messages: Long-term message data to be saved
actual_config_id: Configuration identifier for memory settings
end_user_id: User identifier for memory association
type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
scope: Scope/window size for memory processing
"""
with get_db_context() as db_session: with get_db_context() as db_session:
repo = LongTermMemoryRepository(db_session) repo = LongTermMemoryRepository(db_session)
from app.core.memory.agent.utils.redis_tool import write_store from app.core.memory.agent.utils.redis_tool import write_store
result = write_store.get_session_by_userid(end_user_id) result = write_store.get_session_by_userid(end_user_id)
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE: if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
data = await format_parsing(result, "dict") data = await format_parsing(result, "dict")
chunk_data = data[:scope] chunk_data = data[:scope]
if len(chunk_data)==scope: if len(chunk_data) == scope:
repo.upsert(end_user_id, chunk_data) repo.upsert(end_user_id, chunk_data)
logger.info(f'---------写入短长期-----------') logger.info(f'---------写入短长期-----------')
else: else:
@@ -112,18 +151,23 @@ async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,
logger.info(f'写入短长期:') logger.info(f'写入短长期:')
"""Window-based dialogue processing"""
'''根据窗口'''
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope): async def window_dialogue(end_user_id, langchain_messages, memory_config, scope):
''' """
根据窗口获取redis数据,写入neo4j Process dialogue based on window size and write to Neo4j
Args:
end_user_id: 终端用户ID Manages conversation data based on a sliding window approach. When the window
memory_config: 内存配置对象 reaches the specified scope size, it triggers long-term memory storage to Neo4j.
langchain_messages原始数据LIST
scope窗口大小 Args:
''' end_user_id: Terminal user identifier
scope=scope memory_config: Memory configuration object containing settings
langchain_messages: Original message data list
scope: Window size determining when to trigger long-term storage
"""
scope = scope
is_end_user_id = count_store.get_sessions_count(end_user_id) is_end_user_id = count_store.get_sessions_count(end_user_id)
if is_end_user_id is not False: if is_end_user_id is not False:
is_end_user_id = count_store.get_sessions_count(end_user_id)[0] is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
@@ -135,50 +179,72 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
elif int(is_end_user_id) == int(scope): elif int(is_end_user_id) == int(scope):
logger.info('写入长期记忆NEO4J') logger.info('写入长期记忆NEO4J')
formatted_messages = (redis_messages) formatted_messages = (redis_messages)
# 获取 config_id(如果 memory_config 是对象,提取 config_id否则直接使用 # Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
if hasattr(memory_config, 'config_id'): if hasattr(memory_config, 'config_id'):
config_id = memory_config.config_id config_id = memory_config.config_id
else: else:
config_id = memory_config config_id = memory_config
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id, await write(
config_id, formatted_messages) AgentMemory_Long_Term.STORAGE_NEO4J,
end_user_id,
"",
"",
None,
end_user_id,
config_id,
formatted_messages
)
count_store.update_sessions_count(end_user_id, 1, langchain_messages) count_store.update_sessions_count(end_user_id, 1, langchain_messages)
else: else:
count_store.save_sessions_count(end_user_id, 1, langchain_messages) count_store.save_sessions_count(end_user_id, 1, langchain_messages)
"""根据时间""" """Time-based memory processing"""
async def memory_long_term_storage(end_user_id,memory_config,time):
'''
根据时间获取redis数据,写入neo4j async def memory_long_term_storage(end_user_id, memory_config, time):
Args:
end_user_id: 终端用户ID
memory_config: 内存配置对象
'''
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
format_messages = (long_time_data)
messages=[]
memory_config=memory_config.config_id
for i in format_messages:
message=json.loads(i['Query'])
messages+= message
if format_messages!=[]:
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
memory_config, messages)
'''聚合判断'''
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
""" """
聚合判断函数:判断输入句子和历史消息是否描述同一事件 Process memory storage based on time intervals and write to Neo4j
Retrieves Redis data based on time intervals and writes it to Neo4j for
long-term storage. This function handles time-based memory consolidation.
Args: Args:
end_user_id: 终端用户ID end_user_id: Terminal user identifier
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] memory_config: Memory configuration object containing settings
memory_config: 内存配置对象 time: Time interval for data retrieval
""" """
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
format_messages = long_time_data
messages = []
memory_config = memory_config.config_id
for i in format_messages:
message = json.loads(i['Query'])
messages += message
if format_messages:
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
memory_config, messages)
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
"""
Aggregation judgment function: determine if input sentence and historical messages describe the same event
Uses LLM-based analysis to determine whether new messages should be aggregated with existing
historical data or stored as separate events. This helps optimize memory storage and retrieval.
Args:
end_user_id: Terminal user identifier
ori_messages: Original message list, format like [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
memory_config: Memory configuration object containing LLM settings
Returns:
dict: Aggregation judgment result containing is_same_event flag and processed output
"""
history = None
try: try:
# 1. 获取历史会话数据(使用新方法) # 1. Get historical session data (using new method)
result = write_store.get_all_sessions_by_end_user_id(end_user_id) result = write_store.get_all_sessions_by_end_user_id(end_user_id)
history = await format_parsing(result) history = await format_parsing(result)
if not result: if not result:

View File

@@ -2,41 +2,53 @@ import asyncio
import json import json
from datetime import datetime, timedelta from datetime import datetime, timedelta
from langchain.tools import tool from langchain.tools import tool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.core.memory.src.search import ( from app.core.memory.src.search import (
search_by_temporal, search_by_temporal,
search_by_keyword_temporal, search_by_keyword_temporal,
) )
def extract_tool_message_content(response): def extract_tool_message_content(response):
"""从agent响应中提取ToolMessage内容和工具名称""" """
Extract ToolMessage content and tool names from agent response
Parses agent response messages to extract tool execution results and metadata.
Handles JSON parsing and provides structured access to tool output data.
Args:
response: Agent response dictionary containing messages
Returns:
dict: Dictionary containing tool_name and parsed content, or None if no tool message found
- tool_name: Name of the executed tool
- content: Parsed tool execution result (JSON or raw text)
"""
messages = response.get('messages', []) messages = response.get('messages', [])
for message in messages: for message in messages:
if hasattr(message, 'tool_call_id') and hasattr(message, 'content'): if hasattr(message, 'tool_call_id') and hasattr(message, 'content'):
# 这是一个ToolMessage # This is a ToolMessage
tool_content = message.content tool_content = message.content
tool_name = None tool_name = None
# 尝试获取工具名称 # Try to get tool name
if hasattr(message, 'name'): if hasattr(message, 'name'):
tool_name = message.name tool_name = message.name
elif hasattr(message, 'tool_name'): elif hasattr(message, 'tool_name'):
tool_name = message.tool_name tool_name = message.tool_name
try: try:
# 解析JSON内容 # Parse JSON content
parsed_content = json.loads(tool_content) parsed_content = json.loads(tool_content)
return { return {
'tool_name': tool_name, 'tool_name': tool_name,
'content': parsed_content 'content': parsed_content
} }
except json.JSONDecodeError: except json.JSONDecodeError:
# 如果不是JSON格式直接返回内容 # If not JSON format, return content directly
return { return {
'tool_name': tool_name, 'tool_name': tool_name,
'content': tool_content 'content': tool_content
@@ -46,26 +58,49 @@ def extract_tool_message_content(response):
class TimeRetrievalInput(BaseModel): class TimeRetrievalInput(BaseModel):
"""时间检索工具的输入模式""" """
Input schema for time retrieval tool
Defines the expected input parameters for time-based retrieval operations.
Used for validation and documentation of tool parameters.
Attributes:
context: User input query content for search
end_user_id: Group ID for filtering search results, defaults to test user
"""
context: str = Field(description="用户输入的查询内容") context: str = Field(description="用户输入的查询内容")
end_user_id: str = Field(default="88a459f5_text09", description="组ID用于过滤搜索结果") end_user_id: str = Field(default="88a459f5_text09", description="组ID用于过滤搜索结果")
def create_time_retrieval_tool(end_user_id: str): def create_time_retrieval_tool(end_user_id: str):
""" """
创建一个带有特定end_user_id的TimeRetrieval工具同步版本用于按时间范围搜索语句(Statements) Create a TimeRetrieval tool with specific end_user_id (synchronous version) for searching statements by time range
Creates a specialized time-based retrieval tool that searches for statements within
specified time ranges. Includes field cleaning functionality to remove unnecessary
metadata from search results.
Args:
end_user_id: User identifier for scoping search results
Returns:
function: Configured TimeRetrievalWithGroupId tool function
""" """
def clean_temporal_result_fields(data): def clean_temporal_result_fields(data):
""" """
清理时间搜索结果中不需要的字段,并修改结构 Clean unnecessary fields from temporal search results and modify structure
Removes metadata fields that are not needed for end-user consumption and
restructures the response format for better usability.
Args: Args:
data: 要清理的数据 data: Data to be cleaned (dict, list, or other types)
Returns: Returns:
清理后的数据 Cleaned data with unnecessary fields removed
""" """
# 需要过滤的字段列表 # List of fields to filter out
fields_to_remove = { fields_to_remove = {
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at', 'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
'valid_at', 'invalid_at', 'statement_ids' 'valid_at', 'invalid_at', 'statement_ids'
@@ -75,9 +110,9 @@ def create_time_retrieval_tool(end_user_id: str):
cleaned = {} cleaned = {}
for key, value in data.items(): for key, value in data.items():
if key == 'statements' and isinstance(value, dict) and 'statements' in value: if key == 'statements' and isinstance(value, dict) and 'statements' in value:
# statements: {"statements": [...]} 改为 time_search: {"statements": [...]} # Change statements: {"statements": [...]} to time_search: {"statements": [...]}
cleaned_value = clean_temporal_result_fields(value) cleaned_value = clean_temporal_result_fields(value)
# 进一步将内部的 statements 改为 time_search # Further change internal statements to time_search
if 'statements' in cleaned_value: if 'statements' in cleaned_value:
cleaned['results'] = { cleaned['results'] = {
'time_search': cleaned_value['statements'] 'time_search': cleaned_value['statements']
@@ -93,24 +128,33 @@ def create_time_retrieval_tool(end_user_id: str):
return data return data
@tool @tool
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str: def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None,
end_user_id_param: str = None, clean_output: bool = True) -> str:
""" """
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段 Optimized time retrieval tool, combines time range search only (synchronous version), automatically filters unnecessary metadata fields
显式接收参数:
- context: 查询上下文内容 Performs time-based search operations with automatic metadata filtering. Supports
- start_date: 开始时间可选格式YYYY-MM-DD flexible date range specification and provides clean, user-friendly output.
- end_date: 结束时间可选格式YYYY-MM-DD
- end_user_id_param: 组ID可选用于覆盖默认组ID Explicit parameters:
- clean_output: 是否清理输出中的元数据字段 - context: Query context content
-end_date 需要根据用户的描述获取结束的时间输出格式用strftime("%Y-%m-%d") - start_date: Start time (optional, format: YYYY-MM-DD)
- end_date: End time (optional, format: YYYY-MM-DD)
- end_user_id_param: Group ID (optional, overrides default group ID)
- clean_output: Whether to clean metadata fields from output
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
Returns:
str: JSON formatted search results with temporal data
""" """
async def _async_search(): async def _async_search():
# 使用传入的参数或默认值 # Use passed parameters or default values
actual_end_user_id = end_user_id_param or end_user_id actual_end_user_id = end_user_id_param or end_user_id
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d") actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d") actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
# 基本时间搜索 # Basic time search
results = await search_by_temporal( results = await search_by_temporal(
end_user_id=actual_end_user_id, end_user_id=actual_end_user_id,
start_date=actual_start_date, start_date=actual_start_date,
@@ -118,7 +162,7 @@ def create_time_retrieval_tool(end_user_id: str):
limit=10 limit=10
) )
# 清理结果中不需要的字段 # Clean unnecessary fields from results
if clean_output: if clean_output:
cleaned_results = clean_temporal_result_fields(results) cleaned_results = clean_temporal_result_fields(results)
else: else:
@@ -129,22 +173,32 @@ def create_time_retrieval_tool(end_user_id: str):
return asyncio.run(_async_search()) return asyncio.run(_async_search())
@tool @tool
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None, clean_output: bool = True) -> str: def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None,
clean_output: bool = True) -> str:
""" """
优化的关键词时间检索工具,结合关键词和时间范围搜索(同步版本),自动过滤不需要的元数据字段 Optimized keyword time retrieval tool, combines keyword and time range search (synchronous version), automatically filters unnecessary metadata fields
显式接收参数:
- context: 查询内容 Performs combined keyword and temporal search operations with automatic metadata
- days_back: 向前搜索的天数默认7天 filtering. Provides more targeted search results by combining content relevance
- start_date: 开始时间可选格式YYYY-MM-DD with time-based filtering.
- end_date: 结束时间可选格式YYYY-MM-DD
- clean_output: 是否清理输出中的元数据字段 Explicit parameters:
- end_date 需要根据用户的描述获取结束的时间输出格式用strftime("%Y-%m-%d") - context: Query content for keyword matching
- days_back: Number of days to search backwards, default 7 days
- start_date: Start time (optional, format: YYYY-MM-DD)
- end_date: End time (optional, format: YYYY-MM-DD)
- clean_output: Whether to clean metadata fields from output
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
Returns:
str: JSON formatted search results combining keyword and temporal data
""" """
async def _async_search(): async def _async_search():
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d") actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d") actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d")
# 关键词时间搜索 # Keyword time search
results = await search_by_keyword_temporal( results = await search_by_keyword_temporal(
query_text=context, query_text=context,
end_user_id=end_user_id, end_user_id=end_user_id,
@@ -153,7 +207,7 @@ def create_time_retrieval_tool(end_user_id: str):
limit=15 limit=15
) )
# 清理结果中不需要的字段 # Clean unnecessary fields from results
if clean_output: if clean_output:
cleaned_results = clean_temporal_result_fields(results) cleaned_results = clean_temporal_result_fields(results)
else: else:
@@ -168,43 +222,53 @@ def create_time_retrieval_tool(end_user_id: str):
def create_hybrid_retrieval_tool_async(memory_config, **search_params): def create_hybrid_retrieval_tool_async(memory_config, **search_params):
""" """
创建混合检索工具使用run_hybrid_search进行混合检索优化输出格式并过滤不需要的字段 Create hybrid retrieval tool using run_hybrid_search for hybrid retrieval, optimize output format and filter unnecessary fields
Creates an advanced hybrid search tool that combines multiple search strategies
(keyword, vector, hybrid) with automatic result cleaning and formatting.
Args: Args:
memory_config: 内存配置对象 memory_config: Memory configuration object containing LLM and search settings
**search_params: 搜索参数,包含end_user_id, limit, include **search_params: Search parameters including end_user_id, limit, include, etc.
Returns:
function: Configured HybridSearch tool function with async capabilities
""" """
def clean_result_fields(data): def clean_result_fields(data):
""" """
递归清理结果中不需要的字段 Recursively clean unnecessary fields from results
Removes metadata fields that are not needed for end-user consumption,
improving readability and reducing response size.
Args: Args:
data: 要清理的数据(可能是字典、列表或其他类型) data: Data to be cleaned (can be dict, list, or other types)
Returns: Returns:
清理后的数据 Cleaned data with unnecessary fields removed
""" """
# 需要过滤的字段列表 # List of fields to filter out
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 # TODO: fact_summary functionality temporarily disabled, will be enabled after future development
fields_to_remove = { fields_to_remove = {
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids', 'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id', 'expired_at', 'created_at', 'chunk_id', 'apply_id',
'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary" 'user_id', 'statement_ids', 'updated_at', "chunk_ids", "fact_summary"
} }
# 注意:'id' 字段保留community 展开时需要用 community id 查询成员 statements
if isinstance(data, dict): if isinstance(data, dict):
# 对字典进行清理 # Clean dictionary
cleaned = {} cleaned = {}
for key, value in data.items(): for key, value in data.items():
if key not in fields_to_remove: if key not in fields_to_remove:
cleaned[key] = clean_result_fields(value) # 递归清理嵌套数据 cleaned[key] = clean_result_fields(value) # Recursively clean nested data
return cleaned return cleaned
elif isinstance(data, list): elif isinstance(data, list):
# 对列表中的每个元素进行清理 # Clean each element in list
return [clean_result_fields(item) for item in data] return [clean_result_fields(item) for item in data]
else: else:
# 其他类型直接返回 # Return other types directly
return data return data
@tool @tool
@@ -216,49 +280,55 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
rerank_alpha: float = 0.6, rerank_alpha: float = 0.6,
use_forgetting_rerank: bool = False, use_forgetting_rerank: bool = False,
use_llm_rerank: bool = False, use_llm_rerank: bool = False,
clean_output: bool = True # 新增:是否清理输出字段 clean_output: bool = True # New: whether to clean output fields
) -> str: ) -> str:
""" """
优化的混合检索工具,支持关键词、向量和混合搜索,自动过滤不需要的元数据字段 Optimized hybrid retrieval tool, supports keyword, vector and hybrid search, automatically filters unnecessary metadata fields
Provides comprehensive search capabilities combining multiple search strategies
with intelligent result ranking and automatic metadata filtering for clean output.
Args: Args:
context: 查询内容 context: Query content for search
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid') search_type: Search type ('keyword', 'embedding', 'hybrid')
limit: 结果数量限制 limit: Result quantity limit
end_user_id: 组ID用于过滤搜索结果 end_user_id: Group ID for filtering search results
rerank_alpha: 重排序权重参数 rerank_alpha: Reranking weight parameter for result scoring
use_forgetting_rerank: 是否使用遗忘重排序 use_forgetting_rerank: Whether to use forgetting-based reranking
use_llm_rerank: 是否使用LLM重排序 use_llm_rerank: Whether to use LLM-based reranking
clean_output: 是否清理输出中的元数据字段 clean_output: Whether to clean metadata fields from output
Returns:
str: JSON formatted comprehensive search results
""" """
try: try:
# 导入run_hybrid_search函数 # Import run_hybrid_search function
from app.core.memory.src.search import run_hybrid_search from app.core.memory.src.search import run_hybrid_search
# 合并参数,优先使用传入的参数 # Merge parameters, prioritize passed parameters
final_params = { final_params = {
"query_text": context, "query_text": context,
"search_type": search_type, "search_type": search_type,
"end_user_id": end_user_id or search_params.get("end_user_id"), "end_user_id": end_user_id or search_params.get("end_user_id"),
"limit": limit or search_params.get("limit", 10), "limit": limit or search_params.get("limit", 10),
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]), "include": search_params.get("include", ["summaries", "statements", "chunks", "entities", "communities"]),
"output_path": None, # 不保存到文件 "output_path": None, # Don't save to file
"memory_config": memory_config, "memory_config": memory_config,
"rerank_alpha": rerank_alpha, "rerank_alpha": rerank_alpha,
"use_forgetting_rerank": use_forgetting_rerank, "use_forgetting_rerank": use_forgetting_rerank,
"use_llm_rerank": use_llm_rerank "use_llm_rerank": use_llm_rerank
} }
# 执行混合检索 # Execute hybrid retrieval
raw_results = await run_hybrid_search(**final_params) raw_results = await run_hybrid_search(**final_params)
# 清理结果中不需要的字段 # Clean unnecessary fields from results
if clean_output: if clean_output:
cleaned_results = clean_result_fields(raw_results) cleaned_results = clean_result_fields(raw_results)
else: else:
cleaned_results = raw_results cleaned_results = raw_results
# 格式化返回结果 # Format return results
formatted_results = { formatted_results = {
"search_query": context, "search_query": context,
"search_type": search_type, "search_type": search_type,
@@ -281,32 +351,46 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
def create_hybrid_retrieval_tool_sync(memory_config, **search_params): def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
""" """
创建同步版本的混合检索工具,优化输出格式并过滤不需要的字段 Create synchronous version of hybrid retrieval tool, optimize output format and filter unnecessary fields
Creates a synchronous wrapper around the async hybrid search functionality,
making it compatible with synchronous tool execution environments.
Args: Args:
memory_config: 内存配置对象 memory_config: Memory configuration object containing search settings
**search_params: 搜索参数 **search_params: Search parameters for configuration
Returns:
function: Configured HybridSearchSync tool function
""" """
@tool @tool
def HybridSearchSync( def HybridSearchSync(
context: str, context: str,
search_type: str = "hybrid", search_type: str = "hybrid",
limit: int = 10, limit: int = 10,
end_user_id: str = None, end_user_id: str = None,
clean_output: bool = True clean_output: bool = True
) -> str: ) -> str:
""" """
优化的混合检索工具(同步版本),自动过滤不需要的元数据字段 Optimized hybrid retrieval tool (synchronous version), automatically filters unnecessary metadata fields
Provides the same hybrid search capabilities as the async version but in a
synchronous execution context. Automatically handles async-to-sync conversion.
Args: Args:
context: 查询内容 context: Query content for search
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid') search_type: Search type ('keyword', 'embedding', 'hybrid')
limit: 结果数量限制 limit: Result quantity limit
end_user_id: 组ID用于过滤搜索结果 end_user_id: Group ID for filtering search results
clean_output: 是否清理输出中的元数据字段 clean_output: Whether to clean metadata fields from output
Returns:
str: JSON formatted search results
""" """
async def _async_search(): async def _async_search():
# 创建异步工具并执行 # Create async tool and execute
async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params) async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params)
return await async_tool.ainvoke({ return await async_tool.ainvoke({
"context": context, "context": context,

View File

@@ -1,20 +1,28 @@
import json import json
from langchain_core.messages import HumanMessage, AIMessage from langchain_core.messages import HumanMessage, AIMessage
async def format_parsing(messages: list,type:str='string'):
async def format_parsing(messages: list, type: str = 'string'):
""" """
格式化解析消息列表 Format and parse message lists into different output types
Processes message lists from storage and converts them into either string format
or dictionary format based on the specified type parameter. Handles JSON parsing
and role-based message organization.
Args: Args:
messages: 消息列表 messages: List of message objects from storage containing message data
type: 返回类型 ('string''dict') type: Return type specification ('string' for text format, 'dict' for key-value pairs)
Returns: Returns:
格式化后的消息列表 list: Formatted message list in the specified format
- 'string': List of formatted text messages with role prefixes
- 'dict': List of dictionaries mapping user messages to AI responses
""" """
result = [] result = []
user=[] user = []
ai=[] ai = []
for message in messages: for message in messages:
hstory_messages = message['messages'] hstory_messages = message['messages']
@@ -24,25 +32,38 @@ async def format_parsing(messages: list,type:str='string'):
role = content['role'] role = content['role']
content = content['content'] content = content['content']
if type == "string": if type == "string":
if role == 'human' or role=="user": if role == 'human' or role == "user":
content = '用户:' + content content = '用户:' + content
else: else:
content = 'AI:' + content content = 'AI:' + content
result.append(content) result.append(content)
if type == "dict" : if type == "dict":
if role == 'human' or role=="user": if role == 'human' or role == "user":
user.append( content) user.append(content)
else: else:
ai.append(content) ai.append(content)
if type == "dict": if type == "dict":
for key,values in zip(user,ai): for key, values in zip(user, ai):
result.append({key:values}) result.append({key: values})
return result return result
async def messages_parse(messages: list | dict): async def messages_parse(messages: list | dict):
user=[] """
ai=[] Parse messages from storage format into user-AI conversation pairs
database=[]
Extracts and organizes conversation data from stored message format,
separating user and AI messages and pairing them for database storage.
Args:
messages: List or dictionary containing stored message data with Query fields
Returns:
list: List of dictionaries containing user-AI message pairs for database storage
"""
user = []
ai = []
database = []
for message in messages: for message in messages:
Query = message['Query'] Query = message['Query']
Query = json.loads(Query) Query = json.loads(Query)
@@ -54,10 +75,23 @@ async def messages_parse(messages: list | dict):
ai.append(data['content']) ai.append(data['content'])
for key, values in zip(user, ai): for key, values in zip(user, ai):
database.append({key, values}) database.append({key, values})
return database return database
async def agent_chat_messages(user_content,ai_content): async def agent_chat_messages(user_content, ai_content):
"""
Create structured chat message format for agent conversations
Formats user and AI content into a standardized message structure suitable
for agent processing and storage. Creates role-based message objects.
Args:
user_content: User's message content string
ai_content: AI's response content string
Returns:
list: List of structured message dictionaries with role and content fields
"""
messages = [ messages = [
{ {
"role": "user", "role": "user",

View File

@@ -13,7 +13,6 @@ from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.schemas.memory_agent_schema import AgentMemory_Long_Term from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
warnings.filterwarnings("ignore", category=RuntimeWarning) warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
@@ -42,10 +41,26 @@ async def make_write_graph():
yield graph yield graph
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '',
end_user_id: str = '', scope: int = 6):
"""
Handle long-term memory storage with different strategies
Supports multiple storage strategies including chunk-based, time-based,
and aggregate judgment approaches for long-term memory persistence.
Args:
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
langchain_messages: List of messages to store
memory_config: Memory configuration identifier
end_user_id: User group identifier
scope: Scope parameter for chunk-based storage (default: 6)
"""
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
aggregate_judgment
from app.core.memory.agent.utils.redis_tool import write_store from app.core.memory.agent.utils.redis_tool import write_store
write_store.save_session_write(end_user_id, (langchain_messages)) write_store.save_session_write(end_user_id, langchain_messages)
# 获取数据库会话 # 获取数据库会话
with get_db_context() as db_session: with get_db_context() as db_session:
config_service = MemoryConfigService(db_session) config_service = MemoryConfigService(db_session)
@@ -53,26 +68,39 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[
config_id=memory_config, # 改为整数 config_id=memory_config, # 改为整数
service_name="MemoryAgentService" service_name="MemoryAgentService"
) )
if long_term_type=='chunk': if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
'''方案一:对话窗口6轮对话''' '''Strategy 1: Dialogue window with 6 rounds of conversation'''
await window_dialogue(end_user_id,langchain_messages,memory_config,scope) await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
if long_term_type=='time': if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
"""时间""" """Time-based strategy"""
await memory_long_term_storage(end_user_id, memory_config,5) await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
if long_term_type=='aggregate': if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
"""方案三:聚合判断""" """Strategy 3: Aggregate judgment"""
await aggregate_judgment(end_user_id, langchain_messages, memory_config) await aggregate_judgment(end_user_id, langchain_messages, memory_config)
async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id):
"""
Write long-term memory with different storage types
async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id): Handles both RAG-based storage and traditional memory storage approaches.
For traditional storage, uses chunk-based strategy with paired user-AI messages.
Args:
storage_type: Type of storage (RAG or traditional)
end_user_id: User group identifier
message_chat: User message content
aimessages: AI response messages
user_rag_memory_id: RAG memory identifier
actual_config_id: Actual configuration ID
"""
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
if storage_type == AgentMemory_Long_Term.STORAGE_RAG: if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id) await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
else: else:
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) # AI reply writing (user messages and AI replies paired, written as complete dialogue at once)
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
long_term_messages = await agent_chat_messages(message_chat, aimessages) long_term_messages = await agent_chat_messages(message_chat, aimessages)

View File

@@ -13,6 +13,72 @@ from app.core.memory.utils.data.text_utils import escape_lucene_query
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
# 需要从展开结果中过滤的字段(含 Neo4j DateTime不可 JSON 序列化)
_EXPAND_FIELDS_TO_REMOVE = {
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
'expired_at', 'created_at', 'chunk_id', 'apply_id',
'user_id', 'statement_ids', 'updated_at', 'chunk_ids', 'fact_summary'
}
def _clean_expand_fields(obj):
"""递归过滤展开结果中不可序列化的字段DateTime 等)。"""
if isinstance(obj, dict):
return {k: _clean_expand_fields(v) for k, v in obj.items() if k not in _EXPAND_FIELDS_TO_REMOVE}
if isinstance(obj, list):
return [_clean_expand_fields(i) for i in obj]
return obj
async def expand_communities_to_statements(
community_results: List[dict],
end_user_id: str,
existing_content: str = "",
limit: int = 10,
) -> Tuple[List[dict], List[str]]:
"""
社区展开 helper给定命中的 community 列表,拉取关联 Statement。
- 对展开结果去重(过滤已在 existing_content 中出现的文本)
- 过滤不可序列化字段
- 返回 (cleaned_expanded_stmts, new_texts)
- cleaned_expanded_stmts: 可直接写回 raw_results 的列表
- new_texts: 去重后新增的 statement 文本列表,用于追加到 clean_content
"""
community_ids = [r.get("id") for r in community_results if r.get("id")]
if not community_ids or not end_user_id:
return [], []
from app.repositories.neo4j.graph_search import search_graph_community_expand
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
connector = Neo4jConnector()
try:
result = await search_graph_community_expand(
connector=connector,
community_ids=community_ids,
end_user_id=end_user_id,
limit=limit,
)
except Exception as e:
logger.warning(f"[expand_communities] 社区展开检索失败,跳过: {e}")
return [], []
finally:
await connector.close()
expanded_stmts = result.get("expanded_statements", [])
if not expanded_stmts:
return [], []
existing_lines = set(existing_content.splitlines())
new_texts = [
s["statement"] for s in expanded_stmts
if s.get("statement") and s["statement"] not in existing_lines
]
cleaned = _clean_expand_fields(expanded_stmts)
logger.info(f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements新增 {len(new_texts)}community_ids={community_ids}")
return cleaned, new_texts
class SearchService: class SearchService:
"""Service for executing hybrid search and processing results.""" """Service for executing hybrid search and processing results."""
@@ -21,7 +87,7 @@ class SearchService:
"""Initialize the search service.""" """Initialize the search service."""
logger.info("SearchService initialized") logger.info("SearchService initialized")
def extract_content_from_result(self, result: dict) -> str: def extract_content_from_result(self, result: dict, node_type: str = "") -> str:
""" """
Extract only meaningful content from search results, dropping all metadata. Extract only meaningful content from search results, dropping all metadata.
@@ -30,9 +96,11 @@ class SearchService:
- Entities: extract 'name' and 'fact_summary' fields - Entities: extract 'name' and 'fact_summary' fields
- Summaries: extract 'content' field - Summaries: extract 'content' field
- Chunks: extract 'content' field - Chunks: extract 'content' field
- Communities: extract 'content' field (c.summary), prefixed with community name
Args: Args:
result: Search result dictionary result: Search result dictionary
node_type: Hint for node type ("community", "summary", etc.)
Returns: Returns:
Clean content string without metadata Clean content string without metadata
@@ -46,8 +114,21 @@ class SearchService:
if 'statement' in result and result['statement']: if 'statement' in result and result['statement']:
content_parts.append(result['statement']) content_parts.append(result['statement'])
# Summaries/Chunks: extract content field # Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定
if 'content' in result and result['content']: # 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要
is_community = (
node_type == "community"
or 'member_count' in result
or 'core_entities' in result
)
if is_community:
name = result.get('name', '')
content = result.get('content', '')
if content:
prefix = f"[主题:{name}] " if name else ""
content_parts.append(f"{prefix}{content}")
elif 'content' in result and result['content']:
# Summaries / Chunks
content_parts.append(result['content']) content_parts.append(result['content'])
# Entities: extract name and fact_summary (commented out in original) # Entities: extract name and fact_summary (commented out in original)
@@ -99,7 +180,8 @@ class SearchService:
rerank_alpha: float = 0.4, rerank_alpha: float = 0.4,
output_path: str = "search_results.json", output_path: str = "search_results.json",
return_raw_results: bool = False, return_raw_results: bool = False,
memory_config = None memory_config = None,
expand_communities: bool = True,
) -> Tuple[str, str, Optional[dict]]: ) -> Tuple[str, str, Optional[dict]]:
""" """
Execute hybrid search and return clean content. Execute hybrid search and return clean content.
@@ -114,13 +196,15 @@ class SearchService:
output_path: Path to save search results (default: "search_results.json") output_path: Path to save search results (default: "search_results.json")
return_raw_results: If True, also return the raw search results as third element (default: False) return_raw_results: If True, also return the raw search results as third element (default: False)
memory_config: Memory configuration object (required) memory_config: Memory configuration object (required)
expand_communities: If True, expand community hits to member statements (default: True).
Set to False for quick-summary paths that only need community-level text.
Returns: Returns:
Tuple of (clean_content, cleaned_query, raw_results) Tuple of (clean_content, cleaned_query, raw_results)
raw_results is None if return_raw_results=False raw_results is None if return_raw_results=False
""" """
if include is None: if include is None:
include = ["statements", "chunks", "entities", "summaries"] include = ["statements", "chunks", "entities", "summaries", "communities"]
# Clean query # Clean query
cleaned_query = self.clean_query(question) cleaned_query = self.clean_query(question)
@@ -146,8 +230,8 @@ class SearchService:
if search_type == "hybrid": if search_type == "hybrid":
reranked_results = answer.get('reranked_results', {}) reranked_results = answer.get('reranked_results', {})
# Priority order: summaries first (most contextual), then statements, chunks, entities # Priority order: summaries first (most contextual), then communities, statements, chunks, entities
priority_order = ['summaries', 'statements', 'chunks', 'entities'] priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
for category in priority_order: for category in priority_order:
if category in include and category in reranked_results: if category in include and category in reranked_results:
@@ -157,7 +241,7 @@ class SearchService:
else: else:
# For keyword or embedding search, results are directly in answer dict # For keyword or embedding search, results are directly in answer dict
# Apply same priority order # Apply same priority order
priority_order = ['summaries', 'statements', 'chunks', 'entities'] priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
for category in priority_order: for category in priority_order:
if category in include and category in answer: if category in include and category in answer:
@@ -165,11 +249,25 @@ class SearchService:
if isinstance(category_results, list): if isinstance(category_results, list):
answer_list.extend(category_results) answer_list.extend(category_results)
# Extract clean content from all results # 对命中的 community 节点展开其成员 statements路径 "0"/"1" 需要,路径 "2" 不需要)
content_list = [ if expand_communities and "communities" in include:
self.extract_content_from_result(ans) community_results = (
for ans in answer_list answer.get('reranked_results', {}).get('communities', [])
] if search_type == "hybrid"
else answer.get('communities', [])
)
cleaned_stmts, new_texts = await expand_communities_to_statements(
community_results=community_results,
end_user_id=end_user_id,
)
answer_list.extend(cleaned_stmts)
# Extract clean content from all results按类型传入 node_type 区分 community
content_list = []
for ans in answer_list:
# community 节点有 member_count 或 core_entities 字段
ntype = "community" if ('member_count' in ans or 'core_entities' in ans) else ""
content_list.append(self.extract_content_from_result(ans, node_type=ntype))
# Filter out empty strings and join with newlines # Filter out empty strings and join with newlines

View File

@@ -84,7 +84,7 @@ async def get_chunked_dialogs(
pruning_scene=memory_config.pruning_scene or "education", pruning_scene=memory_config.pruning_scene or "education",
pruning_threshold=memory_config.pruning_threshold, pruning_threshold=memory_config.pruning_threshold,
scene_id=str(memory_config.scene_id) if memory_config.scene_id else None, scene_id=str(memory_config.scene_id) if memory_config.scene_id else None,
ontology_classes=memory_config.ontology_classes, ontology_class_infos=memory_config.ontology_class_infos,
) )
logger.info(f"[剪枝] 加载配置: switch={pruning_config.pruning_switch}, scene={pruning_config.pruning_scene}, threshold={pruning_config.pruning_threshold}") logger.info(f"[剪枝] 加载配置: switch={pruning_config.pruning_switch}, scene={pruning_config.pruning_scene}, threshold={pruning_config.pruning_threshold}")

View File

@@ -8,10 +8,11 @@ from langgraph.graph import add_messages
PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3]) PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3])
class WriteState(TypedDict): class WriteState(TypedDict):
''' """
Langgrapg Writing TypedDict Langgrapg Writing TypedDict
''' """
messages: Annotated[list[AnyMessage], add_messages] messages: Annotated[list[AnyMessage], add_messages]
end_user_id: str end_user_id: str
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}] errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
@@ -20,6 +21,7 @@ class WriteState(TypedDict):
data: str data: str
language: str # 语言类型 ("zh" 中文, "en" 英文) language: str # 语言类型 ("zh" 中文, "en" 英文)
class ReadState(TypedDict): class ReadState(TypedDict):
""" """
LangGraph 工作流状态定义 LangGraph 工作流状态定义
@@ -43,18 +45,20 @@ class ReadState(TypedDict):
config_id: str config_id: str
data: str # 新增字段用于传递内容 data: str # 新增字段用于传递内容
spit_data: dict # 新增字段用于传递问题分解结果 spit_data: dict # 新增字段用于传递问题分解结果
problem_extension:dict problem_extension: dict
storage_type: str storage_type: str
user_rag_memory_id: str user_rag_memory_id: str
llm_id: str llm_id: str
embedding_id: str embedding_id: str
memory_config: object # 新增字段用于传递内存配置对象 memory_config: object # 新增字段用于传递内存配置对象
retrieve:dict retrieve: dict
RetrieveSummary: dict RetrieveSummary: dict
InputSummary: dict InputSummary: dict
verify: dict verify: dict
SummaryFails: dict SummaryFails: dict
summary: dict summary: dict
class COUNTState: class COUNTState:
""" """
工作流对话检索内容计数器 工作流对话检索内容计数器
@@ -99,6 +103,7 @@ class COUNTState:
self.total = 0 self.total = 0
print("[COUNTState] 已重置为 0") print("[COUNTState] 已重置为 0")
def deduplicate_entries(entries): def deduplicate_entries(entries):
seen = set() seen = set()
deduped = [] deduped = []
@@ -109,6 +114,7 @@ def deduplicate_entries(entries):
deduped.append(entry) deduped.append(entry)
return deduped return deduped
def merge_to_key_value_pairs(data, query_key, result_key): def merge_to_key_value_pairs(data, query_key, result_key):
grouped = defaultdict(list) grouped = defaultdict(list)
for item in data: for item in data:

View File

@@ -19,7 +19,7 @@ from app.core.memory.utils.log.logging_utils import log_time
from app.db import get_db_context from app.db import get_db_context
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import MemoryConfig from app.schemas.memory_config_schema import MemoryConfig
@@ -165,10 +165,19 @@ async def write(
statement_chunk_edges=all_statement_chunk_edges, statement_chunk_edges=all_statement_chunk_edges,
statement_entity_edges=all_statement_entity_edges, statement_entity_edges=all_statement_entity_edges,
entity_edges=all_entity_entity_edges, entity_edges=all_entity_entity_edges,
connector=neo4j_connector connector=neo4j_connector,
config_id=config_id,
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
) )
if success: if success:
logger.info("Successfully saved all data to Neo4j") logger.info("Successfully saved all data to Neo4j")
# 写入成功后,异步触发聚类(不阻塞写入响应)
schedule_clustering_after_write(
all_entity_nodes,
config_id=config_id,
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
)
break break
else: else:
logger.warning("Failed to save some data to Neo4j") logger.warning("Failed to save some data to Neo4j")

View File

@@ -6,6 +6,7 @@ of the memory system including LLM, chunking, pruning, and search.
Classes: Classes:
LLMConfig: Configuration for LLM client LLMConfig: Configuration for LLM client
ChunkerConfig: Configuration for dialogue chunking ChunkerConfig: Configuration for dialogue chunking
OntologyClassInfo: Single ontology class with name and description
PruningConfig: Configuration for semantic pruning PruningConfig: Configuration for semantic pruning
TemporalSearchParams: Parameters for temporal search queries TemporalSearchParams: Parameters for temporal search queries
""" """
@@ -50,30 +51,41 @@ class ChunkerConfig(BaseModel):
min_characters_per_chunk: Optional[int] = Field(24, ge=0, description="The minimum number of characters in each chunk.") min_characters_per_chunk: Optional[int] = Field(24, ge=0, description="The minimum number of characters in each chunk.")
class OntologyClassInfo(BaseModel):
"""本体类型的名称与语义描述,用于剪枝提示词注入。
Attributes:
class_name: 本体类型名称(如"患者""课程"
class_description: 本体类型语义描述,告知 LLM 该类型在当前场景下的含义
"""
class_name: str = Field(..., description="本体类型名称")
class_description: str = Field(default="", description="本体类型语义描述")
class PruningConfig(BaseModel): class PruningConfig(BaseModel):
"""Configuration for semantic pruning of dialogue content. """Configuration for semantic pruning of dialogue content.
Attributes: Attributes:
pruning_switch: Enable or disable semantic pruning pruning_switch: Enable or disable semantic pruning
pruning_scene: Scene name for pruning, either a built-in key pruning_scene: Scene name for pruning from ontology_scene table
('education', 'online_service', 'outbound') or a custom scene_name
from ontology_scene table
pruning_threshold: Pruning ratio (0-0.9, max 0.9 to avoid complete removal) pruning_threshold: Pruning ratio (0-0.9, max 0.9 to avoid complete removal)
scene_id: Optional ontology scene UUID, used to load custom ontology classes scene_id: Optional ontology scene UUID
ontology_classes: List of class_name strings from ontology_class table, ontology_class_infos: Full ontology class info (name + description) from
injected into the prompt when pruning_scene is not a built-in scene ontology_class table, injected into the pruning prompt to drive
scene-aware preservation decisions
""" """
pruning_switch: bool = Field(False, description="Enable semantic pruning when True.") pruning_switch: bool = Field(False, description="Enable semantic pruning when True.")
pruning_scene: str = Field( pruning_scene: str = Field(
"education", "education",
description="Scene for pruning: built-in key or custom scene_name from ontology_scene.", description="Scene name from ontology_scene table.",
) )
pruning_threshold: float = Field( pruning_threshold: float = Field(
0.5, ge=0.0, le=0.9, 0.5, ge=0.0, le=0.9,
description="Pruning ratio within 0-0.9 (max 0.9 to avoid termination).") description="Pruning ratio within 0-0.9 (max 0.9 to avoid termination).")
scene_id: Optional[str] = Field(None, description="Ontology scene UUID (optional).") scene_id: Optional[str] = Field(None, description="Ontology scene UUID (optional).")
ontology_classes: Optional[List[str]] = Field( ontology_class_infos: List[OntologyClassInfo] = Field(
None, description="Class names from ontology_class table for custom scenes." default_factory=list,
description="Full ontology class info (name + description) injected into pruning prompt."
) )

View File

@@ -238,7 +238,7 @@ def rerank_with_activation(
reranked: Dict[str, List[Dict[str, Any]]] = {} reranked: Dict[str, List[Dict[str, Any]]] = {}
for category in ["statements", "chunks", "entities", "summaries"]: for category in ["statements", "chunks", "entities", "summaries", "communities"]:
keyword_items = keyword_results.get(category, []) keyword_items = keyword_results.get(category, [])
embedding_items = embedding_results.get(category, []) embedding_items = embedding_results.get(category, [])
@@ -281,21 +281,23 @@ def rerank_with_activation(
for item in items_list: for item in items_list:
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
if item_id and item_id in combined_items: if item_id and item_id in combined_items:
combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value", 0) combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value")
# 步骤 4: 计算基础分数和最终分数 # 步骤 4: 计算基础分数和最终分数
for item_id, item in combined_items.items(): for item_id, item in combined_items.items():
bm25_norm = float(item.get("bm25_score", 0) or 0) bm25_norm = float(item.get("bm25_score", 0) or 0)
emb_norm = float(item.get("embedding_score", 0) or 0) emb_norm = float(item.get("embedding_score", 0) or 0)
act_norm = float(item.get("normalized_activation_value", 0) or 0) # normalized_activation_value 为 None 表示该节点无激活值,保留 None 语义
raw_act_norm = item.get("normalized_activation_value")
act_norm = float(raw_act_norm) if raw_act_norm is not None else None
# 第一阶段只考虑内容相关性BM25 + Embedding # 第一阶段只考虑内容相关性BM25 + Embedding
# alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重 # alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重
content_score = alpha * bm25_norm + (1 - alpha) * emb_norm content_score = alpha * bm25_norm + (1 - alpha) * emb_norm
base_score = content_score # 第一阶段用内容分数 base_score = content_score # 第一阶段用内容分数
# 存储激活度分数供第二阶段使用 # 存储激活度分数供第二阶段使用None 表示无激活值,不参与激活值排序)
item["activation_score"] = act_norm item["activation_score"] = act_norm # 可能为 None
item["content_score"] = content_score item["content_score"] = content_score
item["base_score"] = base_score item["base_score"] = base_score
@@ -724,6 +726,8 @@ async def run_hybrid_search(
try: try:
keyword_task = None keyword_task = None
embedding_task = None embedding_task = None
keyword_results: Dict[str, List] = {}
embedding_results: Dict[str, List] = {}
if search_type in ["keyword", "hybrid"]: if search_type in ["keyword", "hybrid"]:
# Keyword-based search # Keyword-based search
@@ -746,35 +750,42 @@ async def run_hybrid_search(
# 从数据库读取嵌入器配置(按 ID并构建 RedBearModelConfig # 从数据库读取嵌入器配置(按 ID并构建 RedBearModelConfig
config_load_start = time.time() config_load_start = time.time()
with get_db_context() as db: try:
config_service = MemoryConfigService(db) with get_db_context() as db:
embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id)) config_service = MemoryConfigService(db)
rb_config = RedBearModelConfig( embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
model_name=embedder_config_dict["model_name"], rb_config = RedBearModelConfig(
provider=embedder_config_dict["provider"], model_name=embedder_config_dict["model_name"],
api_key=embedder_config_dict["api_key"], provider=embedder_config_dict["provider"],
base_url=embedder_config_dict["base_url"], api_key=embedder_config_dict["api_key"],
type="llm" base_url=embedder_config_dict["base_url"],
) type="llm"
config_load_time = time.time() - config_load_start
logger.info(f"[PERF] Config loading took {config_load_time:.4f}s")
# Init embedder
embedder_init_start = time.time()
embedder = OpenAIEmbedderClient(model_config=rb_config)
embedder_init_time = time.time() - embedder_init_start
logger.info(f"[PERF] Embedder init took {embedder_init_time:.4f}s")
embedding_task = asyncio.create_task(
search_graph_by_embedding(
connector=connector,
embedder_client=embedder,
query_text=query_text,
end_user_id=end_user_id,
limit=limit,
include=include,
) )
) config_load_time = time.time() - config_load_start
logger.info(f"[PERF] Config loading took {config_load_time:.4f}s")
# Init embedder
embedder_init_start = time.time()
embedder = OpenAIEmbedderClient(model_config=rb_config)
embedder_init_time = time.time() - embedder_init_start
logger.info(f"[PERF] Embedder init took {embedder_init_time:.4f}s")
embedding_task = asyncio.create_task(
search_graph_by_embedding(
connector=connector,
embedder_client=embedder,
query_text=query_text,
end_user_id=end_user_id,
limit=limit,
include=include,
)
)
except Exception as emb_init_err:
logger.warning(
f"[PERF] Embedding search skipped due to init error "
f"(embedding_model_id={memory_config.embedding_model_id}): {emb_init_err}"
)
embedding_task = None
if keyword_task: if keyword_task:
keyword_results = await keyword_task keyword_results = await keyword_task

View File

@@ -0,0 +1,3 @@
from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine
__all__ = ["LabelPropagationEngine"]

View File

@@ -0,0 +1,559 @@
"""标签传播聚类引擎
基于 ZEP 论文的动态标签传播算法,对 Neo4j 中的 ExtractedEntity 节点进行社区聚类。
支持两种模式:
- 全量初始化full_clustering首次运行对所有实体做完整 LPA 迭代
- 增量更新incremental_update新实体到达时只处理新实体及其邻居
"""
import asyncio
import logging
import uuid
from math import sqrt
from typing import Dict, List, Optional
from app.repositories.neo4j.community_repository import CommunityRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
logger = logging.getLogger(__name__)
# 全量迭代最大轮数,防止不收敛
MAX_ITERATIONS = 10
# 社区核心实体取 top-N 数量
CORE_ENTITY_LIMIT = 10
def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float:
"""计算两个向量的余弦相似度,任一为空则返回 0。"""
if not v1 or not v2 or len(v1) != len(v2):
return 0.0
dot = sum(a * b for a, b in zip(v1, v2))
norm1 = sqrt(sum(a * a for a in v1))
norm2 = sqrt(sum(b * b for b in v2))
if norm1 == 0 or norm2 == 0:
return 0.0
return dot / (norm1 * norm2)
def _weighted_vote(
neighbors: List[Dict],
self_embedding: Optional[List[float]],
) -> Optional[str]:
"""
加权多数投票,选出得票最高的社区。
权重 = 语义相似度name_embedding 余弦)* activation_value 加成
没有 community_id 的邻居不参与投票。
"""
votes: Dict[str, float] = {}
for nb in neighbors:
cid = nb.get("community_id")
if not cid:
continue
sem = _cosine_similarity(self_embedding, nb.get("name_embedding"))
act = nb.get("activation_value") or 0.5
# 语义相似度权重 0.6,激活值权重 0.4
weight = 0.6 * sem + 0.4 * act
votes[cid] = votes.get(cid, 0.0) + weight
if not votes:
return None
return max(votes, key=votes.__getitem__)
class LabelPropagationEngine:
"""标签传播聚类引擎"""
def __init__(
self,
connector: Neo4jConnector,
config_id: Optional[str] = None,
llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None,
):
self.connector = connector
self.repo = CommunityRepository(connector)
self.config_id = config_id
self.llm_model_id = llm_model_id
self.embedding_model_id = embedding_model_id
# ──────────────────────────────────────────────────────────────────────────
# 公开接口
# ──────────────────────────────────────────────────────────────────────────
async def run(
self,
end_user_id: str,
new_entity_ids: Optional[List[str]] = None,
) -> None:
"""
统一入口:自动判断全量还是增量。
- 若该用户尚无 Community 节点 → 全量初始化
- 否则 → 增量更新(仅处理 new_entity_ids
"""
has_communities = await self.repo.has_communities(end_user_id)
if not has_communities:
logger.info(f"[Clustering] 用户 {end_user_id} 首次聚类,执行全量初始化")
await self.full_clustering(end_user_id)
else:
if new_entity_ids:
logger.info(
f"[Clustering] 增量更新,新实体数: {len(new_entity_ids)}"
)
await self.incremental_update(new_entity_ids, end_user_id)
async def full_clustering(self, end_user_id: str) -> None:
"""
全量标签传播初始化(分批处理,控制内存峰值)。
策略:
- 每次只加载 BATCH_SIZE 个实体及其邻居进内存
- labels 字典跨批次共享(只存 id→community_id内存极小
- 每批独立跑 MAX_ITERATIONS 轮 LPA批次间通过 labels 传递社区信息
- 所有批次完成后统一 flush 和 merge
"""
BATCH_SIZE = 888 # 每批实体数,可按需调整
# 轻量查询:只获取总数和 ID 列表,不加载 embedding 等大字段
total_count = await self.repo.get_entity_count(end_user_id)
if not total_count:
logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类")
return
all_entity_ids = await self.repo.get_all_entity_ids(end_user_id)
logger.info(f"[Clustering] 用户 {end_user_id}{total_count} 个实体,"
f"分批大小 {BATCH_SIZE},共 {(total_count + BATCH_SIZE - 1) // BATCH_SIZE}")
# labels 跨批次共享:只存 id→community_id内存极小
labels: Dict[str, str] = {eid: eid for eid in all_entity_ids}
del all_entity_ids # 释放 ID 列表,后续按批次加载完整数据
for batch_start in range(0, total_count, BATCH_SIZE):
batch_entities = await self.repo.get_entities_page(
end_user_id, skip=batch_start, limit=BATCH_SIZE
)
if not batch_entities:
break
batch_ids = [e["id"] for e in batch_entities]
batch_embeddings: Dict[str, Optional[List[float]]] = {
e["id"]: e.get("name_embedding") for e in batch_entities
}
logger.info(
f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1}"
f"加载 {len(batch_entities)} 个实体的邻居图..."
)
neighbors_cache = await self.repo.get_entity_neighbors_for_ids(
batch_ids, end_user_id
)
logger.info(f"[Clustering] 邻居预加载完成,覆盖实体数: {len(neighbors_cache)}")
for iteration in range(MAX_ITERATIONS):
changed = 0
for entity in batch_entities:
eid = entity["id"]
neighbors = neighbors_cache.get(eid, [])
# 注入跨批次的最新标签邻居可能在其他批次labels 里有其最新值)
enriched = []
for nb in neighbors:
nb_copy = dict(nb)
nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id"))
enriched.append(nb_copy)
new_label = _weighted_vote(enriched, batch_embeddings.get(eid))
if new_label and new_label != labels[eid]:
labels[eid] = new_label
changed += 1
logger.info(
f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1} "
f"迭代 {iteration + 1}/{MAX_ITERATIONS},标签变化数: {changed}"
)
if changed == 0:
logger.info("[Clustering] 标签已收敛,提前结束本批迭代")
break
# 释放本批次的大对象
del neighbors_cache, batch_embeddings, batch_entities
# 所有批次完成,统一写入 Neo4j
await self._flush_labels(labels, end_user_id)
pre_merge_count = len(set(labels.values()))
logger.info(
f"[Clustering] 全量迭代完成,共 {pre_merge_count} 个社区,"
f"{len(labels)} 个实体,开始后处理合并"
)
all_community_ids = list(set(labels.values()))
await self._evaluate_merge(all_community_ids, end_user_id)
logger.info(
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
f"{len(labels)} 个实体"
)
# 查询存活社区并生成元数据
surviving_communities = await self.repo.get_all_entities(end_user_id)
surviving_community_ids = list({
e.get("community_id") for e in surviving_communities
if e.get("community_id")
})
logger.info(f"[Clustering] 合并后实际存活社区数: {len(surviving_community_ids)}")
await self._generate_community_metadata(surviving_community_ids, end_user_id)
async def incremental_update(
self, new_entity_ids: List[str], end_user_id: str
) -> None:
"""
增量更新:只处理新实体及其邻居,不重跑全图。
1. 对每个新实体查询邻居
2. 加权多数投票决定社区归属
3. 若邻居无社区 → 创建新社区
4. 若邻居分属多个社区 → 评估是否合并
"""
for entity_id in new_entity_ids:
await self._process_single_entity(entity_id, end_user_id)
# ──────────────────────────────────────────────────────────────────────────
# 内部方法
# ──────────────────────────────────────────────────────────────────────────
async def _process_single_entity(
self, entity_id: str, end_user_id: str
) -> None:
"""处理单个新实体的社区分配。"""
neighbors = await self.repo.get_entity_neighbors(entity_id, end_user_id)
# 查询自身 embedding从邻居查询结果中无法获取需单独查
self_embedding = await self._get_entity_embedding(entity_id, end_user_id)
if not neighbors:
# 孤立实体:创建单成员社区
new_cid = self._new_community_id()
await self.repo.upsert_community(new_cid, end_user_id, member_count=1)
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}")
return
# 统计邻居社区分布
community_ids_in_neighbors = set(
nb["community_id"] for nb in neighbors if nb.get("community_id")
)
target_cid = _weighted_vote(neighbors, self_embedding)
if target_cid is None:
# 邻居都没有社区,连同新实体一起创建新社区
new_cid = self._new_community_id()
await self.repo.upsert_community(new_cid, end_user_id)
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
for nb in neighbors:
await self.repo.assign_entity_to_community(
nb["id"], new_cid, end_user_id
)
await self.repo.refresh_member_count(new_cid, end_user_id)
logger.debug(
f"[Clustering] 新实体 {entity_id}{len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
)
await self._generate_community_metadata([new_cid], end_user_id)
else:
# 加入得票最多的社区
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
await self.repo.refresh_member_count(target_cid, end_user_id)
logger.debug(f"[Clustering] 新实体 {entity_id} → 社区 {target_cid}")
# 若邻居分属多个社区,评估合并
if len(community_ids_in_neighbors) > 1:
await self._evaluate_merge(
list(community_ids_in_neighbors), end_user_id
)
await self._generate_community_metadata([target_cid], end_user_id)
async def _evaluate_merge(
self, community_ids: List[str], end_user_id: str
) -> None:
"""
评估多个社区是否应合并。
策略:计算各社区成员 embedding 的平均向量,若两两余弦相似度 > 0.75 则合并。
合并时保留成员数最多的社区,其余成员迁移过来。
全量场景(社区数 > 20使用批量查询避免 N 次数据库往返。
"""
MERGE_THRESHOLD = 0.85
BATCH_THRESHOLD = 20 # 超过此数量走批量查询
community_embeddings: Dict[str, Optional[List[float]]] = {}
community_sizes: Dict[str, int] = {}
if len(community_ids) > BATCH_THRESHOLD:
# 批量查询:一次拉取所有社区成员
all_members = await self.repo.get_all_community_members_batch(
community_ids, end_user_id
)
for cid in community_ids:
members = all_members.get(cid, [])
community_sizes[cid] = len(members)
valid_embeddings = [
m["name_embedding"] for m in members if m.get("name_embedding")
]
if valid_embeddings:
dim = len(valid_embeddings[0])
community_embeddings[cid] = [
sum(e[i] for e in valid_embeddings) / len(valid_embeddings)
for i in range(dim)
]
else:
community_embeddings[cid] = None
else:
# 增量场景:逐个查询
for cid in community_ids:
members = await self.repo.get_community_members(cid, end_user_id)
community_sizes[cid] = len(members)
valid_embeddings = [
m["name_embedding"] for m in members if m.get("name_embedding")
]
if valid_embeddings:
dim = len(valid_embeddings[0])
community_embeddings[cid] = [
sum(e[i] for e in valid_embeddings) / len(valid_embeddings)
for i in range(dim)
]
else:
community_embeddings[cid] = None
# 找出应合并的社区对
to_merge: List[tuple] = []
cids = list(community_ids)
for i in range(len(cids)):
for j in range(i + 1, len(cids)):
sim = _cosine_similarity(
community_embeddings[cids[i]],
community_embeddings[cids[j]],
)
if sim > MERGE_THRESHOLD:
to_merge.append((cids[i], cids[j]))
logger.info(f"[Clustering] 发现 {len(to_merge)} 对可合并社区")
# 执行合并:逐对处理,每次合并后重新计算合并社区的平均向量
# 避免 union-find 链式传递导致语义不相关的社区被间接合并
# A≈B、B≈C 不代表 A≈C不能因传递性把 A/B/C 全部合并)
merged_into: Dict[str, str] = {} # dissolve → keep 的最终映射
def get_root(x: str) -> str:
"""路径压缩,找到 x 当前所属的根社区。"""
while x in merged_into:
merged_into[x] = merged_into.get(merged_into[x], merged_into[x])
x = merged_into[x]
return x
for c1, c2 in to_merge:
root1, root2 = get_root(c1), get_root(c2)
if root1 == root2:
continue
# 用合并后的最新平均向量重新验证相似度
# 防止链式传递A≈B 合并后 B 的向量已更新C 必须和新 B 相似才能合并
current_sim = _cosine_similarity(
community_embeddings.get(root1),
community_embeddings.get(root2),
)
if current_sim <= MERGE_THRESHOLD:
# 合并后向量已漂移,不再满足阈值,跳过
logger.debug(
f"[Clustering] 跳过合并 {root1}{root2}"
f"当前相似度 {current_sim:.3f}{MERGE_THRESHOLD}"
)
continue
keep = root1 if community_sizes.get(root1, 0) >= community_sizes.get(root2, 0) else root2
dissolve = root2 if keep == root1 else root1
merged_into[dissolve] = keep
members = await self.repo.get_community_members(dissolve, end_user_id)
for m in members:
await self.repo.assign_entity_to_community(m["id"], keep, end_user_id)
# 合并后重新计算 keep 的平均向量(加权平均)
keep_emb = community_embeddings.get(keep)
dissolve_emb = community_embeddings.get(dissolve)
keep_size = community_sizes.get(keep, 0)
dissolve_size = community_sizes.get(dissolve, 0)
total_size = keep_size + dissolve_size
if keep_emb and dissolve_emb and total_size > 0:
dim = len(keep_emb)
community_embeddings[keep] = [
(keep_emb[i] * keep_size + dissolve_emb[i] * dissolve_size) / total_size
for i in range(dim)
]
community_embeddings[dissolve] = None
community_sizes[keep] = total_size
community_sizes[dissolve] = 0
await self.repo.refresh_member_count(keep, end_user_id)
logger.info(
f"[Clustering] 社区合并: {dissolve}{keep}"
f"相似度={current_sim:.3f},迁移 {len(members)} 个成员"
)
async def _flush_labels(
self, labels: Dict[str, str], end_user_id: str
) -> None:
"""将内存中的标签批量写入 Neo4j。"""
# 先创建所有唯一社区节点
unique_communities = set(labels.values())
for cid in unique_communities:
await self.repo.upsert_community(cid, end_user_id)
# 再批量分配实体
for entity_id, community_id in labels.items():
await self.repo.assign_entity_to_community(
entity_id, community_id, end_user_id
)
# 刷新成员数
for cid in unique_communities:
await self.repo.refresh_member_count(cid, end_user_id)
async def _get_entity_embedding(
self, entity_id: str, end_user_id: str
) -> Optional[List[float]]:
"""查询单个实体的 name_embedding。"""
try:
result = await self.connector.execute_query(
"MATCH (e:ExtractedEntity {id: $eid, end_user_id: $uid}) "
"RETURN e.name_embedding AS name_embedding",
eid=entity_id,
uid=end_user_id,
)
return result[0]["name_embedding"] if result else None
except Exception:
return None
@staticmethod
def _build_entity_lines(members: List[Dict]) -> List[str]:
"""将实体列表格式化为 prompt 行,包含 name、aliases、description。"""
lines = []
for m in members:
m_name = m.get("name", "")
aliases = m.get("aliases") or []
description = m.get("description") or ""
aliases_str = f"(别名:{''.join(aliases)}" if aliases else ""
desc_str = f"{description}" if description else ""
lines.append(f"- {m_name}{aliases_str}{desc_str}")
return lines
async def _generate_community_metadata(
self, community_ids: List[str], end_user_id: str
) -> None:
"""
为一个或多个社区生成并写入元数据。
流程:
1. 逐个社区调 LLM 生成 name / summary串行
2. 收集所有 summary一次性批量 embed
3. 单个社区用 update_community_metadata多个用 batch_update_community_metadata
"""
if not community_ids:
return
from app.db import get_db_context
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
# --- 阶段1并发调 LLM 生成每个社区的 name / summary ---
async def _build_one(cid: str):
members = await self.repo.get_community_members(cid, end_user_id)
if not members:
return None
sorted_members = sorted(
members,
key=lambda m: m.get("activation_value") or 0,
reverse=True,
)
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
entity_list_str = "\n".join(self._build_entity_lines(members))
prompt = (
f"以下是一组语义相关的实体:\n{entity_list_str}\n\n"
f"请为这组实体所代表的主题:\n"
f"1. 起一个简洁的中文名称不超过10个字\n"
f"2. 写一句话摘要不超过50个字\n\n"
f"严格按以下格式输出,不要有其他内容:\n"
f"名称:<名称>\n摘要:<摘要>"
)
with get_db_context() as db:
llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
response = await llm_client.chat([{"role": "user", "content": prompt}])
text = response.content if hasattr(response, "content") else str(response)
name, summary = "", ""
for line in text.strip().splitlines():
if line.startswith("名称:"):
name = line[3:].strip()
elif line.startswith("摘要:"):
summary = line[3:].strip()
return {
"community_id": cid,
"end_user_id": end_user_id,
"name": name,
"summary": summary,
"core_entities": core_entities,
"summary_embedding": None,
}
results = await asyncio.gather(
*[_build_one(cid) for cid in community_ids],
return_exceptions=True,
)
metadata_list = []
for cid, res in zip(community_ids, results):
if isinstance(res, Exception):
logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {res}", exc_info=res)
elif res is not None:
metadata_list.append(res)
if not metadata_list:
return
# --- 阶段2批量生成 summary_embedding ---
summaries = [m["summary"] for m in metadata_list]
with get_db_context() as db:
embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
embeddings = await embedder.response(summaries)
for i, meta in enumerate(metadata_list):
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
# --- 阶段3写入单个 or 批量)---
if len(metadata_list) == 1:
m = metadata_list[0]
result = await self.repo.update_community_metadata(
community_id=m["community_id"],
end_user_id=m["end_user_id"],
name=m["name"],
summary=m["summary"],
core_entities=m["core_entities"],
summary_embedding=m["summary_embedding"],
)
if result:
logger.info(f"[Clustering] 社区 {m['community_id']} 元数据写入成功: name={m['name']}, summary={m['summary'][:30]}...")
else:
logger.warning(f"[Clustering] 社区 {m['community_id']} 元数据写入返回 False")
else:
ok = await self.repo.batch_update_community_metadata(metadata_list)
if ok:
logger.info(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据成功")
else:
logger.warning(f"[Clustering] 批量写入社区元数据失败")
@staticmethod
def _new_community_id() -> str:
return str(uuid.uuid4())

View File

@@ -20,7 +20,6 @@ from pydantic import BaseModel, Field
from app.core.memory.models.message_models import DialogData, ConversationMessage, ConversationContext from app.core.memory.models.message_models import DialogData, ConversationMessage, ConversationContext
from app.core.memory.models.config_models import PruningConfig from app.core.memory.models.config_models import PruningConfig
from app.core.memory.utils.config.config_utils import get_pruning_config
from app.core.memory.utils.prompt.prompt_utils import prompt_env, log_prompt_rendering, log_template_rendering from app.core.memory.utils.prompt.prompt_utils import prompt_env, log_prompt_rendering, log_template_rendering
from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene_config import ( from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene_config import (
SceneConfigRegistry, SceneConfigRegistry,
@@ -33,6 +32,9 @@ class DialogExtractionResponse(BaseModel):
- is_related对话与场景的相关性判定。 - is_related对话与场景的相关性判定。
- times / ids / amounts / contacts / addresses / keywords重要信息片段用来在不相关对话中保留关键消息。 - times / ids / amounts / contacts / addresses / keywords重要信息片段用来在不相关对话中保留关键消息。
- preserve_keywords情绪/兴趣/爱好/个人观点相关词,包含这些词的消息必须强制保留。
- scene_unrelated_snippets与当前场景无关且无语义关联的消息片段原文截取
用于高阈值阶段精准删除跨场景内容。
""" """
is_related: bool = Field(...) is_related: bool = Field(...)
times: List[str] = Field(default_factory=list) times: List[str] = Field(default_factory=list)
@@ -41,6 +43,8 @@ class DialogExtractionResponse(BaseModel):
contacts: List[str] = Field(default_factory=list) contacts: List[str] = Field(default_factory=list)
addresses: List[str] = Field(default_factory=list) addresses: List[str] = Field(default_factory=list)
keywords: List[str] = Field(default_factory=list) keywords: List[str] = Field(default_factory=list)
preserve_keywords: List[str] = Field(default_factory=list, description="情绪/兴趣/爱好/个人观点相关词,包含这些词的消息强制保留")
scene_unrelated_snippets: List[str] = Field(default_factory=list,description="与当前场景无关且无语义关联的消息原文片段,高阈值阶段用于精准删除跨场景内容")
class MessageImportanceResponse(BaseModel): class MessageImportanceResponse(BaseModel):
@@ -86,26 +90,19 @@ class SemanticPruner:
self._detailed_prune_logging = True # 是否启用详细日志 self._detailed_prune_logging = True # 是否启用详细日志
self._max_debug_msgs_per_dialog = 20 # 每个对话最多记录前N条消息的详细日志 self._max_debug_msgs_per_dialog = 20 # 每个对话最多记录前N条消息的详细日志
# 加载场景特定配置(内置场景走专门规则,自定义场景 fallback 到通用规则) # 加载统一填充词库
self.scene_config: ScenePatterns = SceneConfigRegistry.get_config( self.scene_config: ScenePatterns = SceneConfigRegistry.get_config(self.config.pruning_scene)
self.config.pruning_scene,
fallback_to_generic=True
)
# 判断是否为内置专门场景 # 本体类型列表:直接使用 ontology_class_infosname + description
self._is_builtin_scene = SceneConfigRegistry.is_scene_supported(self.config.pruning_scene) self._ontology_class_infos = getattr(self.config, "ontology_class_infos", None) or []
# _ontology_classes 仅用于日志统计
self._ontology_classes = [info.class_name for info in self._ontology_class_infos]
# 自定义场景的本体类型列表(用于注入提示词) self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene}")
self._ontology_classes = getattr(self.config, "ontology_classes", None) or [] if self._ontology_class_infos:
self._log(f"[剪枝-初始化] 注入本体类型({len(self._ontology_class_infos)}个): {self._ontology_classes}")
if self._is_builtin_scene:
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 使用内置专门配置")
else: else:
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 为自定义场景,使用通用规则 + 本体类型提示词注入") self._log(f"[剪枝-初始化] 未找到本体类型,将使用通用提示词")
if self._ontology_classes:
self._log(f"[剪枝-初始化] 注入本体类型: {self._ontology_classes}")
else:
self._log(f"[剪枝-初始化] 未找到本体类型,将使用通用提示词")
# Load Jinja2 template # Load Jinja2 template
self.template = prompt_env.get_template("extracat_Pruning.jinja2") self.template = prompt_env.get_template("extracat_Pruning.jinja2")
@@ -117,98 +114,19 @@ class SemanticPruner:
# 运行日志:收集关键终端输出,便于写入 JSON # 运行日志:收集关键终端输出,便于写入 JSON
self.run_logs: List[str] = [] self.run_logs: List[str] = []
def _is_important_message(self, message: ConversationMessage) -> bool: # _is_important_message 和 _importance_score 已移除:
"""基于启发式规则识别重要信息消息,优先保留。 # 重要性判断完全由 extracat_Pruning.jinja2 提示词 + LLM 的 preserve_tokens 机制承担。
# LLM 根据注入的本体工程类型语义识别需要保护的内容,无需硬编码正则规则。
改进版:使用场景特定的模式进行识别
- 根据 pruning_scene 动态加载对应的识别规则
- 支持教育、在线服务、外呼三个场景的特定模式
"""
text = message.msg.strip()
if not text:
return False
# 使用场景特定的模式
all_patterns = (
self.scene_config.high_priority_patterns +
self.scene_config.medium_priority_patterns +
self.scene_config.low_priority_patterns
)
for pattern, _ in all_patterns:
if re.search(pattern, text, flags=re.IGNORECASE):
return True
# 检查是否为问句(以问号结尾或包含疑问词)
if text.endswith("") or text.endswith("?"):
return True
# 检查是否包含问句关键词
if any(keyword in text for keyword in self.scene_config.question_keywords):
return True
# 检查是否包含决策性关键词
if any(keyword in text for keyword in self.scene_config.decision_keywords):
return True
return False
def _importance_score(self, message: ConversationMessage) -> int:
"""为重要消息打分,用于在保留比例内优先保留更关键的内容。
改进版使用场景特定的权重体系0-10分
- 根据场景动态调整不同信息类型的权重
- 高优先级模式4-6分
- 中优先级模式2-3分
- 低优先级模式1分
"""
text = message.msg.strip()
score = 0
# 使用场景特定的权重
for pattern, weight in self.scene_config.high_priority_patterns:
if re.search(pattern, text, flags=re.IGNORECASE):
score += weight
for pattern, weight in self.scene_config.medium_priority_patterns:
if re.search(pattern, text, flags=re.IGNORECASE):
score += weight
for pattern, weight in self.scene_config.low_priority_patterns:
if re.search(pattern, text, flags=re.IGNORECASE):
score += weight
# 问句加分
if text.endswith("") or text.endswith("?"):
score += 2
# 包含问句关键词加分
if any(keyword in text for keyword in self.scene_config.question_keywords):
score += 1
# 包含决策性关键词加分
if any(keyword in text for keyword in self.scene_config.decision_keywords):
score += 2
# 长度加分(较长的消息通常包含更多信息)
if len(text) > 50:
score += 1
if len(text) > 100:
score += 1
return min(score, 10) # 最高10分
def _is_filler_message(self, message: ConversationMessage) -> bool: def _is_filler_message(self, message: ConversationMessage) -> bool:
"""检测典型寒暄/口头禅/确认类短消息。 """检测典型寒暄/口头禅/确认类短消息。
改进版:更严格的填充消息判断,避免误删场景相关内容 判断顺序:
满足以下之一视为填充消息 1. 空消息
- 纯标点或空白 2. 场景特定填充词库精确匹配
- 在场景特定填充词库中(精确匹配 3. 常见寒暄精确匹配
- 纯表情符号 4. 组合寒暄模式(前缀+后缀组合,如"好的谢谢""同学你好""明白了"
- 常见寒暄(精确匹配短语) 5. 纯表情/标点
注意:不再使用长度判断,避免误删短但重要的消息
""" """
t = message.msg.strip() t = message.msg.strip()
if not t: if not t:
@@ -230,24 +148,59 @@ class SemanticPruner:
if t in common_greetings: if t in common_greetings:
return True return True
# 组合寒暄模式短消息≤15字且完全由寒暄成分构成
# 策略:将消息拆分后,每个片段都能在填充词库或常见寒暄中找到,则整体为填充
if len(t) <= 15:
# 确认+称呼/感谢组合,如"好的谢谢"、"明白了"、"知道了谢谢"
_confirm_prefixes = {"好的", "", "", "嗯嗯", "", "明白", "明白了", "知道了", "了解", "收到", "没问题"}
_thanks_suffixes = {"谢谢", "谢谢你", "谢谢您", "多谢", "感谢", "谢了"}
_greeting_suffixes = {"你好", "您好", "老师好", "同学好", "大家好"}
_greeting_prefixes = {"同学", "老师", "您好", "你好"}
_close_patterns = {
"没有了", "没事了", "没问题了", "好了", "行了", "可以了",
"不用了", "不需要了", "就这样", "就这样吧", "那就这样",
}
_polite_responses = {
"不客气", "不用谢", "没关系", "没事", "应该的", "这是我应该做的",
}
# 规则1确认词 + 感谢词(如"好的谢谢"、"嗯谢谢"
for cp in _confirm_prefixes:
for ts in _thanks_suffixes:
if t == cp + ts or t == cp + "" + ts or t == cp + "," + ts:
return True
# 规则2称呼前缀 + 问候(如"同学你好"、"老师好"
for gp in _greeting_prefixes:
for gs in _greeting_suffixes:
if t == gp + gs or t.startswith(gp) and t.endswith(""):
return True
# 规则3结束语 + 感谢(如"没有了,谢谢老师"、"没有了谢谢"
for cp in _close_patterns:
if t.startswith(cp):
remainder = t[len(cp):].lstrip(",、 ")
if not remainder or any(remainder.startswith(ts) for ts in _thanks_suffixes):
return True
# 规则4礼貌回应如"不客气,祝你考试顺利"——前缀是礼貌词,后半是祝福套话)
for pr in _polite_responses:
if t.startswith(pr):
remainder = t[len(pr):].lstrip(",、 ")
# 后半是祝福/套话(不含实质信息)
if not remainder or re.match(r"^(祝|希望|期待|加油|顺利|好好|保重)", remainder):
return True
# 规则5纯确认词加"了"后缀(如"明白了"、"知道了"、"好了"
_confirm_base = {"明白", "知道", "了解", "收到", "", "", "可以", "没问题"}
for cb in _confirm_base:
if t == cb + "" or t == cb + "了。" or t == cb + "了!":
return True
# 检查是否为纯表情符号(方括号包裹) # 检查是否为纯表情符号(方括号包裹)
if re.fullmatch(r"(\[[^\]]+\])+", t): if re.fullmatch(r"(\[[^\]]+\])+", t):
return True return True
# 检查是否为纯emojiUnicode表情
emoji_pattern = re.compile(
"["
"\U0001F600-\U0001F64F" # 表情符号
"\U0001F300-\U0001F5FF" # 符号和象形文字
"\U0001F680-\U0001F6FF" # 交通和地图符号
"\U0001F1E0-\U0001F1FF" # 旗帜
"\U00002702-\U000027B0"
"\U000024C2-\U0001F251"
"]+", flags=re.UNICODE
)
if emoji_pattern.fullmatch(t):
return True
# 纯标点符号 # 纯标点符号
if re.fullmatch(r"[。!?,.!?…·\s]+", t): if re.fullmatch(r"[。!?,.!?…·\s]+", t):
return True return True
@@ -432,15 +385,13 @@ class SemanticPruner:
rendered = self.template.render( rendered = self.template.render(
pruning_scene=self.config.pruning_scene, pruning_scene=self.config.pruning_scene,
is_builtin_scene=self._is_builtin_scene, ontology_class_infos=self._ontology_class_infos,
ontology_classes=self._ontology_classes,
dialog_text=dialog_text, dialog_text=dialog_text,
language=self.language language=self.language
) )
log_template_rendering("extracat_Pruning.jinja2", { log_template_rendering("extracat_Pruning.jinja2", {
"pruning_scene": self.config.pruning_scene, "pruning_scene": self.config.pruning_scene,
"is_builtin_scene": self._is_builtin_scene, "ontology_class_infos_count": len(self._ontology_class_infos),
"ontology_classes_count": len(self._ontology_classes),
"language": self.language "language": self.language
}) })
log_prompt_rendering("pruning-extract", rendered) log_prompt_rendering("pruning-extract", rendered)
@@ -480,6 +431,183 @@ class SemanticPruner:
) )
return fallback_response return fallback_response
def _get_pruning_mode(self) -> str:
"""根据 pruning_threshold 返回当前剪枝阶段。
- 低阈值 [0.0, 0.3)conservative 只删填充,保留所有实质内容
- 中阈值 [0.3, 0.6)semantic 保留场景相关 + 有语义关联的内容,删除无关联内容
- 高阈值 [0.6, 0.9]strict 只保留场景相关内容,跨场景内容可被删除
"""
t = float(self.config.pruning_threshold)
if t < 0.3:
return "conservative"
elif t < 0.6:
return "semantic"
else:
return "strict"
def _apply_related_dialog_pruning(
self,
msgs: List[ConversationMessage],
extraction: "DialogExtractionResponse",
dialog_label: str,
pruning_mode: str,
) -> List[ConversationMessage]:
"""相关对话统一剪枝入口,消除 prune_dialog / prune_dataset 中的重复逻辑。
- conservative只删填充
- semantic / strict场景感知剪枝
"""
if pruning_mode == "conservative":
preserve_tokens = self._build_preserve_tokens(extraction)
return self._prune_fillers_only(msgs, preserve_tokens, dialog_label)
else:
return self._prune_with_scene_filter(msgs, extraction, dialog_label, pruning_mode)
def _prune_fillers_only(
self,
msgs: List[ConversationMessage],
preserve_tokens: List[str],
dialog_label: str,
) -> List[ConversationMessage]:
"""相关对话专用只删填充消息LLM 保护消息和实质内容一律保留。
不受 pruning_threshold 约束,删多少算多少(填充有多少删多少)。
至少保留 1 条消息。
注意:填充检测优先于 preserve_tokens 保护——填充消息本身无信息价值,
即使 LLM 误将其关键词放入 preserve_tokens 也应删除。
"""
to_delete_ids: set = set()
for m in msgs:
# 填充检测优先:先判断是否为填充,再看 LLM 保护
if self._is_filler_message(m):
to_delete_ids.add(id(m))
self._log(f" [填充] '{m.msg[:40]}' → 删除")
continue
if self._msg_matches_tokens(m, preserve_tokens):
self._log(f" [保护] '{m.msg[:40]}' → LLM保护跳过")
kept = [m for m in msgs if id(m) not in to_delete_ids]
if not kept and msgs:
kept = [msgs[0]]
deleted = len(msgs) - len(kept)
self._log(
f"[剪枝-相关] {dialog_label} 总消息={len(msgs)} "
f"填充删除={deleted} 保留={len(kept)}"
)
return kept
def _prune_with_scene_filter(
self,
msgs: List[ConversationMessage],
extraction: "DialogExtractionResponse",
dialog_label: str,
mode: str,
) -> List[ConversationMessage]:
"""场景感知剪枝,供 semantic / strict 两个阈值档位调用。
本函数体现剪枝系统的三层递进逻辑:
第一层conservative阈值 < 0.3
不进入本函数,由 _prune_fillers_only 处理。
保留标准:只问"有没有信息量",填充消息(嗯/好的/哈哈等)删除,其余一律保留。
第二层semantic阈值 [0.3, 0.6)
保留标准:内容价值优先,场景相关性是参考而非唯一标准。
- 填充消息 → 删除(最高优先级)
- 场景相关消息 → 保留
- 场景无关消息 → 有两次豁免机会:
1. 命中 scene_preserve_tokensLLM 标记的关键词/时间/金额等)→ 保留
2. 含情感词(感觉/压力/开心等)→ 保留(情感内容有记忆价值)
3. 两次豁免均未命中 → 删除
第三层strict阈值 [0.6, 0.9]
保留标准:场景相关性优先,无任何豁免。
- 填充消息 → 删除(最高优先级)
- 场景相关消息 → 保留
- 场景无关消息 → 直接删除preserve_keywords 和情感词在此模式下均不生效
至少保留 1 条消息(兜底取第一条)。
"""
# strict 模式收窄保护范围:只保护结构化关键信息(时间/编号/金额/联系方式/地址),
# 不保护 keywords / preserve_keywords让场景过滤能删掉更多内容。
# semantic 模式完整保护:包含 LLM 抽取的所有重要片段(含 keywords 和 preserve_keywords
if mode == "strict":
scene_preserve_tokens = (
extraction.times + extraction.ids + extraction.amounts +
extraction.contacts + extraction.addresses
)
else:
scene_preserve_tokens = self._build_preserve_tokens(extraction)
unrelated_snippets = extraction.scene_unrelated_snippets or []
to_delete_ids: set = set()
for m in msgs:
msg_text = m.msg.strip()
# 第一优先级:填充消息无论模式直接删除,不参与后续场景判断
if self._is_filler_message(m):
to_delete_ids.add(id(m))
self._log(f" [填充] '{msg_text[:40]}' → 删除")
continue
# 双向包含匹配:处理 LLM 返回片段与原始消息文本长度不完全一致的情况
is_scene_unrelated = any(
snip and (snip in msg_text or msg_text in snip)
for snip in unrelated_snippets
)
if is_scene_unrelated:
if mode == "strict":
# strict场景无关直接删除不做任何豁免
# 场景相关性是唯一裁决标准preserve_keywords 在此模式下不生效
to_delete_ids.add(id(m))
self._log(f" [场景无关-严格] '{msg_text[:40]}' → 删除")
elif mode == "semantic":
# semantic场景无关但有内容价值 → 保留
# 豁免第一层:命中 scene_preserve_tokens关键词/结构化信息保护)
if self._msg_matches_tokens(m, scene_preserve_tokens):
self._log(f" [保护] '{msg_text[:40]}' → 场景关键词保护,保留")
else:
# 豁免第二层:含情感词,认为有情境记忆价值,即使场景无关也保留
has_contextual_emotion = any(
word in msg_text
for word in ["感觉", "觉得", "心情", "开心", "难过", "高兴", "沮丧",
"喜欢", "讨厌", "", "", "担心", "害怕", "兴奋",
"压力", "", "疲惫", "", "焦虑", "委屈", "感动"]
)
if not has_contextual_emotion:
to_delete_ids.add(id(m))
self._log(f" [场景无关-语义] '{msg_text[:40]}' → 删除(无情感关联)")
else:
self._log(f" [场景关联-保留] '{msg_text[:40]}' → 有情感关联,保留")
else:
# 不在 scene_unrelated_snippets 中 → 场景相关,直接保留
if self._msg_matches_tokens(m, scene_preserve_tokens):
self._log(f" [保护] '{msg_text[:40]}' → LLM保护跳过")
# else: 普通场景相关消息,保留,不输出日志
kept = [m for m in msgs if id(m) not in to_delete_ids]
if not kept and msgs:
kept = [msgs[0]]
deleted = len(msgs) - len(kept)
self._log(
f"[剪枝-{mode}] {dialog_label} 总消息={len(msgs)} "
f"删除={deleted} 保留={len(kept)}"
)
return kept
def _build_preserve_tokens(self, extraction: "DialogExtractionResponse") -> List[str]:
"""统一构建 preserve_tokens合并 LLM 抽取的所有重要片段。"""
return (
extraction.times + extraction.ids + extraction.amounts +
extraction.contacts + extraction.addresses + extraction.keywords +
extraction.preserve_keywords
)
def _msg_matches_tokens(self, message: ConversationMessage, tokens: List[str]) -> bool: def _msg_matches_tokens(self, message: ConversationMessage, tokens: List[str]) -> bool:
"""判断消息是否包含任意抽取到的重要片段。""" """判断消息是否包含任意抽取到的重要片段。"""
if not tokens: if not tokens:
@@ -500,66 +628,62 @@ class SemanticPruner:
proportion = float(self.config.pruning_threshold) proportion = float(self.config.pruning_threshold)
extraction = await self._extract_dialog_important(dialog.content) extraction = await self._extract_dialog_important(dialog.content)
pruning_mode = self._get_pruning_mode()
self._log(f"[剪枝-模式] 阈值={proportion} → 模式={pruning_mode}")
if extraction.is_related: if extraction.is_related:
# 相关对话不剪枝 kept = self._apply_related_dialog_pruning(
dialog.context.msgs, extraction, f"对话ID={dialog.id}", pruning_mode
)
dialog.context = ConversationContext(msgs=kept)
return dialog return dialog
# 在不相关对话中,识别重要/不重要消息 # 在不相关对话中,LLM 已通过 preserve_tokens 标记需要保护的内容
tokens = extraction.times + extraction.ids + extraction.amounts + extraction.contacts + extraction.addresses + extraction.keywords preserve_tokens = self._build_preserve_tokens(extraction)
msgs = dialog.context.msgs msgs = dialog.context.msgs
imp_unrel_msgs: List[ConversationMessage] = []
unimp_unrel_msgs: List[ConversationMessage] = [] # 分类:填充 / 其他可删LLM保护消息通过不加入任何桶来隐式保护
filler_ids: set = set()
deletable: List[ConversationMessage] = []
for m in msgs: for m in msgs:
if self._msg_matches_tokens(m, tokens) or self._is_important_message(m): if self._msg_matches_tokens(m, preserve_tokens):
imp_unrel_msgs.append(m) pass # 保护消息:不加入任何桶,不会被删除
elif self._is_filler_message(m):
filler_ids.add(id(m))
else: else:
unimp_unrel_msgs.append(m) deletable.append(m)
# 计算总删除目标数量
# 计算删除目标
total_unrel = len(msgs) total_unrel = len(msgs)
delete_target = int(total_unrel * proportion) delete_target = int(total_unrel * proportion)
if proportion > 0 and total_unrel > 0 and delete_target == 0: if proportion > 0 and total_unrel > 0 and delete_target == 0:
delete_target = 1 delete_target = 1
imp_del_cap = min(int(len(imp_unrel_msgs) * proportion), len(imp_unrel_msgs)) max_deletable = min(len(filler_ids) + len(deletable), max(0, total_unrel - 1))
unimp_del_cap = len(unimp_unrel_msgs)
max_capacity = max(0, len(msgs) - 1)
max_deletable = min(imp_del_cap + unimp_del_cap, max_capacity)
delete_target = min(delete_target, max_deletable) delete_target = min(delete_target, max_deletable)
# 删除配额分配
del_unimp = min(delete_target, unimp_del_cap)
rem = delete_target - del_unimp
del_imp = min(rem, imp_del_cap)
# 选取删除集合 # 优先删填充,再删其他可删消息(按出现顺序)
unimp_delete_ids = [] to_delete_ids: set = set()
imp_delete_ids = []
if del_unimp > 0:
# 按出现顺序选取前 del_unimp 条不重要消息进行删除(确定性、可复现)
unimp_delete_ids = [id(m) for m in unimp_unrel_msgs[:del_unimp]]
if del_imp > 0:
imp_sorted = sorted(imp_unrel_msgs, key=lambda m: self._importance_score(m))
imp_delete_ids = [id(m) for m in imp_sorted[:del_imp]]
# 统计实际删除数量(重要/不重要)
actual_unimp_deleted = 0
actual_imp_deleted = 0
kept_msgs = []
delete_targets = set(unimp_delete_ids) | set(imp_delete_ids)
for m in msgs: for m in msgs:
mid = id(m) if len(to_delete_ids) >= delete_target:
if mid in delete_targets: break
if mid in set(unimp_delete_ids) and actual_unimp_deleted < del_unimp: if id(m) in filler_ids:
actual_unimp_deleted += 1 to_delete_ids.add(id(m))
continue for m in deletable:
if mid in set(imp_delete_ids) and actual_imp_deleted < del_imp: if len(to_delete_ids) >= delete_target:
actual_imp_deleted += 1 break
continue to_delete_ids.add(id(m))
kept_msgs.append(m)
kept_msgs = [m for m in msgs if id(m) not in to_delete_ids]
if not kept_msgs and msgs: if not kept_msgs and msgs:
kept_msgs = [msgs[0]] kept_msgs = [msgs[0]]
deleted_total = actual_unimp_deleted + actual_imp_deleted deleted_total = len(msgs) - len(kept_msgs)
protected_count = len(msgs) - len(filler_ids) - len(deletable)
self._log( self._log(
f"[剪枝-对话] 对话ID={dialog.id} 总消息={len(msgs)} 删除目标={delete_target} 实删={deleted_total} 保留={len(kept_msgs)}" f"[剪枝-对话] 对话ID={dialog.id} 总消息={len(msgs)} "
f"(保护={protected_count} 填充={len(filler_ids)} 可删={len(deletable)}) "
f"删除目标={delete_target} 实删={deleted_total} 保留={len(kept_msgs)}"
) )
dialog.context = ConversationContext(msgs=kept_msgs) dialog.context = ConversationContext(msgs=kept_msgs)
@@ -591,54 +715,105 @@ class SemanticPruner:
f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch} 模式=消息级独立判断" f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch} 模式=消息级独立判断"
) )
pruning_mode = self._get_pruning_mode()
self._log(f"[剪枝-数据集] 阈值={proportion} → 剪枝阶段={pruning_mode}")
result: List[DialogData] = [] result: List[DialogData] = []
total_original_msgs = 0 total_original_msgs = 0
total_deleted_msgs = 0 total_deleted_msgs = 0
for d_idx, dd in enumerate(dialogs): # 统计对象:直接收集结构化数据,无需事后正则解析
stats = {
"scene": self.config.pruning_scene,
"dialog_total": len(dialogs),
"deletion_ratio": proportion,
"enabled": self.config.pruning_switch,
"pruning_mode": pruning_mode,
"related_count": 0,
"unrelated_count": 0,
"related_indices": [],
"unrelated_indices": [],
"total_deleted_messages": 0,
"remaining_dialogs": 0,
"dialogs": [],
}
# 并发执行所有对话的 LLM 抽取(获取 preserve_keywords 等保护信息)
semaphore = asyncio.Semaphore(self.max_concurrent)
async def extract_with_semaphore(dd: DialogData) -> DialogExtractionResponse:
async with semaphore:
try:
return await self._extract_dialog_important(dd.content)
except Exception as e:
self._log(f"[剪枝-LLM] 对话抽取失败,使用降级策略: {str(e)[:100]}")
return DialogExtractionResponse(is_related=True)
extraction_tasks = [extract_with_semaphore(dd) for dd in dialogs]
extraction_results: List[DialogExtractionResponse] = await asyncio.gather(*extraction_tasks)
for d_idx, (dd, extraction) in enumerate(zip(dialogs, extraction_results)):
msgs = dd.context.msgs msgs = dd.context.msgs
original_count = len(msgs) original_count = len(msgs)
total_original_msgs += original_count total_original_msgs += original_count
# ========== 问答对保护(已注释,暂不启用,留作观察) ========== # 相关对话:根据阶段决定处理力度
# qa_pairs = self._identify_qa_pairs(msgs) if extraction.is_related:
# protected_indices = self._get_protected_indices(msgs, qa_pairs, window_size=0) stats["related_count"] += 1
# ======================================================== stats["related_indices"].append(d_idx + 1)
kept = self._apply_related_dialog_pruning(
msgs, extraction, f"对话 {d_idx+1}", pruning_mode
)
deleted_count = original_count - len(kept)
total_deleted_msgs += deleted_count
dd.context.msgs = kept
result.append(dd)
stats["dialogs"].append({
"index": d_idx + 1,
"is_related": True,
"total_messages": original_count,
"deleted": deleted_count,
"kept": len(kept),
})
continue
# 消息级分类:每条消息独立判断 stats["unrelated_count"] += 1
important_msgs = [] # 重要消息(保留) stats["unrelated_indices"].append(d_idx + 1)
unimportant_msgs = [] # 不重要消息(可删除)
filler_msgs = [] # 填充消息(优先删除)
# 判断是否需要详细日志仅对前N条消息记录 # 从 LLM 抽取结果中获取所有需要保留的 token
preserve_tokens = self._build_preserve_tokens(extraction)
# 判断是否需要详细日志
should_log_details = self._detailed_prune_logging and original_count <= self._max_debug_msgs_per_dialog should_log_details = self._detailed_prune_logging and original_count <= self._max_debug_msgs_per_dialog
if self._detailed_prune_logging and original_count > self._max_debug_msgs_per_dialog: if self._detailed_prune_logging and original_count > self._max_debug_msgs_per_dialog:
self._log(f" 对话[{d_idx}]消息数={original_count},仅采样前{self._max_debug_msgs_per_dialog}条进行详细日志") self._log(f" 对话[{d_idx}]消息数={original_count},仅采样前{self._max_debug_msgs_per_dialog}条进行详细日志")
if extraction.preserve_keywords:
self._log(f" 对话[{d_idx}] LLM抽取到情绪/兴趣保护词: {extraction.preserve_keywords}")
# 消息级分类LLM保护 / 填充 / 其他可删
llm_protected_msgs = [] # LLM 保护消息preserve_tokens 命中):绝对不可删除
filler_msgs = [] # 填充消息(优先删除)
deletable_msgs = [] # 其余消息(按比例删除)
for idx, m in enumerate(msgs): for idx, m in enumerate(msgs):
msg_text = m.msg.strip() msg_text = m.msg.strip()
# ========== 问答对保护判断(已注释) ========== if self._msg_matches_tokens(m, preserve_tokens):
# if idx in protected_indices: llm_protected_msgs.append((idx, m))
# important_msgs.append((idx, m)) if should_log_details or idx < self._max_debug_msgs_per_dialog:
# self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(问答对保护)") self._log(f" [{idx}] '{msg_text[:30]}...' → 保护LLM不可删")
# ========================================== elif self._is_filler_message(m):
# 填充消息(寒暄、表情等)
if self._is_filler_message(m):
filler_msgs.append((idx, m)) filler_msgs.append((idx, m))
if should_log_details or idx < self._max_debug_msgs_per_dialog: if should_log_details or idx < self._max_debug_msgs_per_dialog:
self._log(f" [{idx}] '{msg_text[:30]}...' → 填充") self._log(f" [{idx}] '{msg_text[:30]}...' → 填充")
# 重要信息(学号、成绩、时间、金额等)
elif self._is_important_message(m):
important_msgs.append((idx, m))
if should_log_details or idx < self._max_debug_msgs_per_dialog:
self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(场景规则)")
# 其他消息
else: else:
unimportant_msgs.append((idx, m)) deletable_msgs.append((idx, m))
if should_log_details or idx < self._max_debug_msgs_per_dialog: if should_log_details or idx < self._max_debug_msgs_per_dialog:
self._log(f" [{idx}] '{msg_text[:30]}...'不重要") self._log(f" [{idx}] '{msg_text[:30]}...'可删")
# important_msgs 仅用于日志统计
important_msgs = llm_protected_msgs
# 计算删除配额 # 计算删除配额
delete_target = int(original_count * proportion) delete_target = int(original_count * proportion)
@@ -649,37 +824,23 @@ class SemanticPruner:
max_deletable = max(0, original_count - 1) max_deletable = max(0, original_count - 1)
delete_target = min(delete_target, max_deletable) delete_target = min(delete_target, max_deletable)
# 删除策略:优先删填充消息,再删除不重要消息 # 删除策略:优先删填充消息,再按出现顺序删其余可删消息
to_delete_indices = set() to_delete_indices = set()
deleted_details = [] # 记录删除的消息详情 deleted_details = []
# 第一步:删除填充消息 # 第一步:删除填充消息
filler_to_delete = min(len(filler_msgs), delete_target) for idx, msg in filler_msgs:
for i in range(filler_to_delete): if len(to_delete_indices) >= delete_target:
idx, msg = filler_msgs[i] break
to_delete_indices.add(idx) to_delete_indices.add(idx)
deleted_details.append(f"[{idx}] 填充: '{msg.msg[:50]}'") deleted_details.append(f"[{idx}] 填充: '{msg.msg[:50]}'")
# 第二步:如果还需要删除,删除不重要消息 # 第二步:如果还需要删除,按出现顺序删可删消息
remaining_quota = delete_target - len(to_delete_indices) for idx, msg in deletable_msgs:
if remaining_quota > 0: if len(to_delete_indices) >= delete_target:
unimp_to_delete = min(len(unimportant_msgs), remaining_quota) break
for i in range(unimp_to_delete): to_delete_indices.add(idx)
idx, msg = unimportant_msgs[i] deleted_details.append(f"[{idx}] 可删: '{msg.msg[:50]}'")
to_delete_indices.add(idx)
deleted_details.append(f"[{idx}] 不重要: '{msg.msg[:50]}'")
# 第三步:如果还需要删除,按重要性分数删除重要消息
remaining_quota = delete_target - len(to_delete_indices)
if remaining_quota > 0 and important_msgs:
# 按重要性分数排序(分数低的优先删除)
imp_sorted = sorted(important_msgs, key=lambda x: self._importance_score(x[1]))
imp_to_delete = min(len(imp_sorted), remaining_quota)
for i in range(imp_to_delete):
idx, msg = imp_sorted[i]
to_delete_indices.add(idx)
score = self._importance_score(msg)
deleted_details.append(f"[{idx}] 重要(分数{score}): '{msg.msg[:50]}'")
# 执行删除 # 执行删除
kept_msgs = [] kept_msgs = []
@@ -707,23 +868,38 @@ class SemanticPruner:
self._log( self._log(
f"[剪枝-对话] 对话 {d_idx+1} 总消息={original_count} " f"[剪枝-对话] 对话 {d_idx+1} 总消息={original_count} "
f"(重要={len(important_msgs)} 不重要={len(unimportant_msgs)} 填充={len(filler_msgs)}) " f"(保护={len(important_msgs)} 填充={len(filler_msgs)} 可删={len(deletable_msgs)}) "
f"删除={deleted_count} 保留={len(kept_msgs)}" f"删除={deleted_count} 保留={len(kept_msgs)}"
) )
stats["dialogs"].append({
"index": d_idx + 1,
"is_related": False,
"total_messages": original_count,
"protected": len(important_msgs),
"fillers": len(filler_msgs),
"deletable": len(deletable_msgs),
"deleted": deleted_count,
"kept": len(kept_msgs),
})
result.append(dd) result.append(dd)
self._log(f"[剪枝-数据集] 剩余对话数={len(result)}") # 补全统计对象
stats["total_deleted_messages"] = total_deleted_msgs
stats["remaining_dialogs"] = len(result)
# 保存日志 self._log(f"[剪枝-数据集] 剩余对话数={len(result)}")
self._log(f"[剪枝-数据集] 相关对话数={stats['related_count']} 不相关对话数={stats['unrelated_count']}")
self._log(f"[剪枝-数据集] 总删除 {total_deleted_msgs}")
# 直接序列化统计对象,无需正则解析
try: try:
from app.core.config import settings from app.core.config import settings
settings.ensure_memory_output_dir() settings.ensure_memory_output_dir()
log_output_path = settings.get_memory_output_path("pruned_terminal.json") log_output_path = settings.get_memory_output_path("pruned_terminal.json")
sanitized_logs = [self._sanitize_log_line(l) for l in self.run_logs]
payload = self._parse_logs_to_structured(sanitized_logs)
with open(log_output_path, "w", encoding="utf-8") as f: with open(log_output_path, "w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2) json.dump(stats, f, ensure_ascii=False, indent=2)
except Exception as e: except Exception as e:
self._log(f"[剪枝-数据集] 保存终端输出日志失败:{e}") self._log(f"[剪枝-数据集] 保存终端输出日志失败:{e}")
@@ -743,114 +919,4 @@ class SemanticPruner:
pass pass
print(msg) print(msg)
def _sanitize_log_line(self, line: str) -> str:
"""移除行首的方括号标签前缀,例如 [剪枝-数据集] 或 [剪枝-对话]。"""
try:
return re.sub(r"^\[[^\]]+\]\s*", "", line)
except Exception:
return line
def _parse_logs_to_structured(self, logs: List[str]) -> dict:
"""将已去前缀的日志列表解析为结构化 JSON便于数据对接。"""
summary = {
"scene": self.config.pruning_scene,
"dialog_total": None,
"deletion_ratio": None,
"enabled": None,
"related_count": None,
"unrelated_count": None,
"related_indices": [],
"unrelated_indices": [],
"total_deleted_messages": None,
"remaining_dialogs": None,
}
dialogs = []
# 解析函数
def parse_int(value: str) -> Optional[int]:
try:
return int(value)
except Exception:
return None
def parse_float(value: str) -> Optional[float]:
try:
return float(value)
except Exception:
return None
def parse_indices(s: str) -> List[int]:
s = s.strip()
if not s:
return []
parts = [p.strip() for p in s.split(",") if p.strip()]
out: List[int] = []
for p in parts:
try:
out.append(int(p))
except Exception:
pass
return out
# 正则
re_header = re.compile(r"对话总数=(\d+)\s+场景=([^\s]+)\s+删除比例=([0-9.]+)\s+开关=(True|False)")
re_counts = re.compile(r"相关对话数=(\d+)\s+不相关对话数=(\d+)")
re_indices = re.compile(r"相关对话:第\[(.*?)\]段;不相关对话:第\[(.*?)\]段")
re_dialog = re.compile(r"对话\s+(\d+)\s+总消息=(\d+)\s+分配删除=(\d+)\s+实删=(\d+)\s+保留=(\d+)")
re_total_del = re.compile(r"总删除\s+(\d+)\s+条")
re_remaining = re.compile(r"剩余对话数=(\d+)")
for line in logs:
# 第一行:总览
m = re_header.search(line)
if m:
summary["dialog_total"] = parse_int(m.group(1))
# 顶层 scene 依配置,这里不覆盖,但也可校验 m.group(2)
summary["deletion_ratio"] = parse_float(m.group(3))
summary["enabled"] = True if m.group(4) == "True" else False
continue
# 第二行:相关/不相关数量
m = re_counts.search(line)
if m:
summary["related_count"] = parse_int(m.group(1))
summary["unrelated_count"] = parse_int(m.group(2))
continue
# 第三行:相关/不相关索引
m = re_indices.search(line)
if m:
summary["related_indices"] = parse_indices(m.group(1))
summary["unrelated_indices"] = parse_indices(m.group(2))
continue
# 对话级统计
m = re_dialog.search(line)
if m:
dialogs.append({
"index": parse_int(m.group(1)),
"total_messages": parse_int(m.group(2)),
"quota_delete": parse_int(m.group(3)),
"actual_deleted": parse_int(m.group(4)),
"kept": parse_int(m.group(5)),
})
continue
# 全局删除总数
m = re_total_del.search(line)
if m:
summary["total_deleted_messages"] = parse_int(m.group(1))
continue
# 剩余对话数
m = re_remaining.search(line)
if m:
summary["remaining_dialogs"] = parse_int(m.group(1))
continue
return {
"scene": summary["scene"],
"timestamp": datetime.now().isoformat(),
"summary": {k: v for k, v in summary.items() if k != "scene"},
"dialogs": dialogs,
}

View File

@@ -1,66 +1,25 @@
""" """
场景特定配置 - 为不同场景提供定制化的剪枝规则 场景特定配置 - 统一填充词库
功能: 重要性判断已完全交由 extracat_Pruning.jinja2 提示词 + LLM preserve_tokens 机制承担。
- 场景特定的重要信息识别模式 本模块仅保留统一填充词库filler_phrases用于识别无意义寒暄/表情/口头禅。
- 场景特定的重要性评分权重 所有场景共用同一份词库,场景差异由 LLM 语义判断处理。
- 场景特定的填充词库
- 场景特定的问答对识别规则
""" """
from typing import Dict, List, Set, Tuple from typing import List, Set
from dataclasses import dataclass, field from dataclasses import dataclass, field
@dataclass @dataclass
class ScenePatterns: class ScenePatterns:
"""场景特定的识别模式""" """场景特定的识别模式(仅保留填充词库)"""
# 重要信息的正则模式(优先级从高到低)
high_priority_patterns: List[Tuple[str, int]] = field(default_factory=list) # (pattern, weight)
medium_priority_patterns: List[Tuple[str, int]] = field(default_factory=list)
low_priority_patterns: List[Tuple[str, int]] = field(default_factory=list)
# 填充词库(无意义对话)
filler_phrases: Set[str] = field(default_factory=set) filler_phrases: Set[str] = field(default_factory=set)
# 问句关键词(用于识别问答对)
question_keywords: Set[str] = field(default_factory=set)
# 决策性/承诺性关键词
decision_keywords: Set[str] = field(default_factory=set)
class SceneConfigRegistry: class SceneConfigRegistry:
"""场景配置注册表 - 管理所有场景的特定配置""" """场景配置注册表 - 所有场景共用统一填充词库"""
# 基础通用模式(所有场景共享) BASE_FILLERS: Set[str] = {
BASE_HIGH_PRIORITY = [
(r"订单号|工单|申请号|编号|ID|账号|账户", 5),
(r"金额|费用|价格|¥|¥|\d+元", 5),
(r"\d{11}", 4), # 手机号
(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", 4), # 邮箱
]
BASE_MEDIUM_PRIORITY = [
(r"\d{4}-\d{1,2}-\d{1,2}", 3), # 日期
(r"\d{4}\d{1,2}月\d{1,2}日", 3),
(r"电话|手机号|微信|QQ|联系方式", 3),
(r"地址|地点|位置", 2),
(r"时间|日期|有效期|截止", 2),
(r"今天|明天|后天|昨天|前天", 3), # 相对时间(提高权重)
(r"下周|下月|下年|上周|上月|上年|本周|本月|本年", 3),
(r"今年|去年|明年", 3),
]
BASE_LOW_PRIORITY = [
(r"\d{1,2}:\d{2}", 2), # 时间点 HH:MM
(r"\d{1,2}点\d{0,2}分?", 2), # 时间点 X点Y分 或 X点
(r"上午|下午|中午|晚上|早上|傍晚|凌晨", 2), # 时段(提高权重并扩充)
(r"AM|PM|am|pm", 1),
]
BASE_FILLERS = {
# 基础寒暄 # 基础寒暄
"你好", "您好", "在吗", "在的", "在呢", "", "嗯嗯", "", "哦哦", "你好", "您好", "在吗", "在的", "在呢", "", "嗯嗯", "", "哦哦",
"好的", "", "", "可以", "不可以", "谢谢", "多谢", "感谢", "好的", "", "", "可以", "不可以", "谢谢", "多谢", "感谢",
@@ -69,7 +28,26 @@ class SceneConfigRegistry:
"哈哈", "呵呵", "哈哈哈", "嘿嘿", "嘻嘻", "hiahia", "哈哈", "呵呵", "哈哈哈", "嘿嘿", "嘻嘻", "hiahia",
"", "", "", "", "", "", "嗯哼", "", "", "", "", "", "", "嗯哼",
# 确认词 # 确认词
"是的", "", "对的", "没错", "嗯嗯", "好嘞", "收到", "明白", "了解", "知道了", "是的", "", "对的", "没错", "好嘞", "收到", "明白", "了解", "知道了",
# 服务类套话
"请问", "请稍等", "稍等", "马上", "立即",
"正在查询", "正在处理", "正在为您", "帮您查一下",
"还有其他问题吗", "还需要什么帮助", "很高兴为您服务",
"感谢您的耐心等待", "抱歉让您久等了",
"已记录", "已反馈", "已转接", "已升级",
"祝您生活愉快", "欢迎下次咨询",
# 外呼套话
"", "hello", "打扰了", "不好意思",
"方便接电话吗", "现在方便吗", "占用您一点时间",
"我是", "我们是", "我们公司", "我们这边",
"了解一下", "介绍一下", "简单说一下",
"考虑考虑", "想一想", "再说", "再看看",
"不需要", "不感兴趣", "没兴趣", "不用了",
"没问题", "那就这样", "再联系", "回头聊", "有需要再说",
# 教育场景套话
"老师好", "同学们好", "上课", "下课", "起立", "坐下",
"举手", "请坐", "很好", "不错", "继续",
"下一个", "下一题", "下一位", "还有吗", "还有问题吗",
# 标点和符号 # 标点和符号
"。。。", "...", "???", "", "!!!", "", "。。。", "...", "???", "", "!!!", "",
# 表情符号 # 表情符号
@@ -82,245 +60,7 @@ class SceneConfigRegistry:
"emmm", "emm", "em", "mmp", "wtf", "omg", "emmm", "emm", "em", "mmp", "wtf", "omg",
} }
BASE_QUESTION_KEYWORDS = {
"什么", "为什么", "怎么", "如何", "哪里", "哪个", "", "多少", "几点", "何时", ""
}
BASE_DECISION_KEYWORDS = {
"必须", "一定", "务必", "需要", "要求", "规定", "应该",
"承诺", "保证", "确保", "负责", "同意", "答应"
}
@classmethod @classmethod
def get_education_config(cls) -> ScenePatterns: def get_config(cls, scene: str = "") -> ScenePatterns:
"""教育场景配置""" """所有场景统一返回同一份填充词库"""
return ScenePatterns( return ScenePatterns(filler_phrases=cls.BASE_FILLERS)
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
# 成绩相关(最高优先级)
(r"成绩|分数|得分|满分|及格|不及格", 6),
(r"GPA|绩点|学分|平均分", 6),
(r"\d+分|\d+\.?\d*分", 5), # 具体分数
(r"排名|名次|第.{1,3}名", 5), # 支持"第三名"、"第1名"等
# 学籍信息
(r"学号|学生证|教师工号|工号", 5),
(r"班级|年级|专业|院系", 4),
# 课程相关
(r"课程|科目|学科|必修|选修", 4),
(r"教材|课本|教科书|参考书", 4),
(r"章节|第.{1,3}章|第.{1,3}节", 3), # 支持"第三章"、"第1章"等
# 学科内容(新增)
(r"微积分|导数|积分|函数|极限|微分", 4),
(r"代数|几何|三角|概率|统计", 4),
(r"物理|化学|生物|历史|地理", 4),
(r"英语|语文|数学|政治|哲学", 4),
(r"定义|定理|公式|概念|原理|法则", 3),
(r"例题|解题|证明|推导|计算", 3),
],
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
# 教学活动
(r"作业|练习|习题|题目", 3),
(r"考试|测验|测试|考核|期中|期末", 3),
(r"上课|下课|课堂|讲课", 2),
(r"提问|回答|发言|讨论", 2),
(r"问一下|请教|咨询|询问", 2), # 新增:问询相关
(r"理解|明白|懂|掌握|学会", 2), # 新增:学习状态
# 时间安排
(r"课表|课程表|时间表", 3),
(r"第.{1,3}节课|第.{1,3}周", 2), # 支持"第三节课"、"第1周"等
],
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
(r"老师|教师|同学|学生", 1),
(r"教室|实验室|图书馆", 1),
],
filler_phrases=cls.BASE_FILLERS | {
# 教育场景特有填充词(移除了"明白了"、"懂了"、"不懂"等,这些在教育场景中有意义)
"老师好", "同学们好", "上课", "下课", "起立", "坐下",
"举手", "请坐", "很好", "不错", "继续",
"下一个", "下一题", "下一位", "还有吗", "还有问题吗",
},
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
"为啥", "", "咋办", "怎样", "如何做",
"能不能", "可不可以", "行不行", "对不对", "是不是",
},
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
"必考", "重点", "考点", "难点", "关键",
"记住", "背诵", "掌握", "理解", "复习",
}
)
@classmethod
def get_online_service_config(cls) -> ScenePatterns:
"""在线服务场景配置"""
return ScenePatterns(
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
# 工单相关(最高优先级)
(r"工单号|工单编号|ticket|TK\d+", 6),
(r"工单状态|处理中|已解决|已关闭|待处理", 5),
(r"优先级|紧急|高优先级|P0|P1|P2", 5),
# 产品信息
(r"产品型号|型号|SKU|产品编号", 5),
(r"序列号|SN|设备号", 5),
(r"版本号|软件版本|固件版本", 4),
# 问题描述
(r"故障|错误|异常|bug|问题", 4),
(r"错误代码|故障代码|error code", 5),
(r"无法|不能|失败|报错", 3),
],
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
# 服务相关
(r"退款|退货|换货|补发", 4),
(r"发票|收据|凭证", 3),
(r"物流|快递|运单号", 3),
(r"保修|质保|售后", 3),
# 时效相关
(r"SLA|响应时间|处理时长", 4),
(r"超时|延迟|等待", 2),
],
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
(r"客服|工程师|技术支持", 1),
(r"用户|客户|会员", 1),
],
filler_phrases=cls.BASE_FILLERS | {
# 在线服务特有填充词
"您好", "请问", "请稍等", "稍等", "马上", "立即",
"正在查询", "正在处理", "正在为您", "帮您查一下",
"还有其他问题吗", "还需要什么帮助", "很高兴为您服务",
"感谢您的耐心等待", "抱歉让您久等了",
"已记录", "已反馈", "已转接", "已升级",
"祝您生活愉快", "再见", "欢迎下次咨询",
},
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
"能否", "可否", "是否", "有没有", "能不能",
"怎么办", "如何处理", "怎么解决",
},
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
"立即处理", "马上解决", "尽快", "优先",
"升级", "转接", "派单", "跟进",
"补偿", "赔偿", "退款", "换货",
}
)
@classmethod
def get_outbound_config(cls) -> ScenePatterns:
"""外呼场景配置"""
return ScenePatterns(
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
# 意向相关(最高优先级)
(r"意向|意愿|兴趣|感兴趣", 6),
(r"A类|B类|C类|D类|高意向|低意向", 6),
(r"成交|签约|下单|购买|确认", 6),
# 联系信息(外呼场景中更重要)
(r"预约|约定|安排|确定时间", 5),
(r"下次联系|回访|跟进", 5),
(r"方便|有空|可以|时间", 4),
# 通话状态
(r"接通|未接通|占线|关机|停机", 4),
(r"通话时长|通话时间", 3),
],
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
# 客户信息
(r"姓名|称呼|先生|女士", 3),
(r"公司|单位|职位|职务", 3),
(r"需求|要求|期望", 3),
# 跟进状态
(r"跟进状态|进展|进度", 3),
(r"已联系|待联系|联系中", 2),
(r"拒绝|不感兴趣|考虑|再说", 3),
],
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
(r"销售|客户经理|业务员", 1),
(r"产品|服务|方案", 1),
],
filler_phrases=cls.BASE_FILLERS | {
# 外呼场景特有填充词
"您好", "", "hello", "打扰了", "不好意思",
"方便接电话吗", "现在方便吗", "占用您一点时间",
"我是", "我们是", "我们公司", "我们这边",
"了解一下", "介绍一下", "简单说一下",
"考虑考虑", "想一想", "再说", "再看看",
"不需要", "不感兴趣", "没兴趣", "不用了",
"好的", "", "可以", "没问题", "那就这样",
"再联系", "回头聊", "有需要再说",
},
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
"有没有", "需不需要", "要不要", "考虑不考虑",
"了解吗", "知道吗", "听说过吗",
"方便吗", "有空吗", "在吗",
},
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
"确定", "决定", "选择", "购买", "下单",
"预约", "安排", "约定", "确认",
"跟进", "回访", "联系", "沟通",
}
)
@classmethod
def get_config(cls, scene: str, fallback_to_generic: bool = True) -> ScenePatterns:
"""根据场景名称获取配置
Args:
scene: 场景名称 ('education', 'online_service', 'outbound' 或其他)
fallback_to_generic: 如果场景不存在,是否降级到通用配置
Returns:
对应场景的配置,如果场景不存在:
- fallback_to_generic=True: 返回通用配置(仅基础规则)
- fallback_to_generic=False: 抛出异常
"""
scene_map = {
'education': cls.get_education_config,
'online_service': cls.get_online_service_config,
'outbound': cls.get_outbound_config,
}
if scene in scene_map:
return scene_map[scene]()
if fallback_to_generic:
# 返回通用配置(仅包含基础规则,不包含场景特定规则)
return cls.get_generic_config()
else:
raise ValueError(f"不支持的场景: {scene},支持的场景: {list(scene_map.keys())}")
@classmethod
def get_generic_config(cls) -> ScenePatterns:
"""通用场景配置 - 仅包含基础规则,适用于未定义的场景
这是一个保守的配置,只使用最通用的规则,避免误删重要信息
"""
return ScenePatterns(
high_priority_patterns=cls.BASE_HIGH_PRIORITY,
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY,
low_priority_patterns=cls.BASE_LOW_PRIORITY,
filler_phrases=cls.BASE_FILLERS,
question_keywords=cls.BASE_QUESTION_KEYWORDS,
decision_keywords=cls.BASE_DECISION_KEYWORDS
)
@classmethod
def get_all_scenes(cls) -> List[str]:
"""获取所有预定义场景的列表"""
return ['education', 'online_service', 'outbound']
@classmethod
def is_scene_supported(cls, scene: str) -> bool:
"""检查场景是否有专门的配置支持
Args:
scene: 场景名称
Returns:
True: 有专门配置
False: 将使用通用配置
"""
return scene in cls.get_all_scenes()

View File

@@ -384,6 +384,14 @@ class ExtractionOrchestrator:
logger.info(f"陈述句提取完成,共提取 {len(all_statements)} 条陈述句") logger.info(f"陈述句提取完成,共提取 {len(all_statements)} 条陈述句")
# 试运行模式下,所有分块提取完成后发送完成事件
if self.progress_callback and self.is_pilot_run:
await self.progress_callback(
"knowledge_extraction_complete",
f"陈述句提取完成,共提取 {len(all_statements)}",
{"total_statements": len(all_statements), "total_chunks": total_chunks}
)
return dialog_data_list return dialog_data_list
async def _extract_triplets( async def _extract_triplets(

View File

@@ -5,7 +5,7 @@ from typing import List, Dict, Optional
from app.core.logging_config import get_memory_logger from app.core.logging_config import get_memory_logger
from app.core.memory.llm_tools.openai_client import OpenAIClient from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤 from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
from app.core.memory.models.triplet_models import TripletExtractionResponse from app.core.memory.models.triplet_models import TripletExtractionResponse
from app.core.memory.models.message_models import DialogData, Statement from app.core.memory.models.message_models import DialogData, Statement
from app.core.memory.models.ontology_extraction_models import OntologyTypeList from app.core.memory.models.ontology_extraction_models import OntologyTypeList
@@ -14,15 +14,15 @@ from app.core.memory.utils.log.logging_utils import prompt_logger
logger = get_memory_logger(__name__) logger = get_memory_logger(__name__)
class TripletExtractor: class TripletExtractor:
"""Extracts knowledge triplets and entities from statements using LLM""" """Extracts knowledge triplets and entities from statements using LLM"""
def __init__( def __init__(
self, self,
llm_client: OpenAIClient, llm_client: OpenAIClient,
ontology_types: Optional[OntologyTypeList] = None, ontology_types: Optional[OntologyTypeList] = None,
language: str = "zh"): language: str = "zh"
):
"""Initialize the TripletExtractor with an LLM client """Initialize the TripletExtractor with an LLM client
Args: Args:
@@ -65,7 +65,8 @@ class TripletExtractor:
# Create messages for LLM # Create messages for LLM
messages = [ messages = [
{"role": "system", "content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."}, {"role": "system",
"content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."},
{"role": "user", "content": prompt_content} {"role": "user", "content": prompt_content}
] ]
@@ -116,7 +117,8 @@ class TripletExtractor:
logger.error(f"Error processing statement: {e}", exc_info=True) logger.error(f"Error processing statement: {e}", exc_info=True)
return TripletExtractionResponse(triplets=[], entities=[]) return TripletExtractionResponse(triplets=[], entities=[])
async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[str, TripletExtractionResponse]: async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[
str, TripletExtractionResponse]:
"""Extract triplets and entities from statements """Extract triplets and entities from statements
Args: Args:

View File

@@ -1,11 +1,11 @@
""" """
自我反思引擎实现 Self-Reflection Engine Implementation
该模块实现了记忆系统的自我反思功能,包括: This module implements the self-reflection functionality of the memory system, including:
1. 基于时间的反思 - 根据时间周期触发反思 1. Time-based reflection - Triggers reflection based on time cycles
2. 基于事实的反思 - 检测记忆冲突并解决 2. Fact-based reflection - Detects and resolves memory conflicts
3. 综合反思 - 整合多种反思策略 3. Comprehensive reflection - Integrates multiple reflection strategies
4. 反思结果应用 - 更新记忆库 4. Reflection result application - Updates memory database
""" """
import asyncio import asyncio
@@ -38,7 +38,7 @@ from app.schemas.memory_storage_schema import (
) )
from pydantic import BaseModel from pydantic import BaseModel
# 配置日志 # Configure logging
_root_logger = logging.getLogger() _root_logger = logging.getLogger()
if not _root_logger.handlers: if not _root_logger.handlers:
logging.basicConfig( logging.basicConfig(
@@ -49,35 +49,62 @@ else:
_root_logger.setLevel(logging.INFO) _root_logger.setLevel(logging.INFO)
class TranslationResponse(BaseModel): class TranslationResponse(BaseModel):
"""翻译响应模型""" """Translation response model for language conversion"""
data: str data: str
class ReflectionRange(str, Enum): class ReflectionRange(str, Enum):
"""反思范围枚举""" """
PARTIAL = "partial" # 从检索结果中反思 Reflection range enumeration
ALL = "all" # 从整个数据库中反思
Defines the scope of data to be included in reflection operations.
"""
PARTIAL = "partial" # Reflect from retrieval results
ALL = "all" # Reflect from entire database
class ReflectionBaseline(str, Enum): class ReflectionBaseline(str, Enum):
"""反思基线枚举""" """
TIME = "TIME" # 基于时间的反思 Reflection baseline enumeration
FACT = "FACT" # 基于事实的反思
HYBRID = "HYBRID" # 混合反思 Defines the strategy or approach used for reflection operations.
"""
TIME = "TIME" # Time-based reflection
FACT = "FACT" # Fact-based reflection
HYBRID = "HYBRID" # Hybrid reflection combining multiple strategies
class ReflectionConfig(BaseModel): class ReflectionConfig(BaseModel):
"""反思引擎配置""" """
Reflection engine configuration
Defines all configuration parameters for the reflection engine including
operation modes, model settings, and evaluation criteria.
Attributes:
enabled: Whether reflection engine is enabled
iteration_period: Reflection cycle period (e.g., "3" hours)
reflexion_range: Scope of reflection (PARTIAL or ALL)
baseline: Reflection strategy (TIME, FACT, or HYBRID)
model_id: LLM model identifier for reflection operations
end_user_id: User identifier for scoped operations
output_example: Example output format for guidance
memory_verify: Enable memory verification checks
quality_assessment: Enable quality assessment evaluation
violation_handling_strategy: Strategy for handling violations
language_type: Language type for output ("zh" or "en")
"""
enabled: bool = False enabled: bool = False
iteration_period: str = "3" # 反思周期 iteration_period: str = "3" # Reflection cycle period
reflexion_range: ReflectionRange = ReflectionRange.PARTIAL reflexion_range: ReflectionRange = ReflectionRange.PARTIAL
baseline: ReflectionBaseline = ReflectionBaseline.TIME baseline: ReflectionBaseline = ReflectionBaseline.TIME
model_id: Optional[str] = None # 模型ID model_id: Optional[str] = None # Model ID
end_user_id: Optional[str] = None end_user_id: Optional[str] = None
output_example: Optional[str] = None # 输出示例 output_example: Optional[str] = None # Output example
# 评估相关字段 # Evaluation related fields
memory_verify: bool = True # 记忆验证 memory_verify: bool = True # Memory verification
quality_assessment: bool = True # 质量评估 quality_assessment: bool = True # Quality assessment
violation_handling_strategy: str = "warn" # 违规处理策略 violation_handling_strategy: str = "warn" # Violation handling strategy
language_type: str = "zh" language_type: str = "zh"
class Config: class Config:
@@ -85,7 +112,21 @@ class ReflectionConfig(BaseModel):
class ReflectionResult(BaseModel): class ReflectionResult(BaseModel):
"""反思结果""" """
Reflection operation result
Contains comprehensive information about the outcome of a reflection operation
including success status, metrics, and execution details.
Attributes:
success: Whether the reflection operation succeeded
message: Descriptive message about the operation result
conflicts_found: Number of conflicts detected during reflection
conflicts_resolved: Number of conflicts successfully resolved
memories_updated: Number of memory entries updated in database
execution_time: Total time taken for the reflection operation
details: Additional details about the operation (optional)
"""
success: bool success: bool
message: str message: str
conflicts_found: int = 0 conflicts_found: int = 0
@@ -97,9 +138,22 @@ class ReflectionResult(BaseModel):
class ReflectionEngine: class ReflectionEngine:
""" """
自我反思引擎 Self-Reflection Engine
负责执行记忆系统的自我反思,包括冲突检测、冲突解决和记忆更新。 Responsible for executing memory system self-reflection operations including
conflict detection, conflict resolution, and memory updates. Supports multiple
reflection strategies and provides comprehensive result tracking.
The engine can operate in different modes:
- Time-based: Reflects on memories within specific time periods
- Fact-based: Detects and resolves factual conflicts in memories
- Hybrid: Combines multiple reflection strategies
Attributes:
config: Reflection engine configuration
neo4j_connector: Neo4j database connector
llm_client: Language model client for analysis
Various function handlers for data processing and prompt rendering
""" """
def __init__( def __init__(
@@ -115,18 +169,21 @@ class ReflectionEngine:
update_query: Optional[str] = None update_query: Optional[str] = None
): ):
""" """
初始化反思引擎 Initialize reflection engine
Sets up the reflection engine with configuration and optional dependencies.
Uses lazy initialization to avoid circular imports and optimize startup time.
Args: Args:
config: 反思引擎配置 config: Reflection engine configuration object
neo4j_connector: Neo4j 连接器(可选) neo4j_connector: Neo4j connector instance (optional, will be created if not provided)
llm_client: LLM 客户端(可选) llm_client: LLM client instance (optional, will be created if not provided)
get_data_func: 获取数据的函数(可选) get_data_func: Function for retrieving data (optional, uses default if not provided)
render_evaluate_prompt_func: 渲染评估提示词的函数(可选) render_evaluate_prompt_func: Function for rendering evaluation prompts (optional)
render_reflexion_prompt_func: 渲染反思提示词的函数(可选) render_reflexion_prompt_func: Function for rendering reflection prompts (optional)
conflict_schema: 冲突结果 Schema(可选) conflict_schema: Schema for conflict result validation (optional)
reflexion_schema: 反思结果 Schema(可选) reflexion_schema: Schema for reflection result validation (optional)
update_query: 更新查询语句(可选) update_query: Query string for database updates (optional)
""" """
self.config = config self.config = config
self.neo4j_connector = neo4j_connector self.neo4j_connector = neo4j_connector
@@ -137,14 +194,20 @@ class ReflectionEngine:
self.conflict_schema = conflict_schema self.conflict_schema = conflict_schema
self.reflexion_schema = reflexion_schema self.reflexion_schema = reflexion_schema
self.update_query = update_query self.update_query = update_query
self._semaphore = asyncio.Semaphore(5) # 默认并发数为5 self._semaphore = asyncio.Semaphore(5) # Default concurrency limit of 5
# 延迟导入以避免循环依赖 # Lazy import to avoid circular dependencies
self._lazy_init_done = False self._lazy_init_done = False
def _lazy_init(self): def _lazy_init(self):
"""延迟初始化,避免循环导入""" """
Lazy initialization to avoid circular imports
Initializes dependencies only when needed, preventing circular import issues
and optimizing startup performance. Sets up default implementations for
any components not provided during construction.
"""
if self._lazy_init_done: if self._lazy_init_done:
return return
@@ -158,7 +221,7 @@ class ReflectionEngine:
factory = MemoryClientFactory(db) factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(self.config.model_id) self.llm_client = factory.get_llm_client(self.config.model_id)
elif isinstance(self.llm_client, str): elif isinstance(self.llm_client, str):
# 如果 llm_client 是字符串model_id则用它初始化客户端 # If llm_client is a string (model_id), use it to initialize the client
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context from app.db import get_db_context
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
@@ -172,10 +235,10 @@ class ReflectionEngine:
model_config = config_service.get_model_config(model_id) model_config = config_service.get_model_config(model_id)
extra_params={ extra_params={
"temperature": 0.2, # 降低温度提高响应速度和一致性 "temperature": 0.2, # Lower temperature for faster response and consistency
"max_tokens": 600, # 限制最大token数 "max_tokens": 600, # Limit maximum token count
"top_p": 0.8, # 优化采样参数 "top_p": 0.8, # Optimize sampling parameters
"stream": False, # 确保非流式输出以获得最快响应 "stream": False, # Ensure non-streaming output for fastest response
} }
self.llm_client = OpenAIClient(RedBearModelConfig( self.llm_client = OpenAIClient(RedBearModelConfig(
@@ -191,7 +254,7 @@ class ReflectionEngine:
if self.get_data_func is None: if self.get_data_func is None:
self.get_data_func = get_data self.get_data_func = get_data
# 导入get_data_statement函数 # Import get_data_statement function
if not hasattr(self, 'get_data_statement'): if not hasattr(self, 'get_data_statement'):
self.get_data_statement = get_data_statement self.get_data_statement = get_data_statement
@@ -223,13 +286,20 @@ class ReflectionEngine:
async def execute_reflection(self, host_id) -> ReflectionResult: async def execute_reflection(self, host_id) -> ReflectionResult:
""" """
执行完整的反思流程 Execute complete reflection workflow
Performs the full reflection process including data retrieval, conflict detection,
conflict resolution, and memory updates. This is the main entry point for
reflection operations.
Args: Args:
host_id: 主机ID host_id: Host identifier for scoping reflection operations
Returns: Returns:
ReflectionResult: 反思结果 ReflectionResult: Comprehensive result of the reflection operation including
success status, conflict metrics, and execution time
""" """
# 延迟初始化 # Lazy initialization
self._lazy_init() self._lazy_init()
if not self.config.enabled: if not self.config.enabled:
@@ -243,7 +313,7 @@ class ReflectionEngine:
print(self.config.baseline, self.config.memory_verify, self.config.quality_assessment) print(self.config.baseline, self.config.memory_verify, self.config.quality_assessment)
try: try:
# 1. 获取反思数据 # 1. Get reflection data
reflexion_data, statement_databasets = await self._get_reflexion_data(host_id) reflexion_data, statement_databasets = await self._get_reflexion_data(host_id)
if not reflexion_data: if not reflexion_data:
return ReflectionResult( return ReflectionResult(
@@ -252,7 +322,7 @@ class ReflectionEngine:
execution_time=asyncio.get_event_loop().time() - start_time execution_time=asyncio.get_event_loop().time() - start_time
) )
# 2. 检测冲突(基于事实的反思) # 2. Detect conflicts (fact-based reflection)
conflict_data = await self._detect_conflicts(reflexion_data, statement_databasets) conflict_data = await self._detect_conflicts(reflexion_data, statement_databasets)
conflict_list=[] conflict_list=[]
for i in conflict_data: for i in conflict_data:
@@ -261,7 +331,7 @@ class ReflectionEngine:
conflicts_found=0 conflicts_found=0
# 3. 解决冲突 # 3. Resolve conflicts
solved_data = await self._resolve_conflicts(conflict_list, statement_databasets) solved_data = await self._resolve_conflicts(conflict_list, statement_databasets)
if not solved_data: if not solved_data:
@@ -276,7 +346,7 @@ class ReflectionEngine:
logging.info(f"解决了 {conflicts_resolved} 个冲突") logging.info(f"解决了 {conflicts_resolved} 个冲突")
# 4. 应用反思结果(更新记忆库) # 4. Apply reflection results (update memory database)
memories_updated=await self._apply_reflection_results(solved_data) memories_updated=await self._apply_reflection_results(solved_data)
execution_time = asyncio.get_event_loop().time() - start_time execution_time = asyncio.get_event_loop().time() - start_time
@@ -302,7 +372,19 @@ class ReflectionEngine:
) )
async def Translate(self, text): async def Translate(self, text):
# 翻译中文为英文 """
Translate Chinese text to English
Uses the configured LLM to translate Chinese text to English with structured output.
Provides consistent translation format for reflection results.
Args:
text: Chinese text to be translated
Returns:
str: Translated English text
"""
# Translate Chinese to English
translation_messages = [ translation_messages = [
{ {
"role": "user", "role": "user",
@@ -316,6 +398,19 @@ class ReflectionEngine:
) )
return response.data return response.data
async def extract_translation(self,data): async def extract_translation(self,data):
"""
Extract and translate reflection data to English
Processes reflection data structure and translates all Chinese content to English.
Handles nested data structures including memory verifications, quality assessments,
and reflection data while preserving the original structure.
Args:
data: Dictionary containing reflection data with Chinese content
Returns:
dict: Translated data structure with English content
"""
end_datas={} end_datas={}
end_datas['source_data']=await self.Translate(data['source_data']) end_datas['source_data']=await self.Translate(data['source_data'])
quality_assessments = [] quality_assessments = []
@@ -350,6 +445,18 @@ class ReflectionEngine:
return end_datas return end_datas
async def reflection_run(self): async def reflection_run(self):
"""
Execute reflection workflow with comprehensive data processing
Performs a complete reflection operation including conflict detection, resolution,
and result formatting. Supports both Chinese and English output based on
configuration settings.
Returns:
dict: Comprehensive reflection results including source data, memory verifications,
quality assessments, and reflection data. Results are translated to English
if language_type is set to 'en'.
"""
self._lazy_init() self._lazy_init()
start_time = time.time() start_time = time.time()
memory_verifies_flag = self.config.memory_verify memory_verifies_flag = self.config.memory_verify
@@ -367,7 +474,7 @@ class ReflectionEngine:
result_data['source_data'] = "我是 2023 年春天去北京工作的后来基本一直都在北京上班也没怎么换过城市。不过后来公司调整2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X银行卡是 6222023847595898这些一直没变。对了其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合" result_data['source_data'] = "我是 2023 年春天去北京工作的后来基本一直都在北京上班也没怎么换过城市。不过后来公司调整2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X银行卡是 6222023847595898这些一直没变。对了其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合"
# 2. 检测冲突(基于事实的反思) # 2. 检测冲突(基于事实的反思)
conflict_data = await self._detect_conflicts(databasets, source_data) conflict_data = await self._detect_conflicts(databasets, source_data)
# 遍历数据提取字段 # Traverse data to extract fields
quality_assessments = [] quality_assessments = []
memory_verifies = [] memory_verifies = []
for item in conflict_data: for item in conflict_data:
@@ -375,9 +482,9 @@ class ReflectionEngine:
memory_verifies.append(item['memory_verify']) memory_verifies.append(item['memory_verify'])
result_data['memory_verifies'] = memory_verifies result_data['memory_verifies'] = memory_verifies
result_data['quality_assessments'] = quality_assessments result_data['quality_assessments'] = quality_assessments
conflicts_found = 0 # 初始化为整数0而不是空字符串 conflicts_found = 0 # Initialize as integer 0 instead of empty string
REMOVE_KEYS = {"created_at", "expired_at","relationship","predicate","statement_id","id","statement_id","relationship_statement_id"} REMOVE_KEYS = {"created_at", "expired_at","relationship","predicate","statement_id","id","statement_id","relationship_statement_id"}
# Clearn conflict_dataAnd memory_verifyquality_assessment # Clean conflict_data, and memory_verify and quality_assessment
cleaned_conflict_data = [] cleaned_conflict_data = []
for item in conflict_data: for item in conflict_data:
cleaned_item = { cleaned_item = {
@@ -389,7 +496,7 @@ class ReflectionEngine:
for item in conflict_data: for item in conflict_data:
cleaned_data = [] cleaned_data = []
for row in item.get("data", []): for row in item.get("data", []):
# 删除 created_at / expired_at # Remove created_at / expired_at
cleaned_row = { cleaned_row = {
k: v k: v
for k, v in row.items() for k, v in row.items()
@@ -402,7 +509,7 @@ class ReflectionEngine:
} }
cleaned_conflict_data_.append(cleaned_item) cleaned_conflict_data_.append(cleaned_item)
print(cleaned_conflict_data_) print(cleaned_conflict_data_)
# 3. 解决冲突 # 3. Resolve conflicts
solved_data = await self._resolve_conflicts(cleaned_conflict_data_, source_data) solved_data = await self._resolve_conflicts(cleaned_conflict_data_, source_data)
if not solved_data: if not solved_data:
return ReflectionResult( return ReflectionResult(
@@ -413,7 +520,7 @@ class ReflectionEngine:
) )
reflexion_data = [] reflexion_data = []
# 遍历数据提取reflexion字段 # Traverse data to extract reflexion fields
for item in solved_data: for item in solved_data:
if 'results' in item: if 'results' in item:
for result in item['results']: for result in item['results']:
@@ -431,15 +538,24 @@ class ReflectionEngine:
async def extract_fields_from_json(self): async def extract_fields_from_json(self):
"""从example.json中提取source_data和databasets字段""" """
Extract source_data and databasets fields from example.json
Reads reflection example data from the example.json file and extracts
the source data and database statements for testing and demonstration purposes.
Returns:
tuple: (source_data, databasets) extracted from the example file
Returns empty lists if file reading fails
"""
prompt_dir = os.path.join(os.path.dirname(__file__), "example") prompt_dir = os.path.join(os.path.dirname(__file__), "example")
try: try:
# 读取JSON文件 # Read JSON file
with open(prompt_dir + '/example.json', 'r', encoding='utf-8') as f: with open(prompt_dir + '/example.json', 'r', encoding='utf-8') as f:
data = json.loads(f.read()) data = json.loads(f.read())
# 提取memory_verify下的字段 # Extract fields under memory_verify
memory_verify = data.get("memory_verify", {}) memory_verify = data.get("memory_verify", {})
source_data = memory_verify.get("source_data", []) source_data = memory_verify.get("source_data", [])
databasets = memory_verify.get("databasets", []) databasets = memory_verify.get("databasets", [])
@@ -451,15 +567,17 @@ class ReflectionEngine:
async def _get_reflexion_data(self, host_id: uuid.UUID) -> List[Any]: async def _get_reflexion_data(self, host_id: uuid.UUID) -> List[Any]:
""" """
获取反思数据 Get reflection data from database
根据配置的反思范围获取需要反思的记忆数据。 Retrieves memory data for reflection based on the configured reflection range.
Supports both partial (from retrieval results) and full (entire database) modes.
Args: Args:
host_id: 主机ID host_id: Host UUID identifier for scoping data retrieval
Returns: Returns:
List[Any]: 反思数据列表 tuple: (reflexion_data, statement_data) containing memory data for reflection
Returns empty lists if query fails
""" """
print("=== 获取反思数据 ===") print("=== 获取反思数据 ===")
@@ -484,26 +602,29 @@ class ReflectionEngine:
async def _detect_conflicts(self, data: List[Any], statement_databasets: List[Any]) -> List[Any]: async def _detect_conflicts(self, data: List[Any], statement_databasets: List[Any]) -> List[Any]:
""" """
检测冲突(基于事实的反思) Detect conflicts (fact-based reflection)
使用 LLM 分析记忆数据,检测其中的冲突。 Uses LLM to analyze memory data and detect conflicts within the memories.
Performs comprehensive conflict detection including memory verification and
quality assessment based on configuration settings.
Args: Args:
data: 待检测的记忆数据 data: Memory data to be analyzed for conflicts
statement_databasets: Statement database records for context
Returns: Returns:
List[Any]: 冲突记忆列表 List[Any]: List of detected conflicts with detailed analysis
""" """
if not data: if not data:
return [] return []
# 数据预处理:如果数据量太少,直接返回无冲突 # Data preprocessing: if data is too small, return no conflicts directly
if len(data) < 2: if len(data) < 2:
logging.info("数据量不足,无需检测冲突") logging.info("数据量不足,无需检测冲突")
return [] return []
# 使用转换后的数据 # Use converted data
# print("转换后的数据:", data[:2] if len(data) > 2 else data) # 只打印前2条避免日志过长 # print("Converted data:", data[:2] if len(data) > 2 else data) # Only print first 2 to avoid long logs
memory_verify = self.config.memory_verify memory_verify = self.config.memory_verify
logging.info("====== 冲突检测开始 ======") logging.info("====== 冲突检测开始 ======")
@@ -512,7 +633,7 @@ class ReflectionEngine:
language_type=self.config.language_type language_type=self.config.language_type
try: try:
# 渲染冲突检测提示词 # Render conflict detection prompt
rendered_prompt = await self.render_evaluate_prompt_func( rendered_prompt = await self.render_evaluate_prompt_func(
data, data,
self.conflict_schema, self.conflict_schema,
@@ -526,7 +647,7 @@ class ReflectionEngine:
messages = [{"role": "user", "content": rendered_prompt}] messages = [{"role": "user", "content": rendered_prompt}]
logging.info(f"提示词长度: {len(rendered_prompt)}") logging.info(f"提示词长度: {len(rendered_prompt)}")
# 调用 LLM 进行冲突检测 # Call LLM for conflict detection
response = await self.llm_client.response_structured( response = await self.llm_client.response_structured(
messages, messages,
self.conflict_schema self.conflict_schema
@@ -539,7 +660,7 @@ class ReflectionEngine:
logging.error("LLM 冲突检测输出解析失败") logging.error("LLM 冲突检测输出解析失败")
return [] return []
# 标准化返回格式 # Standardize return format
if isinstance(response, BaseModel): if isinstance(response, BaseModel):
return [response.model_dump()] return [response.model_dump()]
elif hasattr(response, 'dict'): elif hasattr(response, 'dict'):
@@ -553,15 +674,17 @@ class ReflectionEngine:
async def _resolve_conflicts(self, conflicts: List[Any], statement_databasets: List[Any]) -> List[Any]: async def _resolve_conflicts(self, conflicts: List[Any], statement_databasets: List[Any]) -> List[Any]:
""" """
解决冲突 Resolve detected conflicts
使用 LLM 对检测到的冲突进行反思和解决。 Uses LLM to perform reflection and resolution on detected conflicts.
Processes conflicts in parallel for efficiency while respecting concurrency limits.
Args: Args:
conflicts: 冲突列表 conflicts: List of conflicts to be resolved
statement_databasets: Statement database records for context
Returns: Returns:
List[Any]: 解决方案列表 List[Any]: List of resolution solutions with reflection results
""" """
if not conflicts: if not conflicts:
return [] return []
@@ -570,12 +693,12 @@ class ReflectionEngine:
baseline = self.config.baseline baseline = self.config.baseline
memory_verify = self.config.memory_verify memory_verify = self.config.memory_verify
# 并行处理每个冲突 # Process each conflict in parallel
async def _resolve_one(conflict: Any) -> Optional[Dict[str, Any]]: async def _resolve_one(conflict: Any) -> Optional[Dict[str, Any]]:
"""解决单个冲突""" """Resolve a single conflict"""
async with self._semaphore: async with self._semaphore:
try: try:
# 渲染反思提示词 # Render reflection prompt
rendered_prompt = await self.render_reflexion_prompt_func( rendered_prompt = await self.render_reflexion_prompt_func(
[conflict], [conflict],
self.reflexion_schema, self.reflexion_schema,
@@ -587,7 +710,7 @@ class ReflectionEngine:
messages = [{"role": "user", "content": rendered_prompt}] messages = [{"role": "user", "content": rendered_prompt}]
# 调用 LLM 进行反思 # Call LLM for reflection
response = await self.llm_client.response_structured( response = await self.llm_client.response_structured(
messages, messages,
self.reflexion_schema self.reflexion_schema
@@ -596,7 +719,7 @@ class ReflectionEngine:
if not response: if not response:
return None return None
# 标准化返回格式 # Standardize return format
if isinstance(response, BaseModel): if isinstance(response, BaseModel):
return response.model_dump() return response.model_dump()
elif hasattr(response, 'dict'): elif hasattr(response, 'dict'):
@@ -610,11 +733,11 @@ class ReflectionEngine:
logging.warning(f"解决单个冲突失败: {e}") logging.warning(f"解决单个冲突失败: {e}")
return None return None
# 并发执行所有冲突解决任务 # Execute all conflict resolution tasks concurrently
tasks = [_resolve_one(conflict) for conflict in conflicts] tasks = [_resolve_one(conflict) for conflict in conflicts]
results = await asyncio.gather(*tasks, return_exceptions=False) results = await asyncio.gather(*tasks, return_exceptions=False)
# 过滤掉失败的结果 # Filter out failed results
solved = [r for r in results if r is not None] solved = [r for r in results if r is not None]
logging.info(f"成功解决 {len(solved)}/{len(conflicts)} 个冲突") logging.info(f"成功解决 {len(solved)}/{len(conflicts)} 个冲突")
@@ -626,15 +749,16 @@ class ReflectionEngine:
solved_data: List[Dict[str, Any]] solved_data: List[Dict[str, Any]]
) -> int: ) -> int:
""" """
应用反思结果(更新记忆库) Apply reflection results (update memory database)
将解决冲突后的记忆更新到 Neo4j 数据库中。 Updates the Neo4j database with resolved conflicts and reflection results.
Processes the solved data and applies changes to the memory storage system.
Args: Args:
solved_data: 解决方案列表 solved_data: List of resolved conflict solutions with reflection data
Returns: Returns:
int: 成功更新的记忆数量 int: Number of successfully updated memory entries
""" """
changes = extract_and_process_changes(solved_data) changes = extract_and_process_changes(solved_data)
success_count = await neo4j_data(changes) success_count = await neo4j_data(changes)
@@ -642,80 +766,86 @@ class ReflectionEngine:
# 基于时间的反思方法 # Time-based reflection methods
async def time_based_reflection( async def time_based_reflection(
self, self,
host_id: uuid.UUID, host_id: uuid.UUID,
time_period: Optional[str] = None time_period: Optional[str] = None
) -> ReflectionResult: ) -> ReflectionResult:
""" """
基于时间的反思 Time-based reflection
根据时间周期触发反思,检查在指定时间段内的记忆。 Triggers reflection based on time cycles, checking memories within
specified time periods. Uses the configured iteration period if
no specific time period is provided.
Args: Args:
host_id: 主机ID host_id: Host UUID identifier for scoping reflection
time_period: 时间周期(如"三小时"),如果不提供则使用配置中的值 time_period: Time period (e.g., "three hours"), uses config value if not provided
Returns: Returns:
ReflectionResult: 反思结果 ReflectionResult: Comprehensive reflection operation result
""" """
period = time_period or self.config.iteration_period period = time_period or self.config.iteration_period
logging.info(f"执行基于时间的反思,周期: {period}") logging.info(f"执行基于时间的反思,周期: {period}")
# 使用标准反思流程 # Use standard reflection workflow
return await self.execute_reflection(host_id) return await self.execute_reflection(host_id)
# 基于事实的反思方法 # Fact-based reflection methods
async def fact_based_reflection( async def fact_based_reflection(
self, self,
host_id: uuid.UUID host_id: uuid.UUID
) -> ReflectionResult: ) -> ReflectionResult:
""" """
基于事实的反思 Fact-based reflection
检测记忆中的事实冲突并解决。 Detects and resolves factual conflicts within memories. Analyzes
memory data for inconsistencies and contradictions that need resolution.
Args: Args:
host_id: 主机ID host_id: Host UUID identifier for scoping reflection
Returns: Returns:
ReflectionResult: 反思结果 ReflectionResult: Comprehensive reflection operation result
""" """
logging.info("执行基于事实的反思") logging.info("执行基于事实的反思")
# 使用标准反思流程 # Use standard reflection workflow
return await self.execute_reflection(host_id) return await self.execute_reflection(host_id)
# 综合反思方法 # Comprehensive reflection methods
async def comprehensive_reflection( async def comprehensive_reflection(
self, self,
host_id: uuid.UUID host_id: uuid.UUID
) -> ReflectionResult: ) -> ReflectionResult:
""" """
综合反思 Comprehensive reflection
整合基于时间和基于事实的反思策略。 Integrates time-based and fact-based reflection strategies based on
the configured baseline. Supports hybrid approaches that combine
multiple reflection methodologies.
Args: Args:
host_id: 主机ID host_id: Host UUID identifier for scoping reflection
Returns: Returns:
ReflectionResult: 反思结果 ReflectionResult: Comprehensive reflection operation result combining
multiple strategies if using hybrid baseline
""" """
logging.info("执行综合反思") logging.info("执行综合反思")
# 根据配置的基线选择反思策略 # Choose reflection strategy based on configured baseline
if self.config.baseline == ReflectionBaseline.TIME: if self.config.baseline == ReflectionBaseline.TIME:
return await self.time_based_reflection(host_id) return await self.time_based_reflection(host_id)
elif self.config.baseline == ReflectionBaseline.FACT: elif self.config.baseline == ReflectionBaseline.FACT:
return await self.fact_based_reflection(host_id) return await self.fact_based_reflection(host_id)
elif self.config.baseline == ReflectionBaseline.HYBRID: elif self.config.baseline == ReflectionBaseline.HYBRID:
# 混合策略:先执行基于时间的反思,再执行基于事实的反思 # Hybrid strategy: execute time-based reflection first, then fact-based reflection
time_result = await self.time_based_reflection(host_id) time_result = await self.time_based_reflection(host_id)
fact_result = await self.fact_based_reflection(host_id) fact_result = await self.fact_based_reflection(host_id)
# 合并结果 # Merge results
return ReflectionResult( return ReflectionResult(
success=time_result.success and fact_result.success, success=time_result.success and fact_result.success,
message=f"时间反思: {time_result.message}; 事实反思: {fact_result.message}", message=f"时间反思: {time_result.message}; 事实反思: {fact_result.message}",

View File

@@ -2,9 +2,17 @@ import json
def escape_lucene_query(query: str) -> str: def escape_lucene_query(query: str) -> str:
"""Escape Lucene special characters in a free-text query. """
Escape special characters in Lucene queries
This prevents ParseException when using Neo4j full-text procedures. Prevents ParseException when using Neo4j full-text search procedures.
Escapes all Lucene reserved special characters and operators.
Args:
query: Original query string
Returns:
str: Escaped query string safe for Lucene search
""" """
if query is None: if query is None:
return "" return ""
@@ -22,11 +30,21 @@ def escape_lucene_query(query: str) -> str:
return s return s
def extract_plain_query(query_input: str) -> str: def extract_plain_query(query_input: str) -> str:
"""Extract clean, plain-text query from various input forms. """
Extract clean plain-text query from various input forms
Handles the following cases:
- Strips surrounding quotes and whitespace - Strips surrounding quotes and whitespace
- If input looks like JSON, prefers the 'original' field - If input looks like JSON, prefers the 'original' field
- Fallbacks to the raw string when parsing fails - Falls back to raw string when parsing fails
- Handles dictionary-type input
- Best-effort unescape common escape characters
Args:
query_input: Query input in various forms (string, dict, etc.)
Returns:
str: Extracted plain-text query string
""" """
if query_input is None: if query_input is None:
return "" return ""

View File

@@ -4,7 +4,13 @@ from datetime import datetime
def validate_date_format(date_str: str) -> bool: def validate_date_format(date_str: str) -> bool:
""" """
Validate if the date string is in the format YYYY-MM-DD. Validate if date string conforms to YYYY-MM-DD format
Args:
date_str: Date string to validate
Returns:
bool: True if format is correct, False otherwise
""" """
pattern = r"^\d{4}-\d{1,2}-\d{1,2}$" pattern = r"^\d{4}-\d{1,2}-\d{1,2}$"
return bool(re.match(pattern, date_str)) return bool(re.match(pattern, date_str))
@@ -41,7 +47,20 @@ def normalize_date(date_str: str) -> str:
def preprocess_date_string(date_str: str) -> str: def preprocess_date_string(date_str: str) -> str:
"""预处理日期字符串,处理特殊格式""" """
预处理日期字符串,处理特殊格式
处理以下特殊格式:
- 年份后直接跟月份没有分隔符的格式(如 "20259/28"
- 无分隔符的纯数字格式(如 "20251028", "251028"
- 混合分隔符,统一为 "-"
Args:
date_str: 原始日期字符串
Returns:
str: 预处理后的日期字符串,格式为 "YYYY-MM-DD""YYYY-MM"
"""
# 处理类似 "20259/28" 的格式(年份后直接跟月份没有分隔) # 处理类似 "20259/28" 的格式(年份后直接跟月份没有分隔)
match = re.match(r'^(\d{4,5})[/\.\-_]?(\d{1,2})[/\.\-_]?(\d{1,2})$', date_str) match = re.match(r'^(\d{4,5})[/\.\-_]?(\d{1,2})[/\.\-_]?(\d{1,2})$', date_str)
@@ -78,7 +97,23 @@ def preprocess_date_string(date_str: str) -> str:
def fallback_parse(date_str: str) -> str: def fallback_parse(date_str: str) -> str:
"""备选解析方案""" """
备选日期解析方案
当智能解析失败时,尝试使用预定义的日期格式进行解析。
支持多种常见的日期格式,包括:
- YYYY-MM-DD, YYYY/MM/DD, YYYY.MM.DD
- YYYYMMDD, YYMMDD
- MM-DD-YYYY, MM/DD/YYYY, MM.DD.YYYY
- DD-MM-YYYY, DD/MM/YYYY, DD.MM.YYYY
- YYYY-MM, YYYY/MM, YYYY.MM
Args:
date_str: 待解析的日期字符串
Returns:
str: 标准化后的日期字符串YYYY-MM-DD格式解析失败时返回原字符串
"""
# 尝试常见的日期格式[citation:4][citation:5] # 尝试常见的日期格式[citation:4][citation:5]
formats_to_try = [ formats_to_try = [

View File

@@ -1,6 +1,7 @@
{# {#
对话级抽取与相关性判定模板(用于剪枝加速) 对话级抽取与相关性判定模板(用于剪枝加速)
输入pruning_scene, is_builtin_scene, ontology_classes, dialog_text, language 输入pruning_scene, ontology_class_infos, dialog_text, language
- ontology_class_infos: List[{class_name: str, class_description: str}]
输出:严格 JSON不要包含任何多余文本字段 输出:严格 JSON不要包含任何多余文本字段
- is_related: bool是否与所选场景相关 - is_related: bool是否与所选场景相关
- times: [string],从对话中抽取的时间相关文本(日期、时间、时间段、有效期等) - times: [string],从对话中抽取的时间相关文本(日期、时间、时间段、有效期等)
@@ -9,64 +10,103 @@
- contacts: [string],联系方式(电话/手机号/邮箱/微信/QQ等 - contacts: [string],联系方式(电话/手机号/邮箱/微信/QQ等
- addresses: [string],地址/地点相关文本 - addresses: [string],地址/地点相关文本
- keywords: [string],其它有助于保留的重要关键词(与场景强相关的术语) - keywords: [string],其它有助于保留的重要关键词(与场景强相关的术语)
- preserve_keywords: [string],必须保留的情绪/兴趣/爱好/个人偏好相关词或短语片段
要求: 要求:
- 必须只输出上述 JSON且键名一致不得输出解释、前后缀不得包含注释。 - 必须只输出上述 JSON且键名一致不得输出解释、前后缀不得包含注释。
- times/ids/amounts/contacts/addresses/keywords 仅抽取原文片段或规范化后的简单字符串。 - times/ids/amounts/contacts/addresses/keywords/preserve_keywords 仅抽取原文片段或规范化后的简单字符串。
- 仅输出上述键;避免多余解释或字段。 - 仅输出上述键;避免多余解释或字段。
#} #}
{# ── 内置场景的固定说明 ── #} {# ── 确定场景说明 ── #}
{% set builtin_scene_instructions = { {% if ontology_class_infos and ontology_class_infos | length > 0 %}
'education': { {% if language == 'en' %}
'zh': '教育场景:教学、课程、考试、作业、老师/学生互动、学习资源、学校管理等。', {% set instruction = 'Scene "' ~ pruning_scene ~ '": The dialogue is relevant if it involves any of the following entity types.' %}
'en': 'Education Scenario: Teaching, courses, exams, homework, teacher/student interaction, learning resources, school management, etc.'
},
'online_service': {
'zh': '在线客服场景:客户咨询、问题排查、服务工单、售后支持、订单/退款、工单升级等。',
'en': 'Online Service Scenario: Customer inquiries, troubleshooting, service tickets, after-sales support, orders/refunds, ticket escalation, etc.'
},
'outbound': {
'zh': '外呼场景:电话外呼、邀约、调研问卷、线索跟进、对话脚本、回访记录等。',
'en': 'Outbound Scenario: Outbound calls, invitations, survey questionnaires, lead follow-up, call scripts, follow-up records, etc.'
}
} %}
{# ── 确定最终使用的场景说明 ── #}
{% if is_builtin_scene %}
{# 内置专门场景:使用固定说明 #}
{% set scene_key = pruning_scene %}
{% if scene_key not in builtin_scene_instructions %}{% set scene_key = 'education' %}{% endif %}
{% set instruction = builtin_scene_instructions[scene_key][language] if language in ['zh', 'en'] else builtin_scene_instructions[scene_key]['zh'] %}
{% set custom_types_str = '' %}
{% else %}
{# 自定义场景:使用场景名称 + 本体类型列表构建说明 #}
{% if ontology_classes and ontology_classes | length > 0 %}
{% if language == 'en' %}
{% set custom_types_str = ontology_classes | join(', ') %}
{% set instruction = 'Custom scene "' ~ pruning_scene ~ '": The dialogue is related to this scene if it involves any of the following entity types: ' ~ custom_types_str ~ '.' %}
{% else %}
{% set custom_types_str = ontology_classes | join('、') %}
{% set instruction = '自定义场景「' ~ pruning_scene ~ '」:对话涉及以下任意实体类型时视为相关:' ~ custom_types_str ~ '。' %}
{% endif %}
{% else %} {% else %}
{# 无本体类型时退化为通用说明 #} {% set instruction = '场景「' ~ pruning_scene ~ '」:对话涉及以下任意实体类型时视为相关。' %}
{% if language == 'en' %} {% endif %}
{% set instruction = 'Custom scene "' ~ pruning_scene ~ '": Determine whether the dialogue content is relevant to this scene based on overall context.' %} {% else %}
{% else %} {% if language == 'en' %}
{% set instruction = '自定义场景「' ~ pruning_scene ~ '」:根据对话整体内容判断是否与该场景相关。' %} {% set instruction = 'Scene "' ~ pruning_scene ~ '": Determine whether the dialogue content is relevant to this scene based on overall context.' %}
{% endif %} {% else %}
{% set custom_types_str = '' %} {% set instruction = '场景「' ~ pruning_scene ~ '」:根据对话整体内容判断是否与该场景相关。' %}
{% endif %} {% endif %}
{% endif %} {% endif %}
{% if language == "zh" %} {% if language == "zh" %}
请在下方对话全文基础上,按该场景进行一次性抽取并判定相关性 你是一个对话内容分析助手。请对下方对话全文进行一次性分析,完成两项任务
1. 判断对话是否与指定场景相关;
2. 从对话中抽取所有需要保留的重要信息片段。
场景说明:{{ instruction }} 场景说明:{{ instruction }}
{% if not is_builtin_scene and custom_types_str %}
重要提示:只要对话中出现与上述实体类型({{ custom_types_str }}相关的内容即判定为相关is_related=true {% if ontology_class_infos and ontology_class_infos | length > 0 %}
【本场景实体类型定义】
以下实体类型定义了本场景中哪些内容是重要的。
凡是与以下任意类型相关的内容,都必须保留,并将关键词/短语提取到 keywords 字段:
{% for info in ontology_class_infos %}
- {{ info.class_name }}{{ info.class_description }}
{% endfor %}
重要提示只要对话中出现与上述任意实体类型相关的内容即判定为相关is_related=true
{% endif %} {% endif %}
---
【必须保留的内容(不可删除)】
以下类型的内容无论是否与场景直接相关,都必须保留,请将其关键词/短语抽取到对应字段:
- 时间信息:日期、时间点、时间段、有效期 → times 字段
- 编号信息学号、工号、订单号、申请号、账号、ID → ids 字段
- 金额信息:价格、费用、金额(含货币符号或单位,如"100元"、"¥200")→ amounts 字段(注意:考试分数、成绩分数不属于金额,不要放入此字段)
- 联系方式电话、手机号、邮箱、微信、QQ → contacts 字段
- 地址信息:地点、地址、位置 → addresses 字段
- 场景关键词:与**当前场景**强相关的专业术语、事件名称 → keywords 字段(注意:只放与当前场景直接相关的词,跨场景的内容不要放入此字段)
- **情绪与情感**:喜悦、悲伤、愤怒、焦虑、开心、难过、委屈、兴奋、害怕、担心、压力、感动等情绪表达 → preserve_keywords 字段
- **兴趣与爱好**:喜欢、热爱、爱好、擅长、享受、沉迷、着迷、讨厌某事物等个人偏好表达 → preserve_keywords 字段
- **个人情感态度**:对人际关系、情感状态的明确表达(如"我跟室友闹矛盾了"、"我都快抑郁了")→ preserve_keywords 字段
- 注意:学业目标(如"我想考研")、成绩(如"87分")、学科偏好(如"喜欢数学")属于学业信息,不属于情绪/情感,不要放入 preserve_keywords 字段
【场景无关内容标记】
请从对话中识别出与当前场景({{ pruning_scene }}**既不相关、也无语义关联**的消息片段,将其原文(或关键片段)提取到 scene_unrelated_snippets 字段。
判断标准:
- 与场景实体类型完全无关
- 与场景话题没有因果/时间/情境上的关联(例如:不是"因为上课所以累"这种关联)
- 纯粹是另一个话题的内容(如在教育场景中讨论购物、娱乐等)
注意:有情绪/感受表达的消息即使话题不同,也可能有语义关联,请谨慎标记。
**重要scene_unrelated_snippets 必须认真填写,不能为空数组。**
如果对话中存在与场景无关的内容,必须将其原文片段提取出来。
示例(场景=在线教育):
- "我最近心情很差,跟室友闹矛盾了" → 与教育场景无关,加入 scene_unrelated_snippets
- "她总是很晚回来吵到我睡觉" → 与教育场景无关,加入 scene_unrelated_snippets
- "对,我都快抑郁了" → 与教育场景无关,加入 scene_unrelated_snippets
- "期末考试12月25日" → 与教育场景相关,不加入 scene_unrelated_snippets
- "我上次高数作业87分" → 与教育场景相关,不加入 scene_unrelated_snippets
- "我的目标是考研" → 与教育场景相关,不加入 scene_unrelated_snippets
示例(场景=情感陪伴):
- "我最近心情很差,跟室友闹矛盾了" → 与情感陪伴场景相关(情绪+关系),不加入 scene_unrelated_snippets
- "对,我都快抑郁了" → 与情感陪伴场景相关(情绪),不加入 scene_unrelated_snippets
- "期末考试12月25日3号教学楼201室" → 与情感陪伴场景无关(教育信息),加入 scene_unrelated_snippets
- "我上次高数作业87分这次能考好吗" → 与情感陪伴场景无关(学业信息),加入 scene_unrelated_snippets
- "我的目标是考研,想读应用数学" → 与情感陪伴场景无关(学业目标),加入 scene_unrelated_snippets
【可以删除的内容】
以下类型的内容属于低价值信息,可以在剪枝时删除:
- 纯寒暄问候:如"你好"、"在吗"、"拜拜"、"嗯"、"好的"、"哦"等无实质内容的短语
- 纯表情/符号:如"[微笑]"、"😊"、"哈哈"等
- 重复确认:如"对对对"、"是的是的"、"嗯嗯嗯"等无新增信息的重复
- 无意义填充:如"啊"、"呢"、"嘛"等语气词单独成句
**注意:即使消息很短,只要包含情绪、兴趣、爱好、个人观点等有价值信息,就必须保留,不得删除。**
例如:
- "我好开心呀" → 包含情绪开心必须保留preserve_keywords 中加入"开心"
- "好喜欢打羽毛球呀" → 包含兴趣爱好喜欢打羽毛球必须保留preserve_keywords 中加入"喜欢打羽毛球"
- "我好难过" → 包含情绪难过必须保留preserve_keywords 中加入"难过"
- "太好啦!看到你开心,我也跟着心情亮起来" → 包含情绪必须保留preserve_keywords 中加入"开心"
---
对话全文: 对话全文:
""" """
{{ dialog_text }} {{ dialog_text }}
@@ -80,15 +120,65 @@
"amounts": [<string>...], "amounts": [<string>...],
"contacts": [<string>...], "contacts": [<string>...],
"addresses": [<string>...], "addresses": [<string>...],
"keywords": [<string>...] "keywords": [<string>...],
"preserve_keywords": [<string>...],
"scene_unrelated_snippets": [<string>...]
} }
{% else %} {% else %}
Based on the full dialogue below, perform one-time extraction and relevance determination according to this scenario: You are a dialogue content analysis assistant. Please analyze the full dialogue below in one pass and complete two tasks:
1. Determine whether the dialogue is relevant to the specified scene;
2. Extract all important information fragments that must be preserved.
Scenario Description: {{ instruction }} Scenario Description: {{ instruction }}
{% if not is_builtin_scene and custom_types_str %}
Important: If the dialogue contains content related to any of the entity types above ({{ custom_types_str }}), mark it as relevant (is_related=true). {% if ontology_class_infos and ontology_class_infos | length > 0 %}
[Scene Entity Type Definitions]
The following entity types define what content is important in this scene.
Content related to ANY of these types must be preserved and extracted into the keywords field:
{% for info in ontology_class_infos %}
- {{ info.class_name }}: {{ info.class_description }}
{% endfor %}
Important: If the dialogue contains content related to any of the entity types above, mark it as relevant (is_related=true).
{% endif %} {% endif %}
---
[MUST PRESERVE (cannot be deleted)]
The following types of content must always be preserved regardless of scene relevance. Extract their keywords/phrases into the corresponding fields:
- Time information: dates, time points, durations, expiry dates → times field
- ID information: student IDs, employee IDs, order numbers, application numbers, account IDs → ids field
- Amount information: prices, fees, amounts (with currency symbols or units, e.g., "$100", "¥200") → amounts field (Note: exam scores and grades are NOT amounts, do not put them here)
- Contact information: phone numbers, emails, WeChat, QQ → contacts field
- Address information: locations, addresses, places → addresses field
- Scene keywords: professional terms and event names strongly related to **the current scene** → keywords field (Note: only put terms directly related to the current scene; cross-scene content should not be placed here)
- **Emotions and feelings**: joy, sadness, anger, anxiety, happiness, sadness, excitement, fear, worry, stress, being moved, etc. → preserve_keywords field
- **Interests and hobbies**: likes, loves, hobbies, good at, enjoys, obsessed with, hates something, personal preferences → preserve_keywords field
- **Personal emotional attitudes**: clear expressions about interpersonal relationships or emotional states (e.g., "I had a fight with my roommate", "I'm almost depressed") → preserve_keywords field
- Note: Academic goals (e.g., "I want to pursue a master's degree"), grades (e.g., "87 points"), and subject preferences (e.g., "I like math") are academic information, NOT emotions/feelings — do not put them in preserve_keywords
[Scene-Unrelated Content Marking]
Please identify message snippets in the dialogue that are **neither relevant to nor semantically associated with** the current scene ({{ pruning_scene }}), and extract their original text (or key fragments) into the scene_unrelated_snippets field.
Criteria:
- Completely unrelated to the scene's entity types
- No causal/temporal/contextual association with the scene topic (e.g., "feeling tired because of class" IS associated)
- Purely belongs to a different topic (e.g., discussing shopping or entertainment in an education scene)
Note: Messages with emotional/feeling expressions may still have semantic association even if the topic differs — mark carefully.
[CAN BE DELETED]
The following types of content are low-value and can be removed during pruning:
- Pure greetings: e.g., "hello", "are you there", "bye", "ok", "yeah" — short phrases with no substantive content
- Pure emojis/symbols: e.g., "[smile]", "😊", "haha"
- Repetitive confirmations: e.g., "yes yes yes", "right right", "uh huh" — repetitions with no new information
- Meaningless fillers: standalone interjections like "ah", "well", "hmm"
**Note: Even if a message is short, if it contains emotions, interests, hobbies, or personal opinions, it MUST be preserved.**
Examples:
- "I'm so happy!" → contains emotion (happy), must preserve; add "happy" to preserve_keywords
- "I love playing badminton!" → contains interest (love playing badminton), must preserve; add "love playing badminton" to preserve_keywords
- "I feel so sad" → contains emotion (sad), must preserve; add "sad" to preserve_keywords
---
Full Dialogue: Full Dialogue:
""" """
{{ dialog_text }} {{ dialog_text }}
@@ -102,6 +192,8 @@ Output strict JSON only (fixed keys, order doesn't matter):
"amounts": [<string>...], "amounts": [<string>...],
"contacts": [<string>...], "contacts": [<string>...],
"addresses": [<string>...], "addresses": [<string>...],
"keywords": [<string>...] "keywords": [<string>...],
"preserve_keywords": [<string>...],
"scene_unrelated_snippets": [<string>...]
} }
{% endif %} {% endif %}

View File

@@ -2,15 +2,15 @@ import os
from jinja2 import Environment, FileSystemLoader from jinja2 import Environment, FileSystemLoader
from typing import List, Dict, Any from typing import List, Dict, Any
# Setup Jinja2 environment # Setup Jinja2 environment
prompt_dir = os.path.join(os.path.dirname(__file__), "prompts") prompt_dir = os.path.join(os.path.dirname(__file__), "prompts")
prompt_env = Environment(loader=FileSystemLoader(prompt_dir)) prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any, async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
baseline: str = "TIME", baseline: str = "TIME",
memory_verify: bool = False,quality_assessment:bool = False, memory_verify: bool = False, quality_assessment: bool = False,
statement_databasets: List[str] = [],language_type:str = "zh") -> str: statement_databasets=None, language_type: str = "zh") -> str:
""" """
Renders the evaluate prompt using the evaluate_optimized.jinja2 template. Renders the evaluate prompt using the evaluate_optimized.jinja2 template.
@@ -23,6 +23,8 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
Returns: Returns:
Rendered prompt content as string Rendered prompt content as string
""" """
if statement_databasets is None:
statement_databasets = []
template = prompt_env.get_template("evaluate.jinja2") template = prompt_env.get_template("evaluate.jinja2")
# Convert Pydantic model to JSON schema if needed # Convert Pydantic model to JSON schema if needed
@@ -46,7 +48,7 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: str, memory_verify: bool = False, async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: str, memory_verify: bool = False,
statement_databasets: List[str] = [],language_type:str = "zh") -> str: statement_databasets=None, language_type: str = "zh") -> str:
""" """
Renders the reflexion prompt using the reflexion_optimized.jinja2 template. Renders the reflexion prompt using the reflexion_optimized.jinja2 template.
@@ -58,6 +60,8 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: s
Returns: Returns:
Rendered prompt content as a string. Rendered prompt content as a string.
""" """
if statement_databasets is None:
statement_databasets = []
template = prompt_env.get_template("reflexion.jinja2") template = prompt_env.get_template("reflexion.jinja2")
# Convert Pydantic model to JSON schema if needed # Convert Pydantic model to JSON schema if needed
@@ -69,7 +73,7 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: s
json_schema = schema json_schema = schema
rendered_prompt = template.render(data=data, json_schema=json_schema, rendered_prompt = template.render(data=data, json_schema=json_schema,
baseline=baseline,memory_verify=memory_verify, baseline=baseline, memory_verify=memory_verify,
statement_databasets=statement_databasets,language_type=language_type) statement_databasets=statement_databasets, language_type=language_type)
return rendered_prompt return rendered_prompt

View File

@@ -1,23 +1,19 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import os import os
import time from typing import Any, Dict, Optional, TypeVar
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, TypeVar from langchain_aws import ChatBedrock
from langchain_community.chat_models import ChatTongyi
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLLM
from langchain_ollama import OllamaLLM
from langchain_openai import ChatOpenAI, OpenAI
from pydantic import BaseModel, Field
import httpx
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.models.models_model import ModelProvider, ModelType from app.models.models_model import ModelProvider, ModelType
from langchain_community.document_compressors import JinaRerank
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel, BaseLLM
from langchain_core.outputs import Generation, LLMResult
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import RunnableSerializable
from pydantic import BaseModel, Field
T = TypeVar("T") T = TypeVar("T")
@@ -163,25 +159,17 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy
# dashscope 的 omni 模型使用 OpenAI 兼容模式 # dashscope 的 omni 模型使用 OpenAI 兼容模式
if provider == ModelProvider.DASHSCOPE and config.is_omni: if provider == ModelProvider.DASHSCOPE and config.is_omni:
from langchain_openai import ChatOpenAI
return ChatOpenAI return ChatOpenAI
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
if type == ModelType.LLM: if type == ModelType.LLM:
from langchain_openai import OpenAI
return OpenAI return OpenAI
elif type == ModelType.CHAT: elif type == ModelType.CHAT:
from langchain_openai import ChatOpenAI
return ChatOpenAI return ChatOpenAI
elif provider == ModelProvider.DASHSCOPE: elif provider == ModelProvider.DASHSCOPE:
from langchain_community.chat_models import ChatTongyi
return ChatTongyi return ChatTongyi
elif provider == ModelProvider.OLLAMA: elif provider == ModelProvider.OLLAMA:
from langchain_ollama import OllamaLLM
return OllamaLLM return OllamaLLM
elif provider == ModelProvider.BEDROCK: elif provider == ModelProvider.BEDROCK:
from langchain_aws import ChatBedrock, ChatBedrockConverse
return ChatBedrock return ChatBedrock
else: else:
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)

View File

@@ -94,72 +94,16 @@ def knowledge_retrieval(
db_knowledge = knowledge_repository.get_knowledge_by_id(db, knowledge_id=kb_id) db_knowledge = knowledge_repository.get_knowledge_by_id(db, knowledge_id=kb_id)
if db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1: if db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1:
# Process shared knowledge base # Process shared knowledge base
if db_knowledge.permission_id.lower() == knowledge_model.PermissionType.Share: rs, chat_model, embedding_model = _retrieve_for_knowledge(
knowledgeshare = knowledgeshare_repository.get_knowledgeshare_by_id(db=db, db=db,
knowledgeshare_id=db_knowledge.id) db_knowledge=db_knowledge,
if knowledgeshare: kb_config={**kb_config, "query": query}, # 或改为单独参数
db_knowledge = knowledge_repository.get_knowledge_by_id(db, file_names_filter=file_names_filter,
knowledge_id=knowledgeshare.source_kb_id) chat_model=chat_model,
if not (db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1): embedding_model=embedding_model,
continue kb_ids=kb_ids,
else: workspace_ids=workspace_ids,
continue )
if str(db_knowledge.id) not in kb_ids:
kb_ids.append(str(db_knowledge.id))
if str(db_knowledge.workspace_id) not in workspace_ids:
workspace_ids.append(str(db_knowledge.workspace_id))
if not chat_model:
chat_model = Base(
key=db_knowledge.llm.api_keys[0].api_key,
model_name=db_knowledge.llm.api_keys[0].model_name,
base_url=db_knowledge.llm.api_keys[0].api_base
)
if not embedding_model:
embedding_model = OpenAIEmbed(
key=db_knowledge.embedding.api_keys[0].api_key,
model_name=db_knowledge.embedding.api_keys[0].model_name,
base_url=db_knowledge.embedding.api_keys[0].api_base
)
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
# Retrieve according to the configured retrieval type
match kb_config["retrieve_type"]:
case "participle":
rs = vector_service.search_by_full_text(
query=query,
top_k=kb_config["top_k"],
score_threshold=kb_config["similarity_threshold"],
file_names_filter=file_names_filter
)
case "semantic":
rs = vector_service.search_by_vector(
query=query,
top_k=kb_config["top_k"],
score_threshold=kb_config["vector_similarity_weight"],
file_names_filter=file_names_filter
)
case _: # hybrid
rs1 = vector_service.search_by_vector(
query=query,
top_k=kb_config["top_k"],
score_threshold=kb_config["vector_similarity_weight"],
file_names_filter=file_names_filter
)
rs2 = vector_service.search_by_full_text(
query=query,
top_k=kb_config["top_k"],
score_threshold=kb_config["similarity_threshold"],
file_names_filter=file_names_filter
)
# Deduplication of merge results
seen_ids = set()
unique_rs = []
for doc in rs1 + rs2:
if doc.metadata["doc_id"] not in seen_ids:
seen_ids.add(doc.metadata["doc_id"])
unique_rs.append(doc)
rs = unique_rs
all_results.extend(rs) all_results.extend(rs)
except Exception as e: except Exception as e:
@@ -199,6 +143,115 @@ def knowledge_retrieval(
finally: finally:
db.close() db.close()
def _retrieve_for_knowledge(
db: Session,
db_knowledge,
kb_config: Dict[str, Any],
file_names_filter: list[str],
chat_model: Base | None,
embedding_model: OpenAIEmbed | None,
kb_ids: list[str],
workspace_ids: list[str],
) -> tuple[list[DocumentChunk], Base | None, OpenAIEmbed | None]:
"""
对单个知识库进行检索。
- 处理共享知识库
- 如果是 Folder则递归检索其子知识库
- 返回本知识库(含子库)的检索结果和可能更新后的 chat_model/embedding_model
"""
results: list[DocumentChunk] = []
# 处理共享知识库
if db_knowledge.permission_id.lower() == knowledge_model.PermissionType.Share:
knowledgeshare = knowledgeshare_repository.get_knowledgeshare_by_id(db=db, knowledgeshare_id=db_knowledge.id)
if not knowledgeshare:
return results, chat_model, embedding_model
db_knowledge = knowledge_repository.get_knowledge_by_id(db, knowledge_id=knowledgeshare.source_kb_id)
if not (db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1):
return results, chat_model, embedding_model
# Folder 类型:递归处理子知识库
if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER:
children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id)
for child in children:
if not (child and child.chunk_num > 0 and child.status == 1):
continue
# 递归处理子知识库(子库如果还是 Folder会继续往下
child_results, chat_model, embedding_model = _retrieve_for_knowledge(
db=db,
db_knowledge=child,
kb_config=kb_config,
file_names_filter=file_names_filter,
chat_model=chat_model,
embedding_model=embedding_model,
kb_ids=kb_ids,
workspace_ids=workspace_ids,
)
results.extend(child_results)
return results, chat_model, embedding_model
# 普通知识库,执行一次检索
if str(db_knowledge.id) not in kb_ids:
kb_ids.append(str(db_knowledge.id))
if str(db_knowledge.workspace_id) not in workspace_ids:
workspace_ids.append(str(db_knowledge.workspace_id))
if not chat_model:
chat_model = Base(
key=db_knowledge.llm.api_keys[0].api_key,
model_name=db_knowledge.llm.api_keys[0].model_name,
base_url=db_knowledge.llm.api_keys[0].api_base,
)
if not embedding_model:
embedding_model = OpenAIEmbed(
key=db_knowledge.embedding.api_keys[0].api_key,
model_name=db_knowledge.embedding.api_keys[0].model_name,
base_url=db_knowledge.embedding.api_keys[0].api_base,
)
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
match kb_config["retrieve_type"]:
case "participle":
rs = vector_service.search_by_full_text(
query=kb_config["query"], # 或者直接把 query 作为额外参数传进来
top_k=kb_config["top_k"],
score_threshold=kb_config["similarity_threshold"],
file_names_filter=file_names_filter,
)
case "semantic":
rs = vector_service.search_by_vector(
query=kb_config["query"],
top_k=kb_config["top_k"],
score_threshold=kb_config["vector_similarity_weight"],
file_names_filter=file_names_filter,
)
case _:
rs1 = vector_service.search_by_vector(
query=kb_config["query"],
top_k=kb_config["top_k"],
score_threshold=kb_config["vector_similarity_weight"],
file_names_filter=file_names_filter,
)
rs2 = vector_service.search_by_full_text(
query=kb_config["query"],
top_k=kb_config["top_k"],
score_threshold=kb_config["similarity_threshold"],
file_names_filter=file_names_filter,
)
# 合并去重
seen_ids = set()
unique_rs = []
for doc in rs1 + rs2:
if doc.metadata["doc_id"] not in seen_ids:
seen_ids.add(doc.metadata["doc_id"])
unique_rs.append(doc)
rs = unique_rs
results.extend(rs)
return results, chat_model, embedding_model
def rerank(db: Session, reranker_id: uuid, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]: def rerank(db: Session, reranker_id: uuid, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]:
""" """

View File

@@ -4,11 +4,12 @@ RAG chunk analysis utilities.
from .chunk_summary import generate_chunk_summary from .chunk_summary import generate_chunk_summary
from .chunk_tags import extract_chunk_tags, extract_chunk_persona from .chunk_tags import extract_chunk_tags, extract_chunk_persona
from .chunk_insight import generate_chunk_insight from .chunk_insight import generate_chunk_insight, generate_chunk_insight_sections
__all__ = [ __all__ = [
"generate_chunk_summary", "generate_chunk_summary",
"extract_chunk_tags", "extract_chunk_tags",
"extract_chunk_persona", "extract_chunk_persona",
"generate_chunk_insight", "generate_chunk_insight",
"generate_chunk_insight_sections",
] ]

View File

@@ -1,213 +1,207 @@
""" """
Generate insights from RAG chunks. Generate memory insight report for RAG chunks using memory_insight.jinja2 prompt template.
This module provides functionality to analyze chunk content and generate insights using LLM. The memory_insight.jinja2 template produces a four-section report:
【总体概述】 → memory_insight
【行为模式】 → behavior_pattern
【关键发现】 → key_findings
【成长轨迹】 → growth_trajectory
generate_chunk_insight() returns the full raw text (stored in end_user.memory_insight).
generate_chunk_insight_sections() returns a dict with all four fields for richer storage.
""" """
import asyncio import asyncio
import os
import re
from collections import Counter from collections import Counter
from typing import Any, Dict, List from typing import Dict, List, Optional
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context from app.db import get_db_context
from pydantic import BaseModel, Field
business_logger = get_business_logger() business_logger = get_business_logger()
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
def _get_llm_client():
"""Get LLM client using db context.""" # ── LLM client helper ────────────────────────────────────────────────────────
def _get_llm_client(end_user_id: Optional[str] = None):
"""Get LLM client, preferring user-connected config with fallback to default."""
with get_db_context() as db: with get_db_context() as db:
try:
if end_user_id:
from app.services.memory_agent_service import get_end_user_connected_config
from app.services.memory_config_service import MemoryConfigService
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
workspace_id = connected_config.get("workspace_id")
if config_id or workspace_id:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id
)
factory = MemoryClientFactory(db)
return factory.get_llm_client(memory_config.llm_model_id)
except Exception as e:
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
factory = MemoryClientFactory(db) factory = MemoryClientFactory(db)
return factory.get_llm_client(None) # Uses default LLM return factory.get_llm_client(DEFAULT_LLM_ID)
class ChunkInsight(BaseModel): # ── Domain analysis helpers (kept for building prompt inputs) ─────────────────
"""Pydantic model for chunk insight."""
insight: str = Field(..., description="对chunk内容的深度洞察分析")
async def _classify_domain(chunk: str, llm_client) -> str:
"""Classify a single chunk into a domain category."""
from pydantic import BaseModel, Field
class DomainClassification(BaseModel): class _Domain(BaseModel):
"""Pydantic model for domain classification.""" domain: str = Field(..., description="领域分类")
domain: str = Field(
...,
description="内容所属的领域分类",
examples=["技术", "商业", "教育", "生活", "娱乐", "健康", "其他"]
)
async def classify_chunk_domain(chunk: str) -> str:
"""
Classify a chunk into a specific domain.
Args:
chunk: Chunk content string
Returns:
Domain name
"""
try: try:
llm_client = _get_llm_client() prompt = (
"请将以下文本归类到最合适的领域(技术/商业/教育/生活/娱乐/健康/其他)。\n\n"
prompt = f"""请将以下文本内容归类到最合适的领域中。 f"文本: {chunk[:500]}\n\n直接返回领域名称。"
可选领域及其关键词:
- 技术:编程、软件、硬件、算法、数据、网络、系统、开发、工程等
- 商业:市场、销售、管理、财务、投资、创业、营销、战略等
- 教育:学习、课程、培训、教学、知识、技能、考试、研究等
- 生活:日常、家庭、饮食、购物、旅行、休闲、娱乐等
- 娱乐:游戏、电影、音乐、体育、艺术、文化等
- 健康:医疗、养生、运动、心理、保健、疾病等
- 其他:无法归入以上类别的内容
文本内容: {chunk[:500]}...
请直接返回最合适的领域名称。"""
messages = [
{"role": "system", "content": "你是一个专业的文本分类助手。请仔细分析文本内容,选择最合适的领域分类。"},
{"role": "user", "content": prompt}
]
classification = await llm_client.response_structured(
messages=messages,
response_model=DomainClassification
) )
result = await llm_client.response_structured(
return classification.domain if classification else "其他" messages=[{"role": "user", "content": prompt}],
response_model=_Domain,
except Exception as e: )
business_logger.error(f"分类chunk领域失败: {str(e)}") return result.domain if result else "其他"
except Exception:
return "其他" return "其他"
async def analyze_domain_distribution(chunks: List[str], max_chunks: int = 20) -> Dict[str, float]: async def _build_insight_inputs(
chunks: List[str],
max_chunks: int,
end_user_id: Optional[str],
) -> Dict[str, Optional[str]]:
""" """
Analyze the domain distribution of chunks. Derive domain_distribution, active_periods, social_connections strings
to feed into the memory_insight.jinja2 template.
Args:
chunks: List of chunk content strings
max_chunks: Maximum number of chunks to analyze
Returns:
Dictionary of domain -> percentage
""" """
if not chunks: llm_client = _get_llm_client(end_user_id)
return {} chunks_sample = chunks[:max_chunks]
try: # Domain distribution
# 限制分析的chunk数量 domain_counts: Counter = Counter()
chunks_to_analyze = chunks[:max_chunks] for chunk in chunks_sample:
domain = await _classify_domain(chunk, llm_client)
domain_counts[domain] += 1
# 为每个chunk分类 total = sum(domain_counts.values()) or 1
domain_counts = Counter() domain_distribution = ", ".join(
for chunk in chunks_to_analyze: f"{d}({c / total:.0%})" for d, c in domain_counts.most_common(3)
domain = await classify_chunk_domain(chunk) )
domain_counts[domain] += 1
# 计算百分比 return {
total = sum(domain_counts.values()) "domain_distribution": domain_distribution,
domain_distribution = { "active_periods": None, # RAG模式暂无时间维度数据
domain: count / total "social_connections": None, # RAG模式暂无社交关联数据
for domain, count in domain_counts.items() }
}
# 按百分比降序排序
return dict(sorted(domain_distribution.items(), key=lambda x: x[1], reverse=True))
except Exception as e:
business_logger.error(f"分析领域分布失败: {str(e)}")
return {}
async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15) -> str: # ── Section parser ────────────────────────────────────────────────────────────
_ZH_SECTIONS = {
"memory_insight": r"【总体概述】(.*?)(?=【|$)",
"behavior_pattern": r"【行为模式】(.*?)(?=【|$)",
"key_findings": r"【关键发现】(.*?)(?=【|$)",
"growth_trajectory": r"【成长轨迹】(.*?)(?=【|$)",
}
_EN_SECTIONS = {
"memory_insight": r"【Overview】(.*?)(?=【|$)",
"behavior_pattern": r"【Behavior Pattern】(.*?)(?=【|$)",
"key_findings": r"【Key Findings】(.*?)(?=【|$)",
"growth_trajectory": r"【Growth Trajectory】(.*?)(?=【|$)",
}
def _parse_sections(text: str, language: str = "zh") -> Dict[str, str]:
"""Extract the four sections from the LLM output."""
patterns = _ZH_SECTIONS if language == "zh" else _EN_SECTIONS
result = {}
for key, pattern in patterns.items():
match = re.search(pattern, text, re.DOTALL)
result[key] = match.group(1).strip() if match else ""
return result
# ── Public API ────────────────────────────────────────────────────────────────
async def generate_chunk_insight(
chunks: List[str],
max_chunks: int = 15,
end_user_id: Optional[str] = None,
language: str = "zh",
) -> str:
""" """
Generate insights from the given chunks. Generate a memory insight report from RAG chunks.
Args: Returns the full raw report text (suitable for end_user.memory_insight).
chunks: List of chunk content strings Use generate_chunk_insight_sections() when you need all four dimensions.
max_chunks: Maximum number of chunks to analyze """
sections = await generate_chunk_insight_sections(
chunks=chunks,
max_chunks=max_chunks,
end_user_id=end_user_id,
language=language,
)
return sections.get("memory_insight") or sections.get("_raw", "洞察生成失败")
Returns:
A comprehensive insight report async def generate_chunk_insight_sections(
chunks: List[str],
max_chunks: int = 15,
end_user_id: Optional[str] = None,
language: str = "zh",
) -> Dict[str, str]:
"""
Generate a four-section memory insight report from RAG chunks.
Returns a dict with keys:
memory_insight, behavior_pattern, key_findings, growth_trajectory
(plus '_raw' containing the full LLM output for debugging)
""" """
if not chunks: if not chunks:
business_logger.warning("没有提供chunk内容用于生成洞察") business_logger.warning("没有提供chunk内容用于生成洞察")
return "暂无足够数据生成洞察报告" empty = {k: "" for k in ("memory_insight", "behavior_pattern", "key_findings", "growth_trajectory")}
empty["_raw"] = "暂无足够数据生成洞察报告"
return empty
try: try:
# 1. 分析领域分布 from app.core.memory.utils.prompt.prompt_utils import render_memory_insight_prompt
domain_dist = await analyze_domain_distribution(chunks, max_chunks=max_chunks)
# 2. 统计基本信息 # Build template inputs from chunk analysis
total_chunks = len(chunks) inputs = await _build_insight_inputs(chunks, max_chunks, end_user_id)
avg_length = sum(len(chunk) for chunk in chunks) / total_chunks if total_chunks > 0 else 0
# 3. 构建洞察prompt rendered_prompt = await render_memory_insight_prompt(
prompt_parts = [] domain_distribution=inputs["domain_distribution"],
active_periods=inputs["active_periods"],
social_connections=inputs["social_connections"],
language=language,
)
if domain_dist: messages = [{"role": "user", "content": rendered_prompt}]
top_domains = ", ".join([f"{k}({v:.0%})" for k, v in list(domain_dist.items())[:3]]) llm_client = _get_llm_client(end_user_id)
prompt_parts.append(f"- 内容领域分布: {top_domains}")
prompt_parts.append(f"- 内容规模: 共{total_chunks}个知识片段,平均长度{avg_length:.0f}")
# 添加部分chunk内容作为参考
sample_chunks = chunks[:5]
sample_content = "\n".join([f"示例{i+1}: {chunk[:200]}..." for i, chunk in enumerate(sample_chunks)])
prompt_parts.append(f"\n内容示例:\n{sample_content}")
system_prompt = """你是一位专业的知识内容分析师。你的任务是根据提供的信息,生成一段简洁、有洞察力的分析报告。
重要规则:
1. 报告需要将所有要点流畅地串联成一个段落
2. 语言风格要专业、客观,同时易于理解
3. 不要添加任何额外的解释或标题,直接输出报告内容
4. 基于提供的数据和示例内容进行分析,不要编造信息
5. 重点关注内容的主题、特点和价值
6. 报告长度控制在150-200字
例如,如果输入是:
- 内容领域分布: 技术(60%), 商业(25%), 教育(15%)
- 内容规模: 共50个知识片段平均长度320字
内容示例: [示例内容...]
你的输出应该类似:
"该知识库主要聚焦于技术领域(60%),涵盖商业(25%)和教育(15%)相关内容。共包含50个知识片段平均每个片段约320字内容详实。从示例来看内容涉及[具体主题],体现了[特点],对[目标用户]具有较高的参考价值。"
"""
user_prompt = "\n".join(prompt_parts)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
# 调用LLM生成洞察
llm_client = _get_llm_client()
response = await llm_client.chat(messages=messages) response = await llm_client.chat(messages=messages)
raw_text = response.content.strip() if response and response.content else ""
insight = response.content.strip() sections = _parse_sections(raw_text, language=language)
business_logger.info(f"成功生成chunk洞察分析了 {min(len(chunks), max_chunks)} 个片段") sections["_raw"] = raw_text
return insight business_logger.info(
f"成功生成chunk洞察四维度分析了 {min(len(chunks), max_chunks)} 个片段"
)
return sections
except Exception as e: except Exception as e:
business_logger.error(f"生成chunk洞察失败: {str(e)}") business_logger.error(f"生成chunk洞察失败: {str(e)}")
return "洞察生成失败" empty = {k: "" for k in ("memory_insight", "behavior_pattern", "key_findings", "growth_trajectory")}
empty["_raw"] = "洞察生成失败"
return empty
if __name__ == "__main__":
# 测试代码
test_chunks = [
"Python是一种高级编程语言以其简洁的语法和强大的功能而闻名。它广泛应用于Web开发、数据分析、人工智能等领域。",
"机器学习算法可以从数据中自动学习模式,无需显式编程。常见的算法包括决策树、随机森林、神经网络等。",
"深度学习是机器学习的一个分支,使用多层神经网络来学习数据的层次化表示。它在图像识别、语音识别等任务中表现出色。",
"自然语言处理技术使计算机能够理解和生成人类语言。应用包括机器翻译、情感分析、文本摘要等。",
"数据科学结合了统计学、计算机科学和领域知识,用于从数据中提取有价值的洞察。"
]
print("开始生成chunk洞察...")
insight = asyncio.run(generate_chunk_insight(test_chunks))
print(f"\n生成的洞察:\n{insight}")

View File

@@ -1,11 +1,10 @@
""" """
Generate summary for RAG chunks. Generate summary for RAG chunks using memory_summary.jinja2 prompt template.
This module provides functionality to summarize chunk content using LLM.
""" """
import asyncio import asyncio
from typing import Any, Dict, List import os
from typing import List, Optional
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
@@ -14,94 +13,135 @@ from pydantic import BaseModel, Field
business_logger = get_business_logger() business_logger = get_business_logger()
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
def _get_llm_client():
"""Get LLM client using db context."""
with get_db_context() as db:
factory = MemoryClientFactory(db)
return factory.get_llm_client(None) # Uses default LLM
class ChunkSummary(BaseModel): # ── Schema ──────────────────────────────────────────────────────────────────
"""Pydantic model for chunk summary."""
summary: str = Field(..., description="简洁的chunk内容摘要") class MemorySummaryStatement(BaseModel):
"""Single labelled statement extracted by memory_summary.jinja2."""
statement: str = Field(..., description="提取的陈述内容")
label: Optional[str] = Field(None, description="陈述标签")
async def generate_chunk_summary(chunks: List[str], max_chunks: int = 10) -> str: class MemorySummaryResponse(BaseModel):
""" """
Generate a summary for the given chunks. Structured output expected from memory_summary.jinja2.
The template asks for a JSON array of labelled statements;
we wrap it in an object so response_structured can parse it.
"""
statements: List[MemorySummaryStatement] = Field(
default_factory=list,
description="从chunk中提取的陈述列表"
)
summary: Optional[str] = Field(None, description="整体摘要文本(可选)")
# ── LLM client helper ────────────────────────────────────────────────────────
def _get_llm_client(end_user_id: Optional[str] = None):
"""Get LLM client, preferring user-connected config with fallback to default."""
with get_db_context() as db:
try:
if end_user_id:
from app.services.memory_agent_service import get_end_user_connected_config
from app.services.memory_config_service import MemoryConfigService
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
workspace_id = connected_config.get("workspace_id")
if config_id or workspace_id:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id
)
factory = MemoryClientFactory(db)
return factory.get_llm_client(memory_config.llm_model_id)
except Exception as e:
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
factory = MemoryClientFactory(db)
return factory.get_llm_client(DEFAULT_LLM_ID)
# ── Core function ─────────────────────────────────────────────────────────────
async def generate_chunk_summary(
chunks: List[str],
max_chunks: int = 10,
end_user_id: Optional[str] = None,
language: str = "zh",
) -> str:
"""
Generate a user summary from RAG chunks using the memory_summary.jinja2 template.
The template extracts labelled statements from the chunks; we then join them
into a coherent summary string that can be stored in end_user.user_summary.
Args: Args:
chunks: List of chunk content strings chunks: List of chunk content strings
max_chunks: Maximum number of chunks to process (default: 10) max_chunks: Maximum number of chunks to process
end_user_id: Optional end-user ID for model selection
language: Output language ("zh" or "en")
Returns: Returns:
A concise summary of the chunks Summary string (joined statements or fallback text)
""" """
if not chunks: if not chunks:
business_logger.warning("没有提供chunk内容用于生成摘要") business_logger.warning("没有提供chunk内容用于生成摘要")
return "暂无内容" return "暂无内容"
try: try:
# 限制处理的chunk数量避免token过多 from app.core.memory.utils.prompt.prompt_utils import render_memory_summary_prompt
chunks_to_process = chunks[:max_chunks] chunks_to_process = chunks[:max_chunks]
chunk_texts = "\n\n".join(
# 合并chunk内容 [f"片段{i + 1}: {chunk}" for i, chunk in enumerate(chunks_to_process)]
combined_content = "\n\n".join([f"片段{i+1}: {chunk}" for i, chunk in enumerate(chunks_to_process)])
# 构建prompt
system_prompt = (
"你是一位专业的文本摘要助手。请基于提供的文本片段,生成简洁的摘要。要求:\n"
"- 摘要长度控制在100-150字\n"
"- 提取核心信息和关键要点;\n"
"- 使用客观、清晰的语言;\n"
"- 避免冗余和重复;\n"
"- 如果内容涉及多个主题,按重要性排序呈现。"
) )
user_prompt = f"请为以下文本片段生成摘要:\n\n{combined_content}" json_schema = MemorySummaryResponse.model_json_schema()
messages = [ rendered_prompt = await render_memory_summary_prompt(
{"role": "system", "content": system_prompt}, chunk_texts=chunk_texts,
{"role": "user", "content": user_prompt}, json_schema=json_schema,
] max_words=200,
language=language,
)
# 调用LLM生成摘要 messages = [{"role": "user", "content": rendered_prompt}]
llm_client = _get_llm_client()
response = await llm_client.chat(messages=messages)
summary = response.content.strip() llm_client = _get_llm_client(end_user_id)
business_logger.info(f"成功生成chunk摘要处理了 {len(chunks_to_process)} 个片段")
# Try structured output; fall back to plain chat only for LLMClientException
# (indicates the model/provider doesn't support structured output).
# All other exceptions are re-raised so config/schema errors stay visible.
try:
response: MemorySummaryResponse = await llm_client.response_structured(
messages=messages,
response_model=MemorySummaryResponse,
)
if response.summary:
summary = response.summary.strip()
elif response.statements:
summary = "".join(s.statement for s in response.statements)
else:
summary = "暂无内容"
except Exception as e:
from app.core.memory.llm_tools.llm_client import LLMClientException
if isinstance(e, LLMClientException):
business_logger.warning(
f"结构化输出不可用,降级为普通对话: end_user_id={end_user_id}, reason={e}"
)
raw = await llm_client.chat(messages=messages)
summary = raw.content.strip() if raw and raw.content else "暂无内容"
else:
business_logger.error(f"生成摘要时发生非预期异常: {e}")
raise
business_logger.info(
f"成功生成chunk摘要处理了 {len(chunks_to_process)} 个片段"
)
return summary return summary
except Exception as e: except Exception as e:
business_logger.error(f"生成chunk摘要失败: {str(e)}") business_logger.error(f"生成chunk摘要失败: {str(e)}")
return "摘要生成失败" return "摘要生成失败"
async def generate_chunk_summary_batch(chunks_list: List[List[str]]) -> List[str]:
"""
Generate summaries for multiple chunk lists in batch.
Args:
chunks_list: List of chunk lists
Returns:
List of summaries
"""
tasks = [generate_chunk_summary(chunks) for chunks in chunks_list]
return await asyncio.gather(*tasks)
if __name__ == "__main__":
# 测试代码
test_chunks = [
"这是第一段测试内容,讲述了关于机器学习的基础知识。",
"第二段内容介绍了深度学习的应用场景和发展历史。",
"第三段讨论了自然语言处理技术的最新进展。"
]
print("开始生成chunk摘要...")
summary = asyncio.run(generate_chunk_summary(test_chunks))
print(f"\n生成的摘要:\n{summary}")

View File

@@ -5,8 +5,9 @@ This module provides functionality to extract meaningful tags from chunk content
""" """
import asyncio import asyncio
import os
from collections import Counter from collections import Counter
from typing import List, Tuple from typing import List, Optional, Tuple
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
@@ -15,12 +16,31 @@ from pydantic import BaseModel, Field
business_logger = get_business_logger() business_logger = get_business_logger()
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
def _get_llm_client():
"""Get LLM client using db context.""" def _get_llm_client(end_user_id: Optional[str] = None):
"""Get LLM client, preferring user-connected config with fallback to default."""
with get_db_context() as db: with get_db_context() as db:
try:
if end_user_id:
from app.services.memory_agent_service import get_end_user_connected_config
from app.services.memory_config_service import MemoryConfigService
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
workspace_id = connected_config.get("workspace_id")
if config_id or workspace_id:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id
)
factory = MemoryClientFactory(db)
return factory.get_llm_client(memory_config.llm_model_id)
except Exception as e:
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
factory = MemoryClientFactory(db) factory = MemoryClientFactory(db)
return factory.get_llm_client(None) # Uses default LLM return factory.get_llm_client(DEFAULT_LLM_ID)
class ExtractedTags(BaseModel): class ExtractedTags(BaseModel):
@@ -33,7 +53,7 @@ class ExtractedPersona(BaseModel):
personas: List[str] = Field(..., description="从文本中提取的人物形象列表,如'产品设计师''旅行爱好者'") personas: List[str] = Field(..., description="从文本中提取的人物形象列表,如'产品设计师''旅行爱好者'")
async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks: int = 10) -> List[Tuple[str, int]]: async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks: int = 10, end_user_id: Optional[str] = None) -> List[Tuple[str, int]]:
""" """
Extract meaningful tags from the given chunks. Extract meaningful tags from the given chunks.
@@ -64,7 +84,7 @@ async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks:
"标签应该是名词或名词短语,能够准确概括文本的核心内容。" "标签应该是名词或名词短语,能够准确概括文本的核心内容。"
) )
llm_client = _get_llm_client() llm_client = _get_llm_client(end_user_id)
# 为每个chunk单独提取标签然后统计频率 # 为每个chunk单独提取标签然后统计频率
all_tags = [] all_tags = []
@@ -116,7 +136,7 @@ async def extract_chunk_tags_with_frequency(chunks: List[str], max_tags: int = 1
return await extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=len(chunks)) return await extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=len(chunks))
async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_chunks: int = 20) -> List[str]: async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_chunks: int = 20, end_user_id: Optional[str] = None) -> List[str]:
""" """
Extract persona (人物形象) from the given chunks. Extract persona (人物形象) from the given chunks.
@@ -159,7 +179,7 @@ async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_ch
] ]
# 调用LLM提取人物形象 # 调用LLM提取人物形象
llm_client = _get_llm_client() llm_client = _get_llm_client(end_user_id)
structured_response = await llm_client.response_structured( structured_response = await llm_client.response_structured(
messages=messages, messages=messages,
response_model=ExtractedPersona response_model=ExtractedPersona

View File

@@ -7,7 +7,7 @@ file operations across different storage backends.
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import AsyncIterator, Optional
class StorageBackend(ABC): class StorageBackend(ABC):
@@ -42,6 +42,26 @@ class StorageBackend(ABC):
""" """
pass pass
@abstractmethod
async def upload_stream(
self,
file_key: str,
stream: AsyncIterator[bytes],
content_type: Optional[str] = None,
) -> int:
"""
Upload a file from an async byte stream.
Args:
file_key: Unique identifier for the file.
stream: Async iterator yielding bytes chunks.
content_type: Optional MIME type of the file.
Returns:
Total bytes written.
"""
pass
@abstractmethod @abstractmethod
async def download(self, file_key: str) -> bytes: async def download(self, file_key: str) -> bytes:
""" """

View File

@@ -85,6 +85,7 @@ class StorageFactory:
access_key_id=settings.S3_ACCESS_KEY_ID, access_key_id=settings.S3_ACCESS_KEY_ID,
secret_access_key=settings.S3_SECRET_ACCESS_KEY, secret_access_key=settings.S3_SECRET_ACCESS_KEY,
bucket_name=settings.S3_BUCKET_NAME, bucket_name=settings.S3_BUCKET_NAME,
endpoint_url=settings.S3_ENDPOINT_URL,
) )
else: else:

View File

@@ -11,6 +11,7 @@ from typing import Optional
import aiofiles import aiofiles
import aiofiles.os import aiofiles.os
from typing import AsyncIterator
from app.core.storage.base import StorageBackend from app.core.storage.base import StorageBackend
from app.core.storage_exceptions import ( from app.core.storage_exceptions import (
@@ -179,6 +180,36 @@ class LocalStorage(StorageBackend):
full_path = self._get_full_path(file_key) full_path = self._get_full_path(file_key)
return full_path.exists() return full_path.exists()
async def upload_stream(
self,
file_key: str,
stream: AsyncIterator[bytes],
content_type: Optional[str] = None,
) -> int:
"""
Upload a file from an async byte stream to the local file system.
Returns:
Total bytes written.
"""
full_path = self._get_full_path(file_key)
try:
full_path.parent.mkdir(parents=True, exist_ok=True)
total = 0
async with aiofiles.open(full_path, "wb") as f:
async for chunk in stream:
await f.write(chunk)
total += len(chunk)
logger.info(f"File stream uploaded successfully: {file_key}")
return total
except Exception as e:
logger.error(f"Failed to stream upload file {file_key}: {e}")
raise StorageUploadError(
message=f"Failed to stream upload file: {e}",
file_key=file_key,
cause=e,
)
async def get_url(self, file_key: str, expires: int = 3600) -> str: async def get_url(self, file_key: str, expires: int = 3600) -> str:
""" """
Get an access URL for the file. Get an access URL for the file.

View File

@@ -5,8 +5,9 @@ This module provides a storage backend that stores files on Aliyun Object
Storage Service (OSS) using the oss2 SDK. Storage Service (OSS) using the oss2 SDK.
""" """
import io
import logging import logging
from typing import Optional from typing import AsyncIterator, Optional
import oss2 import oss2
from oss2.exceptions import NoSuchKey, OssError from oss2.exceptions import NoSuchKey, OssError
@@ -125,10 +126,39 @@ class OSSStorage(StorageBackend):
cause=e, cause=e,
) )
async def upload_stream(
self,
file_key: str,
stream: AsyncIterator[bytes],
content_type: Optional[str] = None,
) -> int:
"""Upload from async stream to OSS. Returns total bytes written."""
buf = io.BytesIO()
try:
async for chunk in stream:
buf.write(chunk)
content = buf.getvalue()
headers = {"Content-Type": content_type} if content_type else None
self.bucket.put_object(file_key, content, headers=headers)
logger.info(f"File stream uploaded to OSS successfully: {file_key}")
return len(content)
except OssError as e:
logger.error(f"OSS error stream uploading file {file_key}: {e}")
raise StorageUploadError(
message=f"Failed to stream upload file to OSS: {e.message}",
file_key=file_key,
cause=e,
)
except Exception as e:
logger.error(f"Failed to stream upload file to OSS {file_key}: {e}")
raise StorageUploadError(
message=f"Failed to stream upload file to OSS: {e}",
file_key=file_key,
cause=e,
)
async def download(self, file_key: str) -> bytes: async def download(self, file_key: str) -> bytes:
""" """
Download a file from OSS.
Args: Args:
file_key: Unique identifier for the file in the storage system. file_key: Unique identifier for the file in the storage system.

View File

@@ -5,8 +5,9 @@ This module provides a storage backend that stores files on AWS S3
using the boto3 SDK. using the boto3 SDK.
""" """
import io
import logging import logging
from typing import Optional from typing import AsyncIterator, Optional
import boto3 import boto3
from botocore.exceptions import ClientError, NoCredentialsError, BotoCoreError from botocore.exceptions import ClientError, NoCredentialsError, BotoCoreError
@@ -35,6 +36,19 @@ class S3Storage(StorageBackend):
bucket_name: The name of the S3 bucket. bucket_name: The name of the S3 bucket.
region: The AWS region. region: The AWS region.
""" """
AMAZON_S3_ENDPOINT_MAP = {
"us-east-1": "https://s3.us-east-1.amazonaws.com", # 特殊:无地域后缀
"us-east-2": "https://s3.us-east-2.amazonaws.com",
"us-west-1": "https://s3.us-west-1.amazonaws.com",
"us-west-2": "https://s3.us-west-2.amazonaws.com",
"ap-east-1": "https://s3.ap-east-1.amazonaws.com", # 香港
"ap-southeast-1": "https://s3.ap-southeast-1.amazonaws.com", # 新加坡
"ap-southeast-2": "https://s3.ap-southeast-2.amazonaws.com", # 悉尼
"ap-northeast-1": "https://s3.ap-northeast-1.amazonaws.com", # 东京
"eu-central-1": "https://s3.eu-central-1.amazonaws.com", # 法兰克福
"eu-west-1": "https://s3.eu-west-1.amazonaws.com", # 爱尔兰
# 可根据需要扩展其他地域
}
def __init__( def __init__(
self, self,
@@ -42,6 +56,7 @@ class S3Storage(StorageBackend):
access_key_id: str, access_key_id: str,
secret_access_key: str, secret_access_key: str,
bucket_name: str, bucket_name: str,
endpoint_url: Optional[str] = None
): ):
""" """
Initialize the S3Storage backend. Initialize the S3Storage backend.
@@ -51,6 +66,7 @@ class S3Storage(StorageBackend):
access_key_id: The AWS access key ID. access_key_id: The AWS access key ID.
secret_access_key: The AWS secret access key. secret_access_key: The AWS secret access key.
bucket_name: The name of the S3 bucket. bucket_name: The name of the S3 bucket.
endpoint_url: The complete URL to use for the constructed client.
Raises: Raises:
StorageConfigError: If any required configuration is missing. StorageConfigError: If any required configuration is missing.
@@ -69,10 +85,19 @@ class S3Storage(StorageBackend):
self.region = region self.region = region
self.bucket_name = bucket_name self.bucket_name = bucket_name
if not endpoint_url:
# 优先匹配内置映射表(解决特殊地域)
if region in self.AMAZON_S3_ENDPOINT_MAP:
endpoint_url = self.AMAZON_S3_ENDPOINT_MAP[region]
# 兜底:通用拼接(适配未配置的新地域)
else:
endpoint_url = f"https://s3.{region}.amazonaws.com"
try: try:
self.client = boto3.client( self.client = boto3.client(
"s3", "s3",
region_name=region, region_name=region,
endpoint_url=endpoint_url,
aws_access_key_id=access_key_id, aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key, aws_secret_access_key=secret_access_key,
) )
@@ -150,6 +175,62 @@ class S3Storage(StorageBackend):
cause=e, cause=e,
) )
async def upload_stream(
self,
file_key: str,
stream: AsyncIterator[bytes],
content_type: Optional[str] = None,
) -> int:
"""Upload from async stream to S3 via multipart upload. Returns total bytes written."""
extra_args = {"ContentType": content_type} if content_type else {}
mpu = self.client.create_multipart_upload(
Bucket=self.bucket_name, Key=file_key, **extra_args
)
upload_id = mpu["UploadId"]
parts = []
part_number = 1
buf = io.BytesIO()
total = 0
min_part_size = 5 * 1024 * 1024 # S3 最小分片 5MB
try:
async for chunk in stream:
buf.write(chunk)
total += len(chunk)
if buf.tell() >= min_part_size:
buf.seek(0)
resp = self.client.upload_part(
Bucket=self.bucket_name, Key=file_key,
UploadId=upload_id, PartNumber=part_number, Body=buf.read()
)
parts.append({"PartNumber": part_number, "ETag": resp["ETag"]})
part_number += 1
buf = io.BytesIO()
# 上传剩余数据(最后一片可小于 5MB
remaining = buf.getvalue()
if remaining:
resp = self.client.upload_part(
Bucket=self.bucket_name, Key=file_key,
UploadId=upload_id, PartNumber=part_number, Body=remaining
)
parts.append({"PartNumber": part_number, "ETag": resp["ETag"]})
self.client.complete_multipart_upload(
Bucket=self.bucket_name, Key=file_key,
UploadId=upload_id,
MultipartUpload={"Parts": parts}
)
logger.info(f"File stream uploaded to S3 successfully: {file_key}")
return total
except Exception as e:
self.client.abort_multipart_upload(
Bucket=self.bucket_name, Key=file_key, UploadId=upload_id
)
logger.error(f"Failed to stream upload file to S3 {file_key}: {e}")
raise StorageUploadError(
message=f"Failed to stream upload file to S3: {e}",
file_key=file_key,
cause=e,
)
async def download(self, file_key: str) -> bytes: async def download(self, file_key: str) -> bytes:
""" """
Download a file from S3. Download a file from S3.

View File

@@ -195,6 +195,6 @@ class MCPToolManager:
except Exception as e: except Exception as e:
return { return {
"success": False, "success": False,
"error": str(e), "error": "连接失败",
"message": "连接失败" "message": str(e)
} }

View File

@@ -23,7 +23,7 @@ class SimpleMCPClient:
def __init__(self, server_url: str, connection_config: Dict[str, Any] = None): def __init__(self, server_url: str, connection_config: Dict[str, Any] = None):
self.server_url = server_url self.server_url = server_url
self.connection_config = connection_config or {} self.connection_config = connection_config or {}
self.timeout = self.connection_config.get("timeout", 30) self.timeout = self.connection_config.get("timeout", 10)
# 确定连接类型 # 确定连接类型
self.is_websocket = server_url.startswith(("ws://", "wss://")) self.is_websocket = server_url.startswith(("ws://", "wss://"))
@@ -53,6 +53,7 @@ class SimpleMCPClient:
else: else:
await self._connect_http() await self._connect_http()
except Exception as e: except Exception as e:
await self.disconnect()
logger.error(f"MCP连接失败: {self.server_url}, 错误: {e}") logger.error(f"MCP连接失败: {self.server_url}, 错误: {e}")
raise MCPConnectionError(f"连接失败: {e}") raise MCPConnectionError(f"连接失败: {e}")

View File

@@ -8,34 +8,60 @@ from typing import Any
from urllib.parse import quote from urllib.parse import quote
from app.core.workflow.adapters.base_converter import BaseConverter from app.core.workflow.adapters.base_converter import BaseConverter
from app.core.workflow.adapters.errors import UnsupportVariableType, UnknowModelWarning, ExceptionDefineition, \ from app.core.workflow.adapters.errors import (
UnsupportVariableType,
UnknowModelWarning,
ExceptionDefineition,
ExceptionType ExceptionType
from app.core.workflow.nodes.assigner import AssignerNodeConfig )
from app.core.workflow.nodes.assigner.config import AssignmentItem from app.core.workflow.nodes.assigner.config import AssignmentItem
from app.core.workflow.nodes.base_config import VariableDefinition, BaseNodeConfig from app.core.workflow.nodes.base_config import VariableDefinition, BaseNodeConfig
from app.core.workflow.nodes.code import CodeNodeConfig
from app.core.workflow.nodes.code.config import InputVariable, OutputVariable from app.core.workflow.nodes.code.config import InputVariable, OutputVariable
from app.core.workflow.nodes.configs import StartNodeConfig, LLMNodeConfig from app.core.workflow.nodes.configs import (
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig, IterationNodeConfig StartNodeConfig,
from app.core.workflow.nodes.cycle_graph.config import ConditionDetail as LoopConditionDetail, ConditionsConfig, \ LLMNodeConfig,
AssignerNodeConfig,
CodeNodeConfig,
LoopNodeConfig,
IterationNodeConfig,
EndNodeConfig,
HttpRequestNodeConfig,
IfElseNodeConfig,
JinjaRenderNodeConfig,
KnowledgeRetrievalNodeConfig,
NoteNodeConfig,
ParameterExtractorNodeConfig,
QuestionClassifierNodeConfig,
VariableAggregatorNodeConfig
)
from app.core.workflow.nodes.cycle_graph.config import (
ConditionDetail as LoopConditionDetail,
ConditionsConfig,
CycleVariable CycleVariable
from app.core.workflow.nodes.end import EndNodeConfig )
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, AssignmentOperator, HttpAuthType, \ from app.core.workflow.nodes.enums import (
HttpContentType, HttpErrorHandle ValueInputType,
from app.core.workflow.nodes.http_request import HttpRequestNodeConfig ComparisonOperator,
from app.core.workflow.nodes.http_request.config import HttpAuthConfig, HttpContentTypeConfig, HttpFormData, \ AssignmentOperator,
HttpTimeOutConfig, HttpRetryConfig, HttpErrorDefaultTamplete, HttpErrorHandleConfig HttpAuthType,
from app.core.workflow.nodes.if_else import IfElseNodeConfig HttpContentType,
HttpErrorHandle,
NodeType
)
from app.core.workflow.nodes.http_request.config import (
HttpAuthConfig,
HttpContentTypeConfig,
HttpFormData,
HttpTimeOutConfig,
HttpRetryConfig,
HttpErrorDefaultTamplete,
HttpErrorHandleConfig
)
from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig
from app.core.workflow.nodes.jinja_render import JinjaRenderNodeConfig
from app.core.workflow.nodes.jinja_render.config import VariablesMappingConfig from app.core.workflow.nodes.jinja_render.config import VariablesMappingConfig
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
from app.core.workflow.nodes.llm.config import MemoryWindowSetting, MessageConfig from app.core.workflow.nodes.llm.config import MemoryWindowSetting, MessageConfig
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNodeConfig
from app.core.workflow.nodes.parameter_extractor.config import ParamsConfig from app.core.workflow.nodes.parameter_extractor.config import ParamsConfig
from app.core.workflow.nodes.question_classifier import QuestionClassifierNodeConfig
from app.core.workflow.nodes.question_classifier.config import ClassifierConfig from app.core.workflow.nodes.question_classifier.config import ClassifierConfig
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNodeConfig
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
@@ -48,24 +74,24 @@ class DifyConverter(BaseConverter):
def __init__(self): def __init__(self):
self.CONFIG_CONVERT_MAP = { self.CONFIG_CONVERT_MAP = {
"start": self.convert_start_node_config, NodeType.START: self.convert_start_node_config,
"llm": self.convert_llm_node_config, NodeType.LLM: self.convert_llm_node_config,
"answer": self.convert_end_node_config, NodeType.END: self.convert_end_node_config,
"if-else": self.convert_if_else_node_config, NodeType.IF_ELSE: self.convert_if_else_node_config,
"loop": self.convert_loop_node_config, NodeType.LOOP: self.convert_loop_node_config,
"iteration": self.convert_iteration_node_config, NodeType.ITERATION: self.convert_iteration_node_config,
"assigner": self.convert_assigner_node_config, NodeType.ASSIGNER: self.convert_assigner_node_config,
"code": self.convert_code_node_config, NodeType.CODE: self.convert_code_node_config,
"http-request": self.convert_http_node_config, NodeType.HTTP_REQUEST: self.convert_http_node_config,
"template-transform": self.convert_jinja_render_node_config, NodeType.JINJARENDER: self.convert_jinja_render_node_config,
"knowledge-retrieval": self.convert_knowledge_node_config, NodeType.KNOWLEDGE_RETRIEVAL: self.convert_knowledge_node_config,
"parameter-extractor": self.convert_parameter_extractor_node_config, NodeType.PARAMETER_EXTRACTOR: self.convert_parameter_extractor_node_config,
"question-classifier": self.convert_question_classifier_node_config, NodeType.QUESTION_CLASSIFIER: self.convert_question_classifier_node_config,
"variable-aggregator": self.convert_variable_aggregator_node_config, NodeType.VAR_AGGREGATOR: self.convert_variable_aggregator_node_config,
"tool": self.convert_tool_node_config, NodeType.TOOL: self.convert_tool_node_config,
"loop-start": lambda x: {}, NodeType.NOTES: self.convert_notes_config,
"iteration-start": lambda x: {}, NodeType.CYCLE_START: lambda x: {},
"loop-end": lambda x: {}, NodeType.BREAK: lambda x: {},
} }
def get_node_convert(self, node_type): def get_node_convert(self, node_type):
@@ -185,6 +211,9 @@ class DifyConverter(BaseConverter):
"not empty": ComparisonOperator.NOT_EMPTY, "not empty": ComparisonOperator.NOT_EMPTY,
"start with": ComparisonOperator.START_WITH, "start with": ComparisonOperator.START_WITH,
"end with": ComparisonOperator.END_WITH, "end with": ComparisonOperator.END_WITH,
"not contains": ComparisonOperator.NOT_CONTAINS,
"exists": ComparisonOperator.NOT_EMPTY,
"not exists": ComparisonOperator.EMPTY
} }
return operator_map.get(operator, operator) return operator_map.get(operator, operator)
@@ -364,7 +393,7 @@ class DifyConverter(BaseConverter):
node_data = node["data"] node_data = node["data"]
cases = [] cases = []
for case in node_data["cases"]: for case in node_data["cases"]:
case_id = case["id"] case_id = case.get("id") or case.get("case_id")
logical_operator = case["logical_operator"] logical_operator = case["logical_operator"]
conditions = [] conditions = []
for condition in case["conditions"]: for condition in case["conditions"]:
@@ -540,7 +569,8 @@ class DifyConverter(BaseConverter):
] = self.trans_variable_format(content["value"]) ] = self.trans_variable_format(content["value"])
else: else:
if node_data["body"]["data"]: if node_data["body"]["data"]:
body_content = node_data["body"]["data"][0]["value"] body_content = (node_data["body"]["data"][0].get("value") or
self._process_list_variable_litearl(node_data["body"]["data"][0].get("file")))
else: else:
body_content = "" body_content = ""
@@ -728,3 +758,16 @@ class DifyConverter(BaseConverter):
detail=f"Please reconfigure the tool node.", detail=f"Please reconfigure the tool node.",
)) ))
return {} return {}
@staticmethod
def convert_notes_config(node: dict):
node_data = node["data"]
result = NoteNodeConfig.model_construct(
author=node_data.get("author", ""),
text=node_data.get("text", ""),
width=node_data.get("width", 80),
height=node_data.get("height", 80),
theme=node_data.get("theme", "blue"),
show_author=node_data.get("showAuthor", True)
).model_dump()
return result

View File

@@ -44,12 +44,13 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
"parameter-extractor": NodeType.PARAMETER_EXTRACTOR, "parameter-extractor": NodeType.PARAMETER_EXTRACTOR,
"question-classifier": NodeType.QUESTION_CLASSIFIER, "question-classifier": NodeType.QUESTION_CLASSIFIER,
"variable-aggregator": NodeType.VAR_AGGREGATOR, "variable-aggregator": NodeType.VAR_AGGREGATOR,
"tool": NodeType.TOOL "tool": NodeType.TOOL,
"": NodeType.NOTES
} }
def __init__(self, config: dict[str, Any]): def __init__(self, config: dict[str, Any]):
DifyConverter.__init__(self) DifyConverter.__init__(self)
BasePlatformAdapter.__init__(self, config) BasePlatformAdapter.__init__(self, config)
def get_metadata(self) -> PlatformMetadata: def get_metadata(self) -> PlatformMetadata:
return PlatformMetadata( return PlatformMetadata(
@@ -58,7 +59,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
support_node_types=list(self.NODE_TYPE_MAPPING.keys()) support_node_types=list(self.NODE_TYPE_MAPPING.keys())
) )
def map_node_type(self, platform_node_type) -> str: def map_node_type(self, platform_node_type) -> NodeType:
return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN) return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN)
@property @property
@@ -83,7 +84,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
require_fields = frozenset({'app', 'kind', 'version', 'workflow'}) require_fields = frozenset({'app', 'kind', 'version', 'workflow'})
if not all(field in self.config for field in require_fields): if not all(field in self.config for field in require_fields):
return False return False
if self.config.get("app",{}).get("mode") == "workflow": if self.config.get("app", {}).get("mode") == "workflow":
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefineition(
type=ExceptionType.PLATFORM, type=ExceptionType.PLATFORM,
detail="workflow mode is not supported" detail="workflow mode is not supported"
@@ -162,13 +163,14 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None: def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None:
node_data = node["data"] node_data = node["data"]
try: try:
node_type = self.map_node_type(node_data["type"])
return NodeDefinition( return NodeDefinition(
id=node["id"], id=node["id"],
type=self.map_node_type(node_data["type"]), type=node_type,
name=node_data.get("title"), name=node_data.get("title") or "notes",
cycle=node.get("parentId"), cycle=node.get("parentId"),
description=None, description=None,
config=self._convert_node_config(node), config=self._convert_node_config(node_type, node),
position={ position={
"x": node["position"]["x"], "x": node["position"]["x"],
"y": node["position"]["y"] "y": node["position"]["y"]
@@ -182,17 +184,16 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
except Exception as e: except Exception as e:
logger.debug(f"convert node error - {e}", exc_info=True) logger.debug(f"convert node error - {e}", exc_info=True)
def _convert_node_config(self, node: dict): def _convert_node_config(self, node_type: NodeType, node: dict):
node_data = node["data"]
node_type = node_data["type"]
try: try:
node_data = node["data"]
converter = self.get_node_convert(node_type) converter = self.get_node_convert(node_type)
if node_type not in self.CONFIG_CONVERT_MAP: if node_type == NodeType.UNKNOWN:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefineition(
type=ExceptionType.NODE, type=ExceptionType.NODE,
node_id=node["id"], node_id=node["id"],
node_name=node["data"]["title"], node_name=node["data"]["title"],
detail=f"node type {node_type if node_type else 'notes'} is unsupported", detail=f"node type {node_data.get('type')} is unsupported",
)) ))
return converter(node) return converter(node)
except Exception as e: except Exception as e:
@@ -209,16 +210,15 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
source = edge["source"] source = edge["source"]
target = edge["target"] target = edge["target"]
edge_id = edge["id"]
label = None label = None
if source in self.branch_node_cache: if source in self.branch_node_cache:
case_id = "-".join(edge_id.split("-")[1:-2]) case_id = edge["sourceHandle"]
if case_id == "false": if case_id == "false":
label = f'CASE{len(self.branch_node_cache[source])+1}' label = f'CASE{len(self.branch_node_cache[source]) + 1}'
else: else:
label = f'CASE{self.branch_node_cache[source].index(case_id) + 1}' label = f'CASE{self.branch_node_cache[source].index(case_id) + 1}'
if source in self.error_branch_node_cache: if source in self.error_branch_node_cache:
case_id = "-".join(edge_id.split("-")[1:-2]) case_id = edge["sourceHandle"]
if case_id == "source": if case_id == "source":
label = "SUCCESS" label = "SUCCESS"
else: else:
@@ -243,6 +243,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
name=variable["name"], name=variable["name"],
default=variable["value"], default=variable["value"],
type=self.variable_type_map(variable["value_type"]), type=self.variable_type_map(variable["value_type"]),
description=variable.get("description")
) )
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefineition(
@@ -256,5 +257,3 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
def _convert_execution(self, execution: dict[str, Any]) -> ExecutionConfig: def _convert_execution(self, execution: dict[str, Any]) -> ExecutionConfig:
return ExecutionConfig() return ExecutionConfig()

View File

@@ -4,65 +4,145 @@
# @Time : 2026/2/25 14:11 # @Time : 2026/2/25 14:11
from typing import Any from typing import Any
from app.core.logging_config import get_logger
from app.core.workflow.adapters.base_adapter import ( from app.core.workflow.adapters.base_adapter import (
PlatformMetadata, PlatformMetadata,
PlatformType, PlatformType,
BasePlatformAdapter, BasePlatformAdapter,
WorkflowParserResult WorkflowParserResult
) )
from app.schemas.workflow_schema import ExecutionConfig from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType, UnsupportNodeType
from app.core.workflow.adapters.memory_bear.memory_bear_converter import MemoryBearConverter
from app.core.workflow.nodes.enums import NodeType
from app.schemas.workflow_schema import ExecutionConfig, NodeDefinition, EdgeDefinition, VariableDefinition
logger = get_logger()
VALID_NODE_TYPES = frozenset(t.value for t in NodeType if t != NodeType.UNKNOWN)
class MemoryBearAdapter(BasePlatformAdapter): class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
NODE_TYPE_MAPPING = {} NODE_TYPE_MAPPING = {t.value: t for t in NodeType}
def __init__(self, config: dict[str, Any]):
MemoryBearConverter.__init__(self)
BasePlatformAdapter.__init__(self, config)
@property @property
def origin_nodes(self): def origin_nodes(self):
return self.config.get("workflow").get("nodes") return self.config.get("workflow").get("nodes") or []
@property @property
def origin_edges(self): def origin_edges(self):
return self.config.get("workflow").get("edges") return self.config.get("workflow").get("edges") or []
@property @property
def origin_variables(self): def origin_variables(self):
return self.config.get("workflow").get("variables") return self.config.get("workflow").get("variables") or []
def get_metadata(self) -> PlatformMetadata: def get_metadata(self) -> PlatformMetadata:
return PlatformMetadata( return PlatformMetadata(
platform_name=PlatformType.MEMORY_BEAR, platform_name=PlatformType.MEMORY_BEAR,
version="0.2.5", version="0.2.5",
support_node_types=list(self.NODE_TYPE_MAPPING.keys()) support_node_types=list(VALID_NODE_TYPES)
) )
def map_node_type(self, platform_node_type) -> str: def map_node_type(self, platform_node_type: str) -> NodeType:
return platform_node_type return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN)
@staticmethod @staticmethod
def _valid_nodes(node: dict[str, Any]): def _valid_node(node: dict[str, Any]) -> bool:
if "type" not in node["data"]:
return False
if "id" not in node or "type" not in node: if "id" not in node or "type" not in node:
return False return False
if not isinstance(node.get("config"), dict):
return False
return True return True
def validate_config(self) -> bool: def validate_config(self) -> bool:
require_fields = frozenset({'app', 'workflow'}) require_fields = frozenset({'app', 'workflow'})
if not all(field in self.config for field in require_fields): if not all(field in self.config for field in require_fields):
return False return False
for node in self.origin_nodes: for node in self.origin_nodes:
if not self._valid_nodes(node): if not self._valid_node(node):
return False return False
return True return True
def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None:
node_id = node.get("id")
node_name = node.get("name")
try:
node_type = self.map_node_type(node["type"])
if node_type == NodeType.UNKNOWN:
self.errors.append(UnsupportNodeType(
node_id=node_id,
node_type=node["type"]
))
return None
config = node.get("config") or {}
converter = self.get_node_convert(node_type)
converter(node_id, node_name, config) # validates and appends errors if invalid
return NodeDefinition(**node)
except Exception as e:
self.errors.append(ExceptionDefineition(
type=ExceptionType.NODE,
node_id=node_id,
node_name=node_name,
detail=f"convert node error - {e}"
))
logger.debug(f"MemoryBear convert node error - {e}", exc_info=True)
return None
def _convert_edge(self, edge: dict[str, Any], valid_node_ids: set) -> EdgeDefinition | None:
try:
if edge.get("source") not in valid_node_ids or edge.get("target") not in valid_node_ids:
self.warnings.append(ExceptionDefineition(
type=ExceptionType.EDGE,
detail=f"edge {edge.get('id')} skipped: source or target node not found"
))
return None
return EdgeDefinition(**edge)
except Exception as e:
self.errors.append(ExceptionDefineition(
type=ExceptionType.EDGE,
detail=f"convert edge error - {e}"
))
logger.debug(f"MemoryBear convert edge error - {e}", exc_info=True)
return None
def _convert_variable(self, variable: dict[str, Any]) -> VariableDefinition | None:
try:
return VariableDefinition(**variable)
except Exception as e:
self.warnings.append(ExceptionDefineition(
type=ExceptionType.VARIABLE,
name=variable.get("name"),
detail=f"convert variable error - {e}"
))
logger.debug(f"MemoryBear convert variable error - {e}", exc_info=True)
return None
def parse_workflow(self) -> WorkflowParserResult: def parse_workflow(self) -> WorkflowParserResult:
self.nodes = self.origin_nodes for node in self.origin_nodes:
self.edges = self.origin_edges converted = self._convert_node(node)
self.conv_variables = self.origin_variables if converted:
self.nodes.append(converted)
valid_node_ids = {n.id for n in self.nodes}
for edge in self.origin_edges:
converted = self._convert_edge(edge, valid_node_ids)
if converted:
self.edges.append(converted)
for variable in self.origin_variables:
converted = self._convert_variable(variable)
if converted:
self.conv_variables.append(converted)
return WorkflowParserResult( return WorkflowParserResult(
success=True, success=not self.errors and not self.warnings,
platform=self.get_metadata(), platform=self.get_metadata(),
execution_config=ExecutionConfig(), execution_config=ExecutionConfig(),
origin_config=self.config, origin_config=self.config,
@@ -72,5 +152,4 @@ class MemoryBearAdapter(BasePlatformAdapter):
variables=self.conv_variables, variables=self.conv_variables,
warnings=self.warnings, warnings=self.warnings,
errors=self.errors, errors=self.errors,
) )

View File

@@ -0,0 +1,85 @@
# -*- coding: UTF-8 -*-
from app.core.workflow.adapters.base_converter import BaseConverter
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.configs import (
StartNodeConfig,
EndNodeConfig,
LLMNodeConfig,
AgentNodeConfig,
IfElseNodeConfig,
KnowledgeRetrievalNodeConfig,
AssignerNodeConfig,
CodeNodeConfig,
HttpRequestNodeConfig,
JinjaRenderNodeConfig,
VariableAggregatorNodeConfig,
ParameterExtractorNodeConfig,
LoopNodeConfig,
IterationNodeConfig,
QuestionClassifierNodeConfig,
ToolNodeConfig,
MemoryReadNodeConfig,
MemoryWriteNodeConfig,
NoteNodeConfig,
)
from app.core.workflow.nodes.enums import NodeType
class MemoryBearConverter(BaseConverter):
errors: list
warnings: list
CONFIG_CLASS_MAP: dict[NodeType, type[BaseNodeConfig]] = {
NodeType.START: StartNodeConfig,
NodeType.END: EndNodeConfig,
NodeType.ANSWER: EndNodeConfig,
NodeType.LLM: LLMNodeConfig,
NodeType.AGENT: AgentNodeConfig,
NodeType.IF_ELSE: IfElseNodeConfig,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNodeConfig,
NodeType.ASSIGNER: AssignerNodeConfig,
NodeType.CODE: CodeNodeConfig,
NodeType.HTTP_REQUEST: HttpRequestNodeConfig,
NodeType.JINJARENDER: JinjaRenderNodeConfig,
NodeType.VAR_AGGREGATOR: VariableAggregatorNodeConfig,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNodeConfig,
NodeType.LOOP: LoopNodeConfig,
NodeType.ITERATION: IterationNodeConfig,
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNodeConfig,
NodeType.TOOL: ToolNodeConfig,
NodeType.MEMORY_READ: MemoryReadNodeConfig,
NodeType.MEMORY_WRITE: MemoryWriteNodeConfig,
NodeType.NOTES: NoteNodeConfig,
}
@staticmethod
def _convert_file(var):
return None
@staticmethod
def _convert_array_file(var):
return []
def config_validate(self, node_id: str, node_name: str, config_cls: type[BaseNodeConfig], value: dict):
try:
return config_cls.model_validate(value)
except Exception as e:
self.errors.append(ExceptionDefineition(
type=ExceptionType.CONFIG,
node_id=node_id,
node_name=node_name,
detail=str(e)
))
return None
def get_node_convert(self, node_type: NodeType):
config_cls = self.CONFIG_CLASS_MAP.get(node_type)
if not config_cls:
return lambda node_id, node_name, config: config
def validate(node_id: str, node_name: str, config: dict):
self.config_validate(node_id, node_name, config_cls, config)
return config
return validate

View File

@@ -292,6 +292,8 @@ class GraphBuilder:
""" """
for node in self.nodes: for node in self.nodes:
node_type = node.get("type") node_type = node.get("type")
if node_type == NodeType.NOTES:
continue
node_id = node.get("id") node_id = node.get("id")
cycle_node = node.get("cycle") cycle_node = node.get("cycle")
if cycle_node: if cycle_node:

View File

@@ -5,7 +5,7 @@
import re import re
from typing import AsyncGenerator from typing import AsyncGenerator
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, PrivateAttr
from app.core.logging_config import get_logger from app.core.logging_config import get_logger
from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.engine.variable_pool import VariablePool
@@ -52,10 +52,11 @@ class OutputContent(BaseModel):
) )
) )
_SCOPE: str | None = None _SCOPE: str | None = PrivateAttr(default=None)
def get_scope(self) -> str: def get_scope(self) -> str | None:
self._SCOPE = SCOPE_PATTERN.findall(self.literal)[0] matches = SCOPE_PATTERN.findall(self.literal)
self._SCOPE = matches[0] if matches else None
return self._SCOPE return self._SCOPE
def depends_on_scope(self, scope: str) -> bool: def depends_on_scope(self, scope: str) -> bool:
@@ -68,6 +69,8 @@ class OutputContent(BaseModel):
Returns: Returns:
bool: True if this segment references the given scope. bool: True if this segment references the given scope.
""" """
if not self.is_variable:
return False
if self._SCOPE: if self._SCOPE:
return self._SCOPE == scope return self._SCOPE == scope
return self.get_scope() == scope return self.get_scope() == scope
@@ -152,7 +155,7 @@ class StreamOutputConfig(BaseModel):
""" """
# Case 1: resolve control branch dependency # Case 1: resolve control branch dependency
if scope in self.control_nodes.keys(): if scope in self.control_nodes:
if status is None: if status is None:
raise RuntimeError("[Stream Output] Control node activation status not provided") raise RuntimeError("[Stream Output] Control node activation status not provided")
if status in self.control_nodes[scope]: if status in self.control_nodes[scope]:

View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
import logging import logging
import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
from functools import cached_property from functools import cached_property
@@ -15,6 +16,7 @@ from app.core.workflow.variable.base_variable import VariableType, FileObject
from app.db import get_db_read from app.db import get_db_read
from app.models import ModelConfig, ModelApiKey, LoadBalanceStrategy from app.models import ModelConfig, ModelApiKey, LoadBalanceStrategy
from app.schemas import FileInput from app.schemas import FileInput
from app.schemas.model_schema import ModelInfo
from app.services.multimodal_service import MultimodalService from app.services.multimodal_service import MultimodalService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -619,11 +621,12 @@ class BaseNode(ABC):
@staticmethod @staticmethod
async def process_message( async def process_message(
provider: str, api_config: ModelInfo,
is_omni: bool,
content: str | dict | FileObject, content: str | dict | FileObject,
end_user_id: str,
enable_file=False enable_file=False
) -> list | str | None: ) -> list | str | None:
provider = api_config.provider
if isinstance(content, dict): if isinstance(content, dict):
content = FileObject( content = FileObject(
type=content.get("type"), type=content.get("type"),
@@ -642,16 +645,20 @@ class BaseNode(ABC):
if content.content_cache.get(provider): if content.content_cache.get(provider):
return content.content_cache[provider] return content.content_cache[provider]
with get_db_read() as db: with get_db_read() as db:
multimodel_service = MultimodalService(db, provider, is_omni=is_omni) multimodel_service = MultimodalService(db, api_config=api_config)
message = await multimodel_service.process_files( file_obj = FileInput(
[FileInput.model_construct( type=content.type,
type=content.type, url=content.url,
url=content.url, transfer_method=content.transfer_method,
transfer_method=content.transfer_method, origin_file_type=content.origin_file_type,
file_type=content.origin_file_type, upload_file_id=uuid.UUID(content.file_id) if content.file_id else None,
upload_file_id=content.file_id
)]
) )
file_obj.set_content(content.get_content())
message = await multimodel_service.process_files(
end_user_id,
[file_obj],
)
content.set_content(file_obj.get_content())
if message: if message:
content.content_cache[provider] = message content.content_cache[provider] = message
return message return message

View File

@@ -23,6 +23,7 @@ from app.core.workflow.nodes.question_classifier.config import QuestionClassifie
from app.core.workflow.nodes.start.config import StartNodeConfig from app.core.workflow.nodes.start.config import StartNodeConfig
from app.core.workflow.nodes.tool.config import ToolNodeConfig from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
from app.core.workflow.nodes.notes.config import NoteNodeConfig
__all__ = [ __all__ = [
# 基础类 # 基础类
@@ -47,5 +48,6 @@ __all__ = [
"ToolNodeConfig", "ToolNodeConfig",
"MemoryReadNodeConfig", "MemoryReadNodeConfig",
"MemoryWriteNodeConfig", "MemoryWriteNodeConfig",
"CodeNodeConfig" "CodeNodeConfig",
"NoteNodeConfig"
] ]

View File

@@ -25,6 +25,7 @@ class NodeType(StrEnum):
MEMORY_WRITE = "memory-write" MEMORY_WRITE = "memory-write"
UNKNOWN = "unknown" UNKNOWN = "unknown"
NOTES = "notes"
BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER] BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]

View File

@@ -4,6 +4,7 @@ from pydantic import Field, BaseModel, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpAuthType, HttpContentType, HttpErrorHandle from app.core.workflow.nodes.enums import HttpRequestMethod, HttpAuthType, HttpContentType, HttpErrorHandle
from app.core.workflow.variable.base_variable import FileObject
class HttpAuthConfig(BaseModel): class HttpAuthConfig(BaseModel):
@@ -260,6 +261,11 @@ class HttpRequestNodeOutput(BaseModel):
description="Http response headers" description="Http response headers"
) )
files: list[FileObject] = Field(
default_factory=list,
description="List of files",
)
output: str = Field( output: str = Field(
default="SUCCESS", default="SUCCESS",
description="HTTP response body", description="HTTP response body",

View File

@@ -1,24 +1,146 @@
import asyncio import asyncio
import json import json
import logging import logging
import mimetypes
import uuid import uuid
import imghdr
from email.message import Message
from typing import Any, Callable, Coroutine from typing import Any, Callable, Coroutine
import httpx import httpx
# import filetypes # TODO: File support (Feature)
from httpx import AsyncClient, Response, Timeout from httpx import AsyncClient, Response, Timeout
import magic
from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.utils.file_processer import mime_to_file_type
from app.core.workflow.variable.base_variable import VariableType, FileObject
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
from app.schemas import FileType, TransferMethod
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
class HttpResponse:
def __init__(self, response: httpx.Response):
self.response = response
self.headers = dict(response.headers)
self._is_file: bool | None = None
@property
def content_type(self) -> str:
return self.headers.get("content-type", "")
@property
def content_disposition(self) -> Message | None:
content_disposition = self.headers.get("content-disposition", "")
if content_disposition:
msg = Message()
msg["content-disposition"] = content_disposition
return msg
return None
@property
def is_file(self) -> bool:
if self._is_file is not None:
return self._is_file
content_type = self.content_type.split(";")[0].strip().lower()
parsed_content_disposition = self.content_disposition
if parsed_content_disposition:
disp_type = parsed_content_disposition.get_content_disposition()
filename = parsed_content_disposition.get_filename()
if disp_type == "attachment" or filename:
self._is_file = True
return True
if content_type.startswith("text/") and "csv" not in content_type:
return False
if content_type.startswith("application/"):
if any(
text_type in content_type
for text_type in {"json", "xml", "javascript", "x-www-form-urlencoded", "yaml", "graphql"}
):
self._is_file = False
return False
try:
content_sample = self.response.content[:1024]
content_sample.decode("utf-8")
text_markers = (b"{", b"[", b"<", b"function", b"var ", b"const ", b"let ")
if any(marker in content_sample for marker in text_markers):
return False
except UnicodeDecodeError:
self._is_file = True
return True
main_type, _ = mimetypes.guess_type("dummy" + (mimetypes.guess_extension(content_type) or ""))
if main_type:
self._is_file = main_type.split("/")[0] in ("application", "image", "audio", "video")
return self._is_file
self._is_file = any(media_type in content_type for media_type in ("image/", "audio/", "video/"))
return self._is_file
@property
def is_image(self):
if self.is_file:
kind = imghdr.what(None, h=self.response.content)
return kind is not None
return False
@property
def url(self) -> str:
return str(self.response.url)
@property
def body(self) -> str:
if self.is_file:
return f"{'!' if self.is_image else ''}[file]({self.url})"
return self.response.text
@staticmethod
def get_file_type(file_bytes) -> tuple[FileType | None, str | None]:
mime = magic.from_buffer(file_bytes, mime=True)
if mime.startswith("image"):
return FileType.IMAGE, mime
elif mime.startswith("video"):
return FileType.VIDEO, mime
elif mime.startswith("audio"):
return FileType.AUDIO, mime
elif mime in ["application/pdf",
"application/msword",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"text/plain"]:
return FileType.DOCUMENT, mime
return None, None
@property
def files(self) -> list[FileObject]:
file_type, mime_type = self.get_file_type(self.response.content)
origin_file_type = mime_to_file_type(mime_type)
if self.is_file and file_type and origin_file_type:
file_obj = FileObject(
type=file_type,
url=self.url,
transfer_method=TransferMethod.REMOTE_URL.value,
origin_file_type=origin_file_type,
file_id=None,
is_file=True
)
file_obj.set_content(self.response.content)
return [
file_obj
]
return []
class HttpRequestNode(BaseNode): class HttpRequestNode(BaseNode):
""" """
HTTP Request Workflow Node. HTTP Request Workflow Node.
@@ -44,6 +166,7 @@ class HttpRequestNode(BaseNode):
"body": VariableType.STRING, "body": VariableType.STRING,
"status_code": VariableType.NUMBER, "status_code": VariableType.NUMBER,
"headers": VariableType.OBJECT, "headers": VariableType.OBJECT,
"files": VariableType.ARRAY_FILE,
"output": VariableType.STRING "output": VariableType.STRING
} }
@@ -232,10 +355,12 @@ class HttpRequestNode(BaseNode):
) )
resp.raise_for_status() resp.raise_for_status()
logger.info(f"Node {self.node_id}: HTTP request succeeded") logger.info(f"Node {self.node_id}: HTTP request succeeded")
response = HttpResponse(resp)
return HttpRequestNodeOutput( return HttpRequestNodeOutput(
body=resp.text, body=response.body,
status_code=resp.status_code, status_code=resp.status_code,
headers=resp.headers, headers=resp.headers,
files=response.files
).model_dump() ).model_dump()
except (httpx.HTTPStatusError, httpx.RequestError) as e: except (httpx.HTTPStatusError, httpx.RequestError) as e:
logger.error(f"HTTP request node exception: {e}") logger.error(f"HTTP request node exception: {e}")

View File

@@ -5,7 +5,7 @@ from typing import Any
from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
from app.core.workflow.nodes.if_else import IfElseNodeConfig from app.core.workflow.nodes.if_else import IfElseNodeConfig
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance
from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.base_variable import VariableType
@@ -23,6 +23,26 @@ class IfElseNode(BaseNode):
"output": VariableType.STRING "output": VariableType.STRING
} }
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
result = []
for case in self.typed_config.cases:
expressions = []
for expression in case.expressions:
expressions.append({
"left": self.get_variable(expression.left, variable_pool, strict=False),
"right": expression.right
if expression.input_type == ValueInputType.CONSTANT
else self.get_variable(expression.right, variable_pool, strict=False),
"operator": expression.operator,
})
result.append({
"expressions": expressions,
"logical_operator": case.logical_operator,
})
return {
"cases": result
}
@staticmethod @staticmethod
def _evaluate(operator, instance: CompareOperatorInstance) -> Any: def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
match operator: match operator:

View File

@@ -30,6 +30,12 @@ class KnowledgeRetrievalNode(BaseNode):
"output": VariableType.ARRAY_STRING "output": VariableType.ARRAY_STRING
} }
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
return {
"query": self._render_template(self.typed_config.query, variable_pool),
"knowledge_bases": [kb_config.model_dump(mode="json") for kb_config in self.typed_config.knowledge_bases],
}
@staticmethod @staticmethod
def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType): def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType):
""" """
@@ -180,6 +186,8 @@ class KnowledgeRetrievalNode(BaseNode):
RuntimeError: If no valid knowledge base is found or access is denied. RuntimeError: If no valid knowledge base is found or access is denied.
""" """
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config) self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
if not self.typed_config.knowledge_bases:
return []
query = self._render_template(self.typed_config.query, variable_pool) query = self._render_template(self.typed_config.query, variable_pool)
with get_db_read() as db: with get_db_read() as db:
knowledge_bases = self.typed_config.knowledge_bases knowledge_bases = self.typed_config.knowledge_bases

View File

@@ -20,6 +20,7 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig
from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.base_variable import VariableType
from app.db import get_db_context from app.db import get_db_context
from app.models import ModelType from app.models import ModelType
from app.schemas.model_schema import ModelInfo
from app.services.model_service import ModelConfigService from app.services.model_service import ModelConfigService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -113,12 +114,15 @@ class LLMNode(BaseNode):
# 在 Session 关闭前提取所有需要的数据 # 在 Session 关闭前提取所有需要的数据
api_config = self.model_balance(config) api_config = self.model_balance(config)
model_name = api_config.model_name model_info = ModelInfo(
provider = api_config.provider model_name=api_config.model_name,
api_key = api_config.api_key model_type=ModelType(config.type),
api_base = api_config.api_base api_key=api_config.api_key,
is_omni = api_config.is_omni api_base=api_config.api_base,
model_type = config.type provider=api_config.provider,
is_omni=api_config.is_omni,
capability=api_config.capability
)
# 4. 创建 LLM 实例(使用已提取的数据) # 4. 创建 LLM 实例(使用已提取的数据)
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True # 注意:对于流式输出,需要在模型初始化时设置 streaming=True
@@ -126,17 +130,18 @@ class LLMNode(BaseNode):
llm = RedBearLLM( llm = RedBearLLM(
RedBearModelConfig( RedBearModelConfig(
model_name=model_name, model_name=model_info.model_name,
provider=provider, provider=model_info.provider,
api_key=api_key, api_key=model_info.api_key,
base_url=api_base, base_url=model_info.api_base,
extra_params=extra_params, extra_params=extra_params,
is_omni=is_omni is_omni=model_info.is_omni
), ),
type=ModelType(model_type) type=model_info.model_type
) )
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}") logger.debug(
f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}")
messages_config = self.typed_config.messages messages_config = self.typed_config.messages
@@ -148,35 +153,40 @@ class LLMNode(BaseNode):
content_template = msg_config.content content_template = msg_config.content
content_template = self._render_context(content_template, variable_pool) content_template = self._render_context(content_template, variable_pool)
content = self._render_template(content_template, variable_pool) content = self._render_template(content_template, variable_pool)
user_id = self.get_variable("sys.user_id", variable_pool)
# 根据角色创建对应的消息对象 # 根据角色创建对应的消息对象
if role == "system": if role == "system":
messages.append({ messages.append({
"role": "system", "role": "system",
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision) "content": await self.process_message(
model_info,
content,
user_id,
self.typed_config.vision,
)
}) })
elif role in ["user", "human"]: elif role in ["user", "human"]:
messages.append({ messages.append({
"role": "user", "role": "user",
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision) "content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
}) })
elif role in ["ai", "assistant"]: elif role in ["ai", "assistant"]:
messages.append({ messages.append({
"role": "assistant", "role": "assistant",
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision) "content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
}) })
else: else:
logger.warning(f"未知的消息角色: {role},默认使用 user") logger.warning(f"未知的消息角色: {role},默认使用 user")
messages.append({ messages.append({
"role": "user", "role": "user",
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision) "content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
}) })
if self.typed_config.vision_input and self.typed_config.vision: if self.typed_config.vision_input and self.typed_config.vision:
file_content = [] file_content = []
files = variable_pool.get_instance(self.typed_config.vision_input) files = variable_pool.get_instance(self.typed_config.vision_input)
for file in files.value: for file in files.value:
content = await self.process_message(provider, is_omni, file.value, self.typed_config.vision) content = await self.process_message(model_info, file.value, user_id, self.typed_config.vision)
if content: if content:
file_content.extend(content) file_content.extend(content)
if messages and messages[-1]["role"] == 'user': if messages and messages[-1]["role"] == 'user':
@@ -190,14 +200,19 @@ class LLMNode(BaseNode):
if isinstance(message["content"], list): if isinstance(message["content"], list):
file_content = [] file_content = []
for file in message["content"]: for file in message["content"]:
content = await self.process_message(provider, is_omni, file, self.typed_config.vision) content = await self.process_message(model_info, file, user_id, self.typed_config.vision)
if content: if content:
file_content.extend(content) file_content.extend(content)
history_message.append( history_message.append(
{"role": message["role"], "content": file_content} {"role": message["role"], "content": file_content}
) )
else: else:
message["content"] = await self.process_message(provider, is_omni, message["content"], self.typed_config.vision) message["content"] = await self.process_message(
model_info,
message["content"],
user_id,
self.typed_config.vision
)
history_message.append(message) history_message.append(message)
messages = messages[:-1] + history_message + messages[-1:] messages = messages[:-1] + history_message + messages[-1:]
self.messages = messages self.messages = messages
@@ -293,7 +308,7 @@ class LLMNode(BaseNode):
# 调用 LLM流式支持字符串或消息列表 # 调用 LLM流式支持字符串或消息列表
last_meta_data = {} last_meta_data = {}
async for chunk in llm.astream(self.messages, stream_usage=True): async for chunk in llm.astream(self.messages):
# 提取内容 # 提取内容
if hasattr(chunk, 'content'): if hasattr(chunk, 'content'):
content = self.process_model_output(chunk.content) content = self.process_model_output(chunk.content)

View File

@@ -0,0 +1,12 @@
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig
class NoteNodeConfig(BaseNodeConfig):
author: str = Field(default="", description="author")
text: str = Field(default="", description="note content")
width: int = Field(default=80)
height: int = Field(default=80)
theme: str = Field(default="blue")
show_author: bool = Field(default=True)

View File

@@ -37,6 +37,14 @@ class ParameterExtractorNode(BaseNode):
} }
return None return None
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
return {
"text": self._render_template(self.typed_config.text, variable_pool),
"prompt": self._render_template(self.typed_config.prompt, variable_pool),
"params": [param.model_dump(mode="json") for param in self.typed_config.params],
"model_id": str(self.typed_config.model_id),
}
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
outputs = {} outputs = {}
for param in self.typed_config.params: for param in self.typed_config.params:

View File

@@ -27,7 +27,6 @@ class ToolNode(BaseNode):
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
return { return {
"data": VariableType.STRING, "data": VariableType.STRING,
"error_code": VariableType.STRING,
"execution_time": VariableType.NUMBER "execution_time": VariableType.NUMBER
} }
@@ -48,10 +47,7 @@ class ToolNode(BaseNode):
if not tenant_id: if not tenant_id:
logger.error(f"节点 {self.node_id} 缺少租户ID") logger.error(f"节点 {self.node_id} 缺少租户ID")
return { raise ValueError("缺少租户ID")
"success": False,
"data": "缺少租户ID"
}
# 渲染工具参数 # 渲染工具参数
rendered_parameters = {} rendered_parameters = {}
@@ -83,13 +79,8 @@ class ToolNode(BaseNode):
logger.info(f"节点 {self.node_id} 工具执行成功") logger.info(f"节点 {self.node_id} 工具执行成功")
return { return {
"data": result.data if isinstance(result.data, str) else json.dumps(result.data, ensure_ascii=False), "data": result.data if isinstance(result.data, str) else json.dumps(result.data, ensure_ascii=False),
"error_code": "",
"execution_time": result.execution_time "execution_time": result.execution_time
} }
else: else:
logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}") logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}")
return { raise ValueError(f"工具执行失败: {result.error if isinstance(result.error, str) else json.dumps(result.error, ensure_ascii=False)}")
"data": result.error if isinstance(result.error, str) else json.dumps(result.error, ensure_ascii=False),
"error_code": result.error_code,
"execution_time": result.execution_time
}

View File

@@ -0,0 +1,56 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/3/10 13:36
TRANSFORM_FILE_TYPE = {
'text/plain': 'document/text',
'text/markdown': 'document/markdown',
'text/x-markdown': 'document/x-markdown',
'application/pdf': 'document/pdf',
'application/msword': 'document/doc',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': 'document/docx',
'application/vnd.ms-powerpoint': 'document/ppt',
'application/vnd.openxmlformats-officedocument.presentationml.presentation': 'document/pptx',
}
ALLOWED_FILE_TYPES = [
'text/plain',
'text/markdown',
'text/x-markdown',
'application/pdf',
'application/msword',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
'application/vnd.ms-powerpoint',
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
'image/jpg',
'image/jpeg',
'image/png',
'image/gif',
'image/bmp',
'image/webp',
'image/svg+xml',
'video/mp4',
'video/quicktime',
'video/x-msvideo',
'video/x-matroska',
'video/webm',
'video/x-flv',
'video/x-ms-wmv',
'audio/mpeg',
'audio/wav',
'audio/ogg',
'audio/aac',
'audio/flac',
'audio/mp4',
'audio/x-ms-wma',
'audio/x-m4a',
]
def mime_to_file_type(mime_type):
if mime_type not in ALLOWED_FILE_TYPES:
return None
return TRANSFORM_FILE_TYPE.get(mime_type, mime_type)

View File

@@ -138,7 +138,7 @@ class WorkflowValidator:
errors.append("工作流必须至少有一个 end 节点") errors.append("工作流必须至少有一个 end 节点")
# 3. 验证节点 ID 唯一性 # 3. 验证节点 ID 唯一性
node_ids = [n.get("id") for n in nodes] node_ids = [n.get("id") for n in nodes if n.get("type") != NodeType.NOTES]
if len(node_ids) != len(set(node_ids)): if len(node_ids) != len(set(node_ids)):
duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1] duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1]
errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}") errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}")

View File

@@ -114,9 +114,16 @@ class FileObject(BaseModel):
file_id: str | None file_id: str | None
content_cache: dict = Field(default_factory=dict) content_cache: dict = Field(default_factory=dict)
is_file: bool is_file: bool
_byte_content: bytes | None = None
def get_content(self):
return self._byte_content
def set_content(self, byte_content):
self._byte_content = byte_content
class BaseVariable(ABC): class BaseVariable(ABC):
"""Abstract base class for all workflow variables. """Abstract base class for all workflow variables.

View File

@@ -16,7 +16,7 @@ engine = create_engine(
pool_recycle=settings.DB_POOL_RECYCLE, pool_recycle=settings.DB_POOL_RECYCLE,
pool_timeout=settings.DB_POOL_TIMEOUT, pool_timeout=settings.DB_POOL_TIMEOUT,
connect_args={ connect_args={
"options": "-c timezone=Asia/Shanghai -c statement_timeout=60000" "options": "-c timezone=UTC -c statement_timeout=60000"
}, },
) )
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

61
api/app/i18n/README.md Normal file
View File

@@ -0,0 +1,61 @@
# Internationalization (i18n) Module
This module provides internationalization support for the MemoryBear API.
## Components
- `service.py` - Translation service and core translation logic
- `middleware.py` - Language detection middleware
- `dependencies.py` - FastAPI dependency injection functions
- `exceptions.py` - Internationalized exception classes
## Usage
### Basic Translation
```python
from app.i18n import t
# Simple translation
message = t("common.success.created")
# Parameterized translation
message = t("common.validation.required", field="Name")
```
### Enum Translation
```python
from app.i18n import t_enum
# Translate enum value
role_display = t_enum("workspace_role", "manager")
```
### In FastAPI Endpoints
```python
from fastapi import Depends
from app.i18n.dependencies import get_translator
@router.post("/workspaces")
async def create_workspace(
data: WorkspaceCreate,
t: Callable = Depends(get_translator)
):
workspace = await workspace_service.create(data)
return {
"success": True,
"message": t("workspace.created_successfully"),
"data": workspace
}
```
## Configuration
See `app/core/config.py` for i18n configuration options:
- `I18N_DEFAULT_LANGUAGE` - Default language (default: "zh")
- `I18N_SUPPORTED_LANGUAGES` - Supported languages (default: "zh,en")
- `I18N_ENABLE_TRANSLATION_CACHE` - Enable caching (default: true)
- `I18N_LOG_MISSING_TRANSLATIONS` - Log missing translations (default: true)

124
api/app/i18n/__init__.py Normal file
View File

@@ -0,0 +1,124 @@
"""
Internationalization (i18n) module for MemoryBear Enterprise.
This module provides complete i18n support for the backend API including:
- Translation loading from multiple directories (community + enterprise)
- Translation service with caching and fallback
- Language detection middleware
- Dependency injection for FastAPI
- Convenience functions for easy usage
Usage:
from app.i18n import t, t_enum
# Simple translation
message = t("common.success.created")
# Parameterized translation
error = t("common.validation.required", field="名称")
# Enum translation
role_display = t_enum("workspace_role", "manager")
"""
from app.i18n.dependencies import (
get_current_language,
get_enum_translator,
get_translator,
)
from app.i18n.exceptions import (
BadRequestError,
ConflictError,
FileNotFoundError,
FileTooLargeError,
ForbiddenError,
I18nException,
InternalServerError,
InvalidCredentialsError,
InvalidFileTypeError,
NotFoundError,
QuotaExceededError,
RateLimitExceededError,
ServiceUnavailableError,
TenantNotFoundError,
TenantSuspendedError,
TokenExpiredError,
TokenInvalidError,
UnauthorizedError,
UserAlreadyExistsError,
UserNotFoundError,
ValidationError,
WorkspaceNotFoundError,
WorkspacePermissionDeniedError,
get_current_locale,
set_current_locale,
)
from app.i18n.loader import TranslationLoader
from app.i18n.logger import (
TranslationLogger,
get_translation_logger,
log_missing_translation,
log_translation_error,
)
from app.i18n.middleware import LanguageMiddleware
from app.i18n.serializers import (
I18nResponseMixin,
WorkspaceSerializer,
WorkspaceMemberSerializer,
WorkspaceInviteSerializer,
)
from app.i18n.service import (
TranslationService,
get_translation_service,
t,
t_enum,
)
__all__ = [
"TranslationLoader",
"LanguageMiddleware",
"TranslationService",
"get_translation_service",
"t",
"t_enum",
"get_current_language",
"get_translator",
"get_enum_translator",
# Context management
"get_current_locale",
"set_current_locale",
# Logging
"TranslationLogger",
"get_translation_logger",
"log_missing_translation",
"log_translation_error",
# Serializers
"I18nResponseMixin",
"WorkspaceSerializer",
"WorkspaceMemberSerializer",
"WorkspaceInviteSerializer",
# Exception classes
"I18nException",
"BadRequestError",
"UnauthorizedError",
"ForbiddenError",
"NotFoundError",
"ConflictError",
"ValidationError",
"InternalServerError",
"ServiceUnavailableError",
"WorkspaceNotFoundError",
"WorkspacePermissionDeniedError",
"UserNotFoundError",
"UserAlreadyExistsError",
"TenantNotFoundError",
"TenantSuspendedError",
"InvalidCredentialsError",
"TokenExpiredError",
"TokenInvalidError",
"FileNotFoundError",
"FileTooLargeError",
"InvalidFileTypeError",
"RateLimitExceededError",
"QuotaExceededError",
]

291
api/app/i18n/cache.py Normal file
View File

@@ -0,0 +1,291 @@
"""
Advanced caching system for i18n translations.
This module provides:
- LRU cache for hot translations
- Lazy loading mechanism
- Memory optimization
- Cache statistics
"""
import logging
from functools import lru_cache
from typing import Any, Dict, Optional
from collections import OrderedDict
import time
logger = logging.getLogger(__name__)
class TranslationCache:
"""
Advanced translation cache with LRU eviction and lazy loading.
Features:
- LRU cache for frequently accessed translations
- Lazy loading to reduce startup time
- Memory-efficient storage
- Cache hit/miss statistics
"""
def __init__(self, max_lru_size: int = 1000, enable_lazy_load: bool = True):
"""
Initialize the translation cache.
Args:
max_lru_size: Maximum size of LRU cache for hot translations
enable_lazy_load: Enable lazy loading of locales
"""
self.max_lru_size = max_lru_size
self.enable_lazy_load = enable_lazy_load
# Main cache: {locale: {namespace: {key: value}}}
self._main_cache: Dict[str, Dict[str, Any]] = {}
# LRU cache for hot translations
self._lru_cache: OrderedDict = OrderedDict()
# Loaded locales tracker
self._loaded_locales: set = set()
# Statistics
self._stats = {
"hits": 0,
"misses": 0,
"lru_hits": 0,
"lru_misses": 0,
"lazy_loads": 0
}
logger.info(
f"TranslationCache initialized with LRU size: {max_lru_size}, "
f"lazy loading: {enable_lazy_load}"
)
def set_locale_data(self, locale: str, data: Dict[str, Any]):
"""
Set translation data for a locale.
Args:
locale: Locale code
data: Translation data dictionary
"""
self._main_cache[locale] = data
self._loaded_locales.add(locale)
logger.debug(f"Loaded locale '{locale}' into cache")
def get_translation(
self,
locale: str,
namespace: str,
key_path: list
) -> Optional[str]:
"""
Get translation from cache with LRU optimization.
Args:
locale: Locale code
namespace: Translation namespace
key_path: List of nested keys
Returns:
Translation string or None if not found
"""
# Build cache key for LRU
cache_key = f"{locale}:{namespace}:{'.'.join(key_path)}"
# Check LRU cache first (hot translations)
if cache_key in self._lru_cache:
self._stats["lru_hits"] += 1
self._stats["hits"] += 1
# Move to end (most recently used)
self._lru_cache.move_to_end(cache_key)
return self._lru_cache[cache_key]
self._stats["lru_misses"] += 1
# Check main cache
if locale not in self._main_cache:
self._stats["misses"] += 1
return None
if namespace not in self._main_cache[locale]:
self._stats["misses"] += 1
return None
# Navigate through nested keys
current = self._main_cache[locale][namespace]
for key in key_path:
if isinstance(current, dict) and key in current:
current = current[key]
else:
self._stats["misses"] += 1
return None
# Return only if it's a string value
if not isinstance(current, str):
self._stats["misses"] += 1
return None
self._stats["hits"] += 1
# Add to LRU cache
self._add_to_lru(cache_key, current)
return current
def _add_to_lru(self, key: str, value: str):
"""
Add translation to LRU cache.
Args:
key: Cache key
value: Translation value
"""
# Remove oldest if cache is full
if len(self._lru_cache) >= self.max_lru_size:
self._lru_cache.popitem(last=False)
self._lru_cache[key] = value
def is_locale_loaded(self, locale: str) -> bool:
"""
Check if a locale is loaded.
Args:
locale: Locale code
Returns:
True if locale is loaded
"""
return locale in self._loaded_locales
def get_loaded_locales(self) -> list:
"""
Get list of loaded locales.
Returns:
List of locale codes
"""
return list(self._loaded_locales)
def clear_lru(self):
"""Clear the LRU cache."""
self._lru_cache.clear()
logger.info("LRU cache cleared")
def clear_locale(self, locale: str):
"""
Clear cache for a specific locale.
Args:
locale: Locale code
"""
if locale in self._main_cache:
del self._main_cache[locale]
self._loaded_locales.discard(locale)
# Clear related LRU entries
keys_to_remove = [k for k in self._lru_cache if k.startswith(f"{locale}:")]
for key in keys_to_remove:
del self._lru_cache[key]
logger.info(f"Cleared cache for locale '{locale}'")
def clear_all(self):
"""Clear all caches."""
self._main_cache.clear()
self._lru_cache.clear()
self._loaded_locales.clear()
logger.info("All caches cleared")
def get_stats(self) -> Dict[str, Any]:
"""
Get cache statistics.
Returns:
Dictionary with cache statistics
"""
total_requests = self._stats["hits"] + self._stats["misses"]
hit_rate = (
self._stats["hits"] / total_requests * 100
if total_requests > 0
else 0
)
lru_total = self._stats["lru_hits"] + self._stats["lru_misses"]
lru_hit_rate = (
self._stats["lru_hits"] / lru_total * 100
if lru_total > 0
else 0
)
return {
"total_requests": total_requests,
"hits": self._stats["hits"],
"misses": self._stats["misses"],
"hit_rate": round(hit_rate, 2),
"lru_hits": self._stats["lru_hits"],
"lru_misses": self._stats["lru_misses"],
"lru_hit_rate": round(lru_hit_rate, 2),
"lru_size": len(self._lru_cache),
"lru_max_size": self.max_lru_size,
"loaded_locales": len(self._loaded_locales),
"lazy_loads": self._stats["lazy_loads"]
}
def reset_stats(self):
"""Reset cache statistics."""
self._stats = {
"hits": 0,
"misses": 0,
"lru_hits": 0,
"lru_misses": 0,
"lazy_loads": 0
}
logger.info("Cache statistics reset")
def get_memory_usage(self) -> Dict[str, Any]:
"""
Estimate memory usage of the cache.
Returns:
Dictionary with memory usage information
"""
import sys
main_cache_size = sys.getsizeof(self._main_cache)
lru_cache_size = sys.getsizeof(self._lru_cache)
# Rough estimate of nested data
for locale_data in self._main_cache.values():
main_cache_size += sys.getsizeof(locale_data)
for namespace_data in locale_data.values():
main_cache_size += sys.getsizeof(namespace_data)
return {
"main_cache_bytes": main_cache_size,
"lru_cache_bytes": lru_cache_size,
"total_bytes": main_cache_size + lru_cache_size,
"main_cache_mb": round(main_cache_size / 1024 / 1024, 2),
"lru_cache_mb": round(lru_cache_size / 1024 / 1024, 2),
"total_mb": round((main_cache_size + lru_cache_size) / 1024 / 1024, 2)
}
@lru_cache(maxsize=128)
def get_cached_translation_key(locale: str, namespace: str, key: str) -> str:
"""
LRU cached function for building translation cache keys.
This reduces string concatenation overhead for frequently accessed keys.
Args:
locale: Locale code
namespace: Translation namespace
key: Translation key
Returns:
Cache key string
"""
return f"{locale}:{namespace}:{key}"

View File

@@ -0,0 +1,158 @@
"""
FastAPI dependency injection functions for i18n.
This module provides dependency injection functions that can be used
in FastAPI route handlers to access the current language and translator.
"""
import logging
from typing import Callable
from fastapi import Request
from app.i18n.service import get_translation_service
logger = logging.getLogger(__name__)
async def get_current_language(request: Request) -> str:
"""
Get the current language from the request context.
This dependency extracts the language that was determined by the
LanguageMiddleware and stored in request.state.
Args:
request: FastAPI request object
Returns:
Language code (e.g., "zh", "en")
Usage:
@router.get("/example")
async def example(language: str = Depends(get_current_language)):
return {"language": language}
"""
# Get language from request state (set by LanguageMiddleware)
language = getattr(request.state, "language", None)
if language is None:
# Fallback to default language if not set
from app.core.config import settings
language = settings.I18N_DEFAULT_LANGUAGE
logger.warning(
"Language not found in request.state, using default: "
f"{language}"
)
return language
async def get_translator(request: Request) -> Callable:
"""
Get a translator function bound to the current request's language.
This dependency returns a translation function that automatically
uses the current request's language, making it easy to translate
strings in route handlers.
Args:
request: FastAPI request object
Returns:
Translation function with signature: t(key: str, **params) -> str
Usage:
@router.post("/workspaces")
async def create_workspace(
data: WorkspaceCreate,
t: Callable = Depends(get_translator)
):
workspace = await workspace_service.create(data)
return {
"success": True,
"message": t("workspace.created_successfully"),
"data": workspace
}
# With parameters
@router.get("/items")
async def get_items(t: Callable = Depends(get_translator)):
count = 5
return {
"message": t("items.found", count=count)
}
"""
# Get current language
language = await get_current_language(request)
# Get translation service
service = get_translation_service()
# Return a bound translation function
def translate(key: str, **params) -> str:
"""
Translate a key using the current request's language.
Args:
key: Translation key (e.g., "common.success.created")
**params: Parameters for parameterized messages
Returns:
Translated string
"""
return service.translate(key, language, **params)
return translate
async def get_enum_translator(request: Request) -> Callable:
"""
Get an enum translator function bound to the current request's language.
This dependency returns a function for translating enum values
that automatically uses the current request's language.
Args:
request: FastAPI request object
Returns:
Enum translation function with signature:
t_enum(enum_type: str, value: str) -> str
Usage:
@router.get("/workspace/{id}")
async def get_workspace(
id: str,
t_enum: Callable = Depends(get_enum_translator)
):
workspace = await workspace_service.get(id)
return {
"id": workspace.id,
"role": workspace.role,
"role_display": t_enum("workspace_role", workspace.role),
"status": workspace.status,
"status_display": t_enum("workspace_status", workspace.status)
}
"""
# Get current language
language = await get_current_language(request)
# Get translation service
service = get_translation_service()
# Return a bound enum translation function
def translate_enum(enum_type: str, value: str) -> str:
"""
Translate an enum value using the current request's language.
Args:
enum_type: Enum type name (e.g., "workspace_role")
value: Enum value (e.g., "manager")
Returns:
Translated enum display name
"""
return service.translate_enum(enum_type, value, language)
return translate_enum

495
api/app/i18n/exceptions.py Normal file
View File

@@ -0,0 +1,495 @@
"""
Internationalized exception classes for i18n system.
This module provides exception classes that automatically translate
error messages based on the current request's language.
"""
import logging
from contextvars import ContextVar
from typing import Any, Dict, Optional
from fastapi import HTTPException, Request
from app.i18n.service import get_translation_service
logger = logging.getLogger(__name__)
# Context variable to store current locale
_current_locale: ContextVar[Optional[str]] = ContextVar("current_locale", default=None)
def set_current_locale(locale: str) -> None:
"""
Set the current locale in the context variable.
This should be called by the LanguageMiddleware.
Args:
locale: Locale code (e.g., "zh", "en")
"""
_current_locale.set(locale)
def get_current_locale() -> Optional[str]:
"""
Get the current locale from the context variable.
Returns:
Locale code or None if not set
"""
return _current_locale.get()
class I18nException(HTTPException):
"""
Base exception class with automatic i18n support.
This exception automatically translates error messages based on:
1. The current request's language (from request.state.language)
2. The fallback language if request language is not available
3. The error key itself if no translation is found
Features:
- Automatic error message translation
- Parameterized error messages support
- Consistent error response format
- Language-aware error handling
Usage:
# Simple error
raise I18nException(
error_key="errors.workspace.not_found",
status_code=404
)
# Error with parameters
raise I18nException(
error_key="errors.validation.missing_field",
status_code=400,
field="name"
)
# Custom error code
raise I18nException(
error_key="errors.workspace.not_found",
error_code="WORKSPACE_NOT_FOUND",
status_code=404,
workspace_id="123"
)
"""
def __init__(
self,
error_key: str,
status_code: int = 400,
error_code: Optional[str] = None,
locale: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
**params
):
"""
Initialize the i18n exception.
Args:
error_key: Translation key for the error message
(e.g., "errors.workspace.not_found")
status_code: HTTP status code (default: 400)
error_code: Custom error code for API clients
(default: derived from error_key)
locale: Target locale for translation (optional)
If not provided, uses current request's language
headers: Additional HTTP headers
**params: Parameters for parameterized error messages
"""
self.error_key = error_key
self.error_code = error_code or self._generate_error_code(error_key)
self.params = params
# Get locale from request context if not provided
if locale is None:
locale = self._get_current_locale()
# Translate error message
translation_service = get_translation_service()
message = translation_service.translate(
error_key,
locale,
**params
)
# Build error detail
detail = {
"error_code": self.error_code,
"message": message,
}
# Add parameters to detail if provided
if params:
detail["params"] = params
# Initialize HTTPException
super().__init__(
status_code=status_code,
detail=detail,
headers=headers
)
logger.debug(
f"I18nException raised: {self.error_code} "
f"(key: {error_key}, locale: {locale})"
)
def _get_current_locale(self) -> str:
"""
Get the current locale from request context.
Returns:
Locale code (e.g., "zh", "en")
"""
try:
# Try to get locale from context variable
locale = _current_locale.get()
if locale:
return locale
except Exception as e:
logger.debug(f"Could not get locale from context: {e}")
# Fallback to default locale
from app.core.config import settings
return settings.I18N_DEFAULT_LANGUAGE
def _generate_error_code(self, error_key: str) -> str:
"""
Generate error code from error key.
Converts "errors.workspace.not_found" to "WORKSPACE_NOT_FOUND"
Args:
error_key: Translation key
Returns:
Error code in UPPER_SNAKE_CASE
"""
# Remove "errors." prefix if present
if error_key.startswith("errors."):
error_key = error_key[7:]
# Convert to UPPER_SNAKE_CASE
parts = error_key.split(".")
return "_".join(parts).upper()
# Specific exception classes for common errors
class BadRequestError(I18nException):
"""Bad request error (400)."""
def __init__(
self,
error_key: str = "errors.common.bad_request",
error_code: Optional[str] = None,
**params
):
super().__init__(
error_key=error_key,
status_code=400,
error_code=error_code,
**params
)
class UnauthorizedError(I18nException):
"""Unauthorized error (401)."""
def __init__(
self,
error_key: str = "errors.auth.unauthorized",
error_code: Optional[str] = None,
**params
):
super().__init__(
error_key=error_key,
status_code=401,
error_code=error_code,
**params
)
class ForbiddenError(I18nException):
"""Forbidden error (403)."""
def __init__(
self,
error_key: str = "errors.auth.forbidden",
error_code: Optional[str] = None,
**params
):
super().__init__(
error_key=error_key,
status_code=403,
error_code=error_code,
**params
)
class NotFoundError(I18nException):
"""Not found error (404)."""
def __init__(
self,
error_key: str = "errors.common.not_found",
error_code: Optional[str] = None,
**params
):
super().__init__(
error_key=error_key,
status_code=404,
error_code=error_code,
**params
)
class ConflictError(I18nException):
"""Conflict error (409)."""
def __init__(
self,
error_key: str = "errors.common.conflict",
error_code: Optional[str] = None,
**params
):
super().__init__(
error_key=error_key,
status_code=409,
error_code=error_code,
**params
)
class ValidationError(I18nException):
"""Validation error (422)."""
def __init__(
self,
error_key: str = "errors.common.validation_failed",
error_code: Optional[str] = None,
**params
):
super().__init__(
error_key=error_key,
status_code=422,
error_code=error_code,
**params
)
class InternalServerError(I18nException):
"""Internal server error (500)."""
def __init__(
self,
error_key: str = "errors.common.internal_error",
error_code: Optional[str] = None,
**params
):
super().__init__(
error_key=error_key,
status_code=500,
error_code=error_code,
**params
)
class ServiceUnavailableError(I18nException):
"""Service unavailable error (503)."""
def __init__(
self,
error_key: str = "errors.common.service_unavailable",
error_code: Optional[str] = None,
**params
):
super().__init__(
error_key=error_key,
status_code=503,
error_code=error_code,
**params
)
# Domain-specific exception classes
class WorkspaceNotFoundError(NotFoundError):
"""Workspace not found error."""
def __init__(self, workspace_id: Optional[str] = None, **params):
if workspace_id:
params["workspace_id"] = workspace_id
super().__init__(
error_key="errors.workspace.not_found",
error_code="WORKSPACE_NOT_FOUND",
**params
)
class WorkspacePermissionDeniedError(ForbiddenError):
"""Workspace permission denied error."""
def __init__(self, workspace_id: Optional[str] = None, **params):
if workspace_id:
params["workspace_id"] = workspace_id
super().__init__(
error_key="errors.workspace.permission_denied",
error_code="WORKSPACE_PERMISSION_DENIED",
**params
)
class UserNotFoundError(NotFoundError):
"""User not found error."""
def __init__(self, user_id: Optional[str] = None, **params):
if user_id:
params["user_id"] = user_id
super().__init__(
error_key="errors.user.not_found",
error_code="USER_NOT_FOUND",
**params
)
class UserAlreadyExistsError(ConflictError):
"""User already exists error."""
def __init__(self, identifier: Optional[str] = None, **params):
if identifier:
params["identifier"] = identifier
super().__init__(
error_key="errors.user.already_exists",
error_code="USER_ALREADY_EXISTS",
**params
)
class TenantNotFoundError(NotFoundError):
"""Tenant not found error."""
def __init__(self, tenant_id: Optional[str] = None, **params):
if tenant_id:
params["tenant_id"] = tenant_id
super().__init__(
error_key="errors.tenant.not_found",
error_code="TENANT_NOT_FOUND",
**params
)
class TenantSuspendedError(ForbiddenError):
"""Tenant suspended error."""
def __init__(self, tenant_id: Optional[str] = None, **params):
if tenant_id:
params["tenant_id"] = tenant_id
super().__init__(
error_key="errors.tenant.suspended",
error_code="TENANT_SUSPENDED",
**params
)
class InvalidCredentialsError(UnauthorizedError):
"""Invalid credentials error."""
def __init__(self, **params):
super().__init__(
error_key="errors.auth.invalid_credentials",
error_code="INVALID_CREDENTIALS",
**params
)
class TokenExpiredError(UnauthorizedError):
"""Token expired error."""
def __init__(self, **params):
super().__init__(
error_key="errors.auth.token_expired",
error_code="TOKEN_EXPIRED",
**params
)
class TokenInvalidError(UnauthorizedError):
"""Token invalid error."""
def __init__(self, **params):
super().__init__(
error_key="errors.auth.token_invalid",
error_code="TOKEN_INVALID",
**params
)
class FileNotFoundError(NotFoundError):
"""File not found error."""
def __init__(self, file_id: Optional[str] = None, **params):
if file_id:
params["file_id"] = file_id
super().__init__(
error_key="errors.file.not_found",
error_code="FILE_NOT_FOUND",
**params
)
class FileTooLargeError(BadRequestError):
"""File too large error."""
def __init__(self, max_size: Optional[str] = None, **params):
if max_size:
params["max_size"] = max_size
super().__init__(
error_key="errors.file.too_large",
error_code="FILE_TOO_LARGE",
**params
)
class InvalidFileTypeError(BadRequestError):
"""Invalid file type error."""
def __init__(self, file_type: Optional[str] = None, **params):
if file_type:
params["file_type"] = file_type
super().__init__(
error_key="errors.file.invalid_type",
error_code="INVALID_FILE_TYPE",
**params
)
class RateLimitExceededError(I18nException):
"""Rate limit exceeded error (429)."""
def __init__(self, **params):
super().__init__(
error_key="errors.api.rate_limit_exceeded",
status_code=429,
error_code="RATE_LIMIT_EXCEEDED",
**params
)
class QuotaExceededError(ForbiddenError):
"""Quota exceeded error."""
def __init__(self, resource: Optional[str] = None, **params):
if resource:
params["resource"] = resource
super().__init__(
error_key="errors.api.quota_exceeded",
error_code="QUOTA_EXCEEDED",
**params
)

199
api/app/i18n/loader.py Normal file
View File

@@ -0,0 +1,199 @@
"""
Translation file loader for i18n system.
This module handles loading translation files from multiple directories
(community edition + enterprise edition) and provides hot reload support.
"""
import json
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
class TranslationLoader:
"""
Translation file loader that supports:
- Loading from multiple directories (community + enterprise)
- Hot reload of translation files
- Automatic locale detection
"""
def __init__(self, locales_dirs: Optional[List[str]] = None):
"""
Initialize the translation loader.
Args:
locales_dirs: List of directories containing translation files.
If None, will auto-detect from settings.
"""
if locales_dirs is None:
locales_dirs = self._detect_locales_dirs()
self.locales_dirs = [Path(d) for d in locales_dirs]
logger.info(f"TranslationLoader initialized with directories: {self.locales_dirs}")
def _detect_locales_dirs(self) -> List[str]:
"""
Auto-detect translation directories from settings.
Returns:
List of translation directory paths
"""
from app.core.config import settings
dirs = []
# 1. Core locales directory (community edition, required)
core_dir = Path(settings.I18N_CORE_LOCALES_DIR)
if core_dir.exists():
dirs.append(str(core_dir))
logger.debug(f"Found core locales directory: {core_dir}")
else:
logger.warning(f"Core locales directory not found: {core_dir}")
# 2. Premium locales directory (enterprise edition, optional)
if settings.I18N_PREMIUM_LOCALES_DIR:
premium_dir = Path(settings.I18N_PREMIUM_LOCALES_DIR)
if premium_dir.exists():
dirs.append(str(premium_dir))
logger.debug(f"Found premium locales directory: {premium_dir}")
else:
# Auto-detect premium directory
premium_dir = Path("premium/locales")
if premium_dir.exists():
dirs.append(str(premium_dir))
logger.debug(f"Auto-detected premium locales directory: {premium_dir}")
if not dirs:
logger.error("No translation directories found!")
return dirs
def get_available_locales(self) -> List[str]:
"""
Get list of all available locales across all directories.
Returns:
List of locale codes (e.g., ['zh', 'en'])
"""
locales = set()
for locales_dir in self.locales_dirs:
if not locales_dir.exists():
continue
for locale_dir in locales_dir.iterdir():
if locale_dir.is_dir() and not locale_dir.name.startswith('.'):
locales.add(locale_dir.name)
return sorted(list(locales))
def load_locale(self, locale: str) -> Dict[str, Any]:
"""
Load all translation files for a specific locale from all directories.
Translation files are merged with priority:
- Later directories override earlier directories
- Enterprise translations override community translations
Args:
locale: Locale code (e.g., 'zh', 'en')
Returns:
Dictionary of translations organized by namespace
Format: {namespace: {key: value, ...}, ...}
"""
translations = {}
# Load from each directory in order (later directories override earlier)
for locales_dir in self.locales_dirs:
locale_dir = locales_dir / locale
if not locale_dir.exists():
logger.debug(f"Locale directory not found: {locale_dir}")
continue
# Load all JSON files in this locale directory
for json_file in locale_dir.glob("*.json"):
namespace = json_file.stem
try:
with open(json_file, "r", encoding="utf-8") as f:
new_translations = json.load(f)
# Merge translations (deep merge)
if namespace in translations:
translations[namespace] = self._deep_merge(
translations[namespace],
new_translations
)
logger.debug(
f"Merged translations: {locale}/{namespace} from {json_file}"
)
else:
translations[namespace] = new_translations
logger.debug(
f"Loaded translations: {locale}/{namespace} from {json_file}"
)
except json.JSONDecodeError as e:
logger.error(
f"Failed to parse JSON file {json_file}: {e}"
)
except Exception as e:
logger.error(
f"Failed to load translation file {json_file}: {e}"
)
if not translations:
logger.warning(f"No translations found for locale: {locale}")
return translations
def reload(self, locale: Optional[str] = None) -> Dict[str, Dict[str, Any]]:
"""
Reload translation files.
Args:
locale: Specific locale to reload. If None, reloads all locales.
Returns:
Dictionary of reloaded translations
Format: {locale: {namespace: {key: value}}}
"""
if locale:
logger.info(f"Reloading translations for locale: {locale}")
return {locale: self.load_locale(locale)}
else:
logger.info("Reloading all translations")
all_translations = {}
for loc in self.get_available_locales():
all_translations[loc] = self.load_locale(loc)
return all_translations
def _deep_merge(self, base: Dict, override: Dict) -> Dict:
"""
Deep merge two dictionaries.
Args:
base: Base dictionary
override: Dictionary with values to override
Returns:
Merged dictionary
"""
result = base.copy()
for key, value in override.items():
if (
key in result
and isinstance(result[key], dict)
and isinstance(value, dict)
):
result[key] = self._deep_merge(result[key], value)
else:
result[key] = value
return result

382
api/app/i18n/logger.py Normal file
View File

@@ -0,0 +1,382 @@
"""
Translation logging for i18n system.
This module provides:
- TranslationLogger for recording missing translations
- Missing translation report generation
- Integration with existing logging system
- Structured logging for translation events
"""
import logging
from typing import Dict, List, Optional, Set
from datetime import datetime
from collections import defaultdict
from pathlib import Path
import json
from app.core.logging_config import get_logger
logger = get_logger(__name__)
class TranslationLogger:
"""
Logger for translation events and missing translations.
Features:
- Records missing translations with context
- Generates missing translation reports
- Integrates with existing logging system
- Provides structured logging for analysis
"""
def __init__(self, log_file: Optional[str] = None):
"""
Initialize translation logger.
Args:
log_file: Optional custom log file path for missing translations
"""
self.log_file = log_file or "logs/i18n/missing_translations.log"
self._missing_translations: Dict[str, Set[str]] = defaultdict(set)
self._missing_with_context: List[Dict] = []
self._max_context_entries = 10000 # Keep last 10k entries
# Ensure log directory exists
log_path = Path(self.log_file)
log_path.parent.mkdir(parents=True, exist_ok=True)
# Create dedicated file handler for missing translations
self._file_handler = logging.FileHandler(
self.log_file,
encoding='utf-8'
)
self._file_handler.setLevel(logging.WARNING)
# Create formatter
formatter = logging.Formatter(
fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
self._file_handler.setFormatter(formatter)
# Create dedicated logger for missing translations
self._logger = logging.getLogger("i18n.missing_translations")
self._logger.setLevel(logging.WARNING)
self._logger.addHandler(self._file_handler)
self._logger.propagate = False # Don't propagate to root logger
logger.info(f"TranslationLogger initialized with log file: {self.log_file}")
def log_missing_translation(
self,
key: str,
locale: str,
context: Optional[Dict] = None
):
"""
Log a missing translation.
Args:
key: Translation key that was not found
locale: Locale code
context: Optional context information (e.g., request path, user info)
"""
# Add to missing set
self._missing_translations[locale].add(key)
# Create context entry
entry = {
"timestamp": datetime.now().isoformat(),
"key": key,
"locale": locale,
"context": context or {}
}
# Keep only recent entries to avoid memory bloat
if len(self._missing_with_context) >= self._max_context_entries:
self._missing_with_context.pop(0)
self._missing_with_context.append(entry)
# Log to file
context_str = f" (context: {context})" if context else ""
self._logger.warning(
f"Missing translation: key='{key}', locale='{locale}'{context_str}"
)
def log_translation_error(
self,
error_type: str,
message: str,
key: Optional[str] = None,
locale: Optional[str] = None,
context: Optional[Dict] = None
):
"""
Log a translation error.
Args:
error_type: Type of error (e.g., "format_error", "parameter_missing")
message: Error message
key: Translation key (optional)
locale: Locale code (optional)
context: Optional context information
"""
error_data = {
"error_type": error_type,
"message": message,
"key": key,
"locale": locale,
"context": context or {},
"timestamp": datetime.now().isoformat()
}
self._logger.error(
f"Translation error: {error_type} - {message} "
f"(key: {key}, locale: {locale})"
)
def log_translation_success(
self,
key: str,
locale: str,
duration_ms: Optional[float] = None
):
"""
Log a successful translation (debug level).
Args:
key: Translation key
locale: Locale code
duration_ms: Optional duration in milliseconds
"""
duration_str = f" ({duration_ms:.3f}ms)" if duration_ms else ""
logger.debug(
f"Translation success: key='{key}', locale='{locale}'{duration_str}"
)
def get_missing_translations(
self,
locale: Optional[str] = None
) -> Dict[str, List[str]]:
"""
Get missing translations.
Args:
locale: Specific locale (optional, returns all if None)
Returns:
Dictionary of missing translations by locale
"""
if locale:
return {locale: sorted(list(self._missing_translations.get(locale, set())))}
return {
loc: sorted(list(keys))
for loc, keys in self._missing_translations.items()
}
def get_missing_with_context(
self,
locale: Optional[str] = None,
limit: Optional[int] = None
) -> List[Dict]:
"""
Get missing translations with context.
Args:
locale: Filter by locale (optional)
limit: Maximum number of entries to return (optional)
Returns:
List of missing translation entries with context
"""
entries = self._missing_with_context
# Filter by locale if specified
if locale:
entries = [e for e in entries if e["locale"] == locale]
# Apply limit if specified
if limit:
entries = entries[-limit:]
return entries
def generate_report(
self,
locale: Optional[str] = None,
output_file: Optional[str] = None
) -> Dict:
"""
Generate a missing translation report.
Args:
locale: Specific locale (optional, generates for all if None)
output_file: Optional file path to save report as JSON
Returns:
Report dictionary
"""
missing = self.get_missing_translations(locale)
report = {
"generated_at": datetime.now().isoformat(),
"total_missing": sum(len(keys) for keys in missing.values()),
"missing_by_locale": {
loc: {
"count": len(keys),
"keys": keys
}
for loc, keys in missing.items()
},
"recent_context": self.get_missing_with_context(locale, limit=100)
}
# Save to file if specified
if output_file:
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(report, f, indent=2, ensure_ascii=False)
logger.info(f"Missing translation report saved to: {output_file}")
return report
def get_statistics(self) -> Dict:
"""
Get statistics about missing translations.
Returns:
Dictionary with statistics
"""
total_missing = sum(len(keys) for keys in self._missing_translations.values())
# Count by namespace
namespace_counts = defaultdict(int)
for locale, keys in self._missing_translations.items():
for key in keys:
namespace = key.split('.')[0] if '.' in key else 'unknown'
namespace_counts[namespace] += 1
return {
"total_missing": total_missing,
"locales_affected": len(self._missing_translations),
"missing_by_locale": {
loc: len(keys)
for loc, keys in self._missing_translations.items()
},
"missing_by_namespace": dict(namespace_counts),
"total_context_entries": len(self._missing_with_context)
}
def clear(self, locale: Optional[str] = None):
"""
Clear missing translation records.
Args:
locale: Specific locale to clear (optional, clears all if None)
"""
if locale:
self._missing_translations.pop(locale, None)
self._missing_with_context = [
e for e in self._missing_with_context
if e["locale"] != locale
]
logger.info(f"Cleared missing translations for locale: {locale}")
else:
self._missing_translations.clear()
self._missing_with_context.clear()
logger.info("Cleared all missing translations")
def export_to_json(self, output_file: str):
"""
Export all missing translations to JSON file.
Args:
output_file: Output file path
"""
data = {
"exported_at": datetime.now().isoformat(),
"missing_translations": self.get_missing_translations(),
"statistics": self.get_statistics(),
"recent_context": self.get_missing_with_context(limit=1000)
}
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
logger.info(f"Missing translations exported to: {output_file}")
def __del__(self):
"""Cleanup file handler on deletion."""
try:
if hasattr(self, '_file_handler'):
self._file_handler.close()
self._logger.removeHandler(self._file_handler)
except Exception:
pass
# Global translation logger instance
_translation_logger: Optional[TranslationLogger] = None
def get_translation_logger() -> TranslationLogger:
"""
Get the global translation logger instance.
Returns:
TranslationLogger singleton
"""
global _translation_logger
if _translation_logger is None:
_translation_logger = TranslationLogger()
return _translation_logger
def log_missing_translation(
key: str,
locale: str,
context: Optional[Dict] = None
):
"""
Log a missing translation (convenience function).
Args:
key: Translation key
locale: Locale code
context: Optional context information
"""
translation_logger = get_translation_logger()
translation_logger.log_missing_translation(key, locale, context)
def log_translation_error(
error_type: str,
message: str,
key: Optional[str] = None,
locale: Optional[str] = None,
context: Optional[Dict] = None
):
"""
Log a translation error (convenience function).
Args:
error_type: Type of error
message: Error message
key: Translation key (optional)
locale: Locale code (optional)
context: Optional context information
"""
translation_logger = get_translation_logger()
translation_logger.log_translation_error(
error_type, message, key, locale, context
)

337
api/app/i18n/metrics.py Normal file
View File

@@ -0,0 +1,337 @@
"""
Performance monitoring and metrics for i18n system.
This module provides:
- Translation request counters
- Translation timing metrics
- Missing translation tracking
- Performance monitoring decorators
- Prometheus-compatible metrics
"""
import logging
import time
from functools import wraps
from typing import Any, Callable, Dict, Optional
from collections import defaultdict
from datetime import datetime
logger = logging.getLogger(__name__)
class TranslationMetrics:
"""
Metrics collector for translation operations.
Tracks:
- Translation request counts
- Translation timing (latency)
- Missing translations
- Cache performance
- Locale usage
"""
def __init__(self):
"""Initialize metrics collector."""
# Request counters by locale
self._request_counts: Dict[str, int] = defaultdict(int)
# Missing translation tracker
self._missing_translations: Dict[str, set] = defaultdict(set)
# Timing metrics (in milliseconds)
self._timing_data: list = []
self._max_timing_samples = 10000 # Keep last 10k samples
# Locale usage
self._locale_usage: Dict[str, int] = defaultdict(int)
# Namespace usage
self._namespace_usage: Dict[str, int] = defaultdict(int)
# Error counts
self._error_counts: Dict[str, int] = defaultdict(int)
# Start time
self._start_time = datetime.now()
logger.info("TranslationMetrics initialized")
def record_request(self, locale: str, namespace: str = None):
"""
Record a translation request.
Args:
locale: Locale code
namespace: Translation namespace (optional)
"""
self._request_counts[locale] += 1
self._locale_usage[locale] += 1
if namespace:
self._namespace_usage[namespace] += 1
def record_missing(self, key: str, locale: str):
"""
Record a missing translation.
Args:
key: Translation key
locale: Locale code
"""
self._missing_translations[locale].add(key)
logger.debug(f"Missing translation recorded: {key} (locale: {locale})")
def record_timing(self, duration_ms: float, locale: str, operation: str = "translate"):
"""
Record translation operation timing.
Args:
duration_ms: Duration in milliseconds
locale: Locale code
operation: Operation type
"""
# Keep only recent samples to avoid memory bloat
if len(self._timing_data) >= self._max_timing_samples:
self._timing_data.pop(0)
self._timing_data.append({
"duration_ms": duration_ms,
"locale": locale,
"operation": operation,
"timestamp": time.time()
})
def record_error(self, error_type: str):
"""
Record an error.
Args:
error_type: Type of error
"""
self._error_counts[error_type] += 1
def get_summary(self) -> Dict[str, Any]:
"""
Get metrics summary.
Returns:
Dictionary with metrics summary
"""
total_requests = sum(self._request_counts.values())
total_missing = sum(len(keys) for keys in self._missing_translations.values())
# Calculate timing statistics
timing_stats = self._calculate_timing_stats()
# Calculate uptime
uptime_seconds = (datetime.now() - self._start_time).total_seconds()
return {
"uptime_seconds": round(uptime_seconds, 2),
"total_requests": total_requests,
"requests_per_locale": dict(self._request_counts),
"total_missing_translations": total_missing,
"missing_by_locale": {
locale: len(keys)
for locale, keys in self._missing_translations.items()
},
"timing": timing_stats,
"locale_usage": dict(self._locale_usage),
"namespace_usage": dict(self._namespace_usage),
"error_counts": dict(self._error_counts)
}
def _calculate_timing_stats(self) -> Dict[str, Any]:
"""
Calculate timing statistics.
Returns:
Dictionary with timing statistics
"""
if not self._timing_data:
return {
"count": 0,
"avg_ms": 0,
"min_ms": 0,
"max_ms": 0,
"p50_ms": 0,
"p95_ms": 0,
"p99_ms": 0
}
durations = [d["duration_ms"] for d in self._timing_data]
durations.sort()
count = len(durations)
avg = sum(durations) / count
# Calculate percentiles
p50_idx = int(count * 0.50)
p95_idx = int(count * 0.95)
p99_idx = int(count * 0.99)
return {
"count": count,
"avg_ms": round(avg, 3),
"min_ms": round(durations[0], 3),
"max_ms": round(durations[-1], 3),
"p50_ms": round(durations[p50_idx], 3),
"p95_ms": round(durations[p95_idx], 3),
"p99_ms": round(durations[p99_idx], 3)
}
def get_missing_translations(self, locale: Optional[str] = None) -> Dict[str, list]:
"""
Get missing translations.
Args:
locale: Specific locale (optional, returns all if None)
Returns:
Dictionary of missing translations by locale
"""
if locale:
return {locale: list(self._missing_translations.get(locale, set()))}
return {
locale: list(keys)
for locale, keys in self._missing_translations.items()
}
def reset(self):
"""Reset all metrics."""
self._request_counts.clear()
self._missing_translations.clear()
self._timing_data.clear()
self._locale_usage.clear()
self._namespace_usage.clear()
self._error_counts.clear()
self._start_time = datetime.now()
logger.info("Metrics reset")
def export_prometheus(self) -> str:
"""
Export metrics in Prometheus format.
Returns:
Prometheus-formatted metrics string
"""
lines = []
# Translation requests counter
lines.append("# HELP i18n_translation_requests_total Total number of translation requests")
lines.append("# TYPE i18n_translation_requests_total counter")
for locale, count in self._request_counts.items():
lines.append(f'i18n_translation_requests_total{{locale="{locale}"}} {count}')
# Missing translations counter
lines.append("# HELP i18n_missing_translations_total Total number of missing translations")
lines.append("# TYPE i18n_missing_translations_total counter")
for locale, keys in self._missing_translations.items():
lines.append(f'i18n_missing_translations_total{{locale="{locale}"}} {len(keys)}')
# Timing metrics
timing_stats = self._calculate_timing_stats()
lines.append("# HELP i18n_translation_duration_ms Translation operation duration in milliseconds")
lines.append("# TYPE i18n_translation_duration_ms summary")
lines.append(f'i18n_translation_duration_ms{{quantile="0.5"}} {timing_stats["p50_ms"]}')
lines.append(f'i18n_translation_duration_ms{{quantile="0.95"}} {timing_stats["p95_ms"]}')
lines.append(f'i18n_translation_duration_ms{{quantile="0.99"}} {timing_stats["p99_ms"]}')
lines.append(f'i18n_translation_duration_ms_sum {sum(d["duration_ms"] for d in self._timing_data)}')
lines.append(f'i18n_translation_duration_ms_count {timing_stats["count"]}')
# Error counter
lines.append("# HELP i18n_errors_total Total number of i18n errors")
lines.append("# TYPE i18n_errors_total counter")
for error_type, count in self._error_counts.items():
lines.append(f'i18n_errors_total{{type="{error_type}"}} {count}')
return "\n".join(lines)
# Global metrics instance
_metrics: Optional[TranslationMetrics] = None
def get_metrics() -> TranslationMetrics:
"""
Get the global metrics instance.
Returns:
TranslationMetrics singleton
"""
global _metrics
if _metrics is None:
_metrics = TranslationMetrics()
return _metrics
def monitor_performance(operation: str = "translate"):
"""
Decorator to monitor translation operation performance.
Args:
operation: Operation name for metrics
Returns:
Decorated function
Example:
@monitor_performance("translate")
def translate(key: str, locale: str) -> str:
...
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs):
start_time = time.perf_counter()
try:
result = func(*args, **kwargs)
# Record timing
duration_ms = (time.perf_counter() - start_time) * 1000
# Try to extract locale from args/kwargs
locale = kwargs.get("locale", "unknown")
if not locale and len(args) > 1:
locale = args[1] if isinstance(args[1], str) else "unknown"
metrics = get_metrics()
metrics.record_timing(duration_ms, locale, operation)
return result
except Exception as e:
# Record error
metrics = get_metrics()
metrics.record_error(type(e).__name__)
raise
return wrapper
return decorator
def track_missing_translation(key: str, locale: str):
"""
Track a missing translation.
Args:
key: Translation key
locale: Locale code
"""
metrics = get_metrics()
metrics.record_missing(key, locale)
def track_translation_request(locale: str, namespace: str = None):
"""
Track a translation request.
Args:
locale: Locale code
namespace: Translation namespace (optional)
"""
metrics = get_metrics()
metrics.record_request(locale, namespace)

202
api/app/i18n/middleware.py Normal file
View File

@@ -0,0 +1,202 @@
"""
Language detection middleware for i18n system.
This middleware determines the language to use for each request based on:
1. Query parameter (?lang=en)
2. Accept-Language HTTP header
3. User language preference (from database)
4. Tenant default language
5. System default language
The detected language is injected into request.state.language and
added to the response Content-Language header.
"""
import logging
import re
from typing import Optional
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger(__name__)
class LanguageMiddleware(BaseHTTPMiddleware):
"""
Language detection middleware.
Determines the language for each request based on multiple sources
with a clear priority order, validates the language is supported,
and injects it into the request context.
"""
async def dispatch(self, request: Request, call_next):
"""
Process the request and determine the language.
Args:
request: The incoming request
call_next: The next middleware/handler in the chain
Returns:
Response with Content-Language header added
"""
# Determine the language for this request
language = await self._determine_language(request)
# Validate language is supported
from app.core.config import settings
if language not in settings.I18N_SUPPORTED_LANGUAGES:
logger.warning(
f"Unsupported language '{language}' requested, "
f"falling back to default: {settings.I18N_DEFAULT_LANGUAGE}"
)
language = settings.I18N_DEFAULT_LANGUAGE
# Inject language into request state
request.state.language = language
# Also set in context variable for exception handling
from app.i18n.exceptions import set_current_locale
set_current_locale(language)
logger.debug(f"Request language set to: {language}")
# Process the request
response = await call_next(request)
# Add Content-Language header to response
response.headers["Content-Language"] = language
return response
async def _determine_language(self, request: Request) -> str:
"""
Determine the language to use based on priority order.
Priority:
1. Query parameter (?lang=en)
2. Accept-Language HTTP header
3. User language preference (from database)
4. Tenant default language
5. System default language
Args:
request: The incoming request
Returns:
Language code (e.g., "zh", "en")
"""
from app.core.config import settings
# 1. Check query parameter (?lang=en)
if "lang" in request.query_params:
lang = request.query_params["lang"].strip().lower()
if lang:
logger.debug(f"Language from query parameter: {lang}")
return lang
# 2. Check Accept-Language HTTP header
if "Accept-Language" in request.headers:
lang = self._parse_accept_language(
request.headers["Accept-Language"]
)
if lang:
logger.debug(f"Language from Accept-Language header: {lang}")
return lang
# 3. Check user language preference (requires authentication)
# Note: This assumes user is already loaded into request.state by auth middleware
if hasattr(request.state, "user") and request.state.user:
user = request.state.user
if hasattr(user, "preferred_language") and user.preferred_language:
logger.debug(
f"Language from user preference: {user.preferred_language}"
)
return user.preferred_language
# 4. Check tenant default language
# Note: This assumes tenant is already loaded into request.state
if hasattr(request.state, "tenant") and request.state.tenant:
tenant = request.state.tenant
if hasattr(tenant, "default_language") and tenant.default_language:
logger.debug(
f"Language from tenant default: {tenant.default_language}"
)
return tenant.default_language
# 5. Fall back to system default language
logger.debug(
f"Using system default language: {settings.I18N_DEFAULT_LANGUAGE}"
)
return settings.I18N_DEFAULT_LANGUAGE
def _parse_accept_language(self, header: str) -> Optional[str]:
"""
Parse the Accept-Language HTTP header.
The Accept-Language header format:
Accept-Language: zh-CN,zh;q=0.9,en;q=0.8,en-US;q=0.7
This method:
1. Parses all language codes and their quality values
2. Extracts the base language code (zh-CN -> zh)
3. Sorts by quality value (higher first)
4. Returns the first supported language
Args:
header: Accept-Language header value
Returns:
Language code if found and supported, None otherwise
Examples:
_parse_accept_language("zh-CN,zh;q=0.9,en;q=0.8")
# => "zh" (if zh is supported)
_parse_accept_language("en-US,en;q=0.9")
# => "en" (if en is supported)
"""
from app.core.config import settings
if not header:
return None
# Parse language preferences with quality values
languages = []
for item in header.split(","):
item = item.strip()
if not item:
continue
# Split language code and quality value
parts = item.split(";")
lang_code = parts[0].strip()
# Extract base language code (zh-CN -> zh, en-US -> en)
base_lang = lang_code.split("-")[0].lower()
# Extract quality value (default: 1.0)
quality = 1.0
if len(parts) > 1:
# Look for q=0.9 pattern
q_match = re.search(r"q=([\d.]+)", parts[1])
if q_match:
try:
quality = float(q_match.group(1))
except ValueError:
quality = 1.0
languages.append((base_lang, quality))
# Sort by quality value (descending)
languages.sort(key=lambda x: x[1], reverse=True)
# Return the first supported language
for lang_code, _ in languages:
if lang_code in settings.I18N_SUPPORTED_LANGUAGES:
return lang_code
return None

221
api/app/i18n/serializers.py Normal file
View File

@@ -0,0 +1,221 @@
"""
国际化响应序列化器
提供基础的 I18nResponseMixin 类,用于为 API 响应添加国际化字段。
"""
from typing import Any, Dict, List, Union
from pydantic import BaseModel
class I18nResponseMixin:
"""国际化响应混入类
为响应数据添加国际化字段,特别是为枚举值添加 _display 后缀的翻译字段。
使用方法:
1. 继承此类
2. 实现 _get_enum_fields() 方法定义需要翻译的枚举字段
3. 调用 serialize_with_i18n() 方法序列化数据
示例:
class WorkspaceSerializer(I18nResponseMixin):
def _get_enum_fields(self) -> Dict[str, str]:
return {
"role": "workspace_role",
"status": "workspace_status"
}
def serialize(self, workspace: Workspace, locale: str = "zh") -> Dict:
data = {
"id": str(workspace.id),
"name": workspace.name,
"role": workspace.role,
"status": workspace.status
}
return self.serialize_with_i18n(data, locale)
"""
def serialize_with_i18n(
self,
data: Any,
locale: str = "zh"
) -> Union[Dict, List[Dict], Any]:
"""序列化数据并添加国际化字段
Args:
data: 要序列化的数据(字典、列表或 Pydantic 模型)
locale: 语言代码
Returns:
序列化后的数据,包含国际化字段
"""
# 如果是 Pydantic 模型,转换为字典
if isinstance(data, BaseModel):
data = data.model_dump()
# 处理不同类型的数据
if isinstance(data, dict):
return self._serialize_dict(data, locale)
elif isinstance(data, list):
return [self._serialize_dict(item, locale) if isinstance(item, dict) else item for item in data]
else:
return data
def _serialize_dict(self, data: Dict, locale: str) -> Dict:
"""序列化字典并添加 _display 字段
Args:
data: 字典数据
locale: 语言代码
Returns:
添加了 _display 字段的字典
"""
from app.i18n.service import get_translation_service
translation_service = get_translation_service()
result = data.copy()
# 获取需要翻译的枚举字段
enum_fields = self._get_enum_fields()
# 为每个枚举字段添加 _display 字段
for field, enum_type in enum_fields.items():
if field in result and result[field] is not None:
value = result[field]
# 翻译枚举值
display_value = translation_service.translate_enum(
enum_type=enum_type,
value=str(value),
locale=locale
)
# 添加 _display 字段
result[f"{field}_display"] = display_value
return result
def _get_enum_fields(self) -> Dict[str, str]:
"""获取需要翻译的枚举字段
子类必须实现此方法,返回字段名到枚举类型的映射。
Returns:
字段名到枚举类型的映射
例如: {"role": "workspace_role", "status": "workspace_status"}
"""
return {}
class WorkspaceSerializer(I18nResponseMixin):
"""工作空间序列化器
为工作空间响应添加国际化字段。
"""
def _get_enum_fields(self) -> Dict[str, str]:
"""定义工作空间的枚举字段"""
return {
"role": "workspace_role",
"status": "workspace_status"
}
def serialize(self, workspace_data: Union[Dict, BaseModel], locale: str = "zh") -> Dict:
"""序列化工作空间数据
Args:
workspace_data: 工作空间数据(字典或 Pydantic 模型)
locale: 语言代码
Returns:
序列化后的工作空间数据,包含国际化字段
"""
return self.serialize_with_i18n(workspace_data, locale)
def serialize_list(self, workspaces: List[Union[Dict, BaseModel]], locale: str = "zh") -> List[Dict]:
"""序列化工作空间列表
Args:
workspaces: 工作空间列表
locale: 语言代码
Returns:
序列化后的工作空间列表
"""
return [self.serialize(ws, locale) for ws in workspaces]
class WorkspaceMemberSerializer(I18nResponseMixin):
"""工作空间成员序列化器
为工作空间成员响应添加国际化字段。
"""
def _get_enum_fields(self) -> Dict[str, str]:
"""定义工作空间成员的枚举字段"""
return {
"role": "workspace_role"
}
def serialize(self, member_data: Union[Dict, BaseModel], locale: str = "zh") -> Dict:
"""序列化工作空间成员数据
Args:
member_data: 成员数据(字典或 Pydantic 模型)
locale: 语言代码
Returns:
序列化后的成员数据,包含国际化字段
"""
return self.serialize_with_i18n(member_data, locale)
def serialize_list(self, members: List[Union[Dict, BaseModel]], locale: str = "zh") -> List[Dict]:
"""序列化工作空间成员列表
Args:
members: 成员列表
locale: 语言代码
Returns:
序列化后的成员列表
"""
return [self.serialize(member, locale) for member in members]
class WorkspaceInviteSerializer(I18nResponseMixin):
"""工作空间邀请序列化器
为工作空间邀请响应添加国际化字段。
"""
def _get_enum_fields(self) -> Dict[str, str]:
"""定义工作空间邀请的枚举字段"""
return {
"status": "invite_status",
"role": "workspace_role"
}
def serialize(self, invite_data: Union[Dict, BaseModel], locale: str = "zh") -> Dict:
"""序列化工作空间邀请数据
Args:
invite_data: 邀请数据(字典或 Pydantic 模型)
locale: 语言代码
Returns:
序列化后的邀请数据,包含国际化字段
"""
return self.serialize_with_i18n(invite_data, locale)
def serialize_list(self, invites: List[Union[Dict, BaseModel]], locale: str = "zh") -> List[Dict]:
"""序列化工作空间邀请列表
Args:
invites: 邀请列表
locale: 语言代码
Returns:
序列化后的邀请列表
"""
return [self.serialize(invite, locale) for invite in invites]

370
api/app/i18n/service.py Normal file
View File

@@ -0,0 +1,370 @@
"""
Translation service for i18n system.
This module provides the core translation functionality including:
- Translation lookup with fallback mechanism
- Parameterized message support
- Enum value translation
- Memory caching for performance
- Performance monitoring and metrics
"""
import logging
from functools import lru_cache
from typing import Any, Dict, Optional
from app.i18n.loader import TranslationLoader
from app.i18n.cache import TranslationCache
from app.i18n.metrics import get_metrics, monitor_performance, track_missing_translation, track_translation_request
from app.i18n.logger import get_translation_logger
logger = logging.getLogger(__name__)
class TranslationService:
"""
Translation service that provides:
- Fast translation lookup with memory cache
- Parameterized message support ({param} syntax)
- Fallback mechanism (current locale → default locale → key)
- Enum value translation
- Deep merge of multi-directory translations
"""
def __init__(self, locales_dirs: Optional[list] = None):
"""
Initialize the translation service.
Args:
locales_dirs: List of directories containing translation files.
If None, will auto-detect from settings.
"""
from app.core.config import settings
self.loader = TranslationLoader(locales_dirs)
self.default_locale = settings.I18N_DEFAULT_LANGUAGE
self.fallback_locale = settings.I18N_FALLBACK_LANGUAGE
self.log_missing = settings.I18N_LOG_MISSING_TRANSLATIONS
self.enable_cache = settings.I18N_ENABLE_TRANSLATION_CACHE
# Initialize advanced cache with LRU
lru_cache_size = getattr(settings, 'I18N_LRU_CACHE_SIZE', 1000)
self.cache = TranslationCache(
max_lru_size=lru_cache_size,
enable_lazy_load=False # Load all at startup for now
)
# Load all translations into cache
self._load_all_locales()
# Initialize metrics
self.metrics = get_metrics()
# Initialize translation logger
self.translation_logger = get_translation_logger()
logger.info(
f"TranslationService initialized with default locale: {self.default_locale}, "
f"LRU cache size: {lru_cache_size}"
)
def _load_all_locales(self):
"""Load all available locales into memory cache."""
available_locales = self.loader.get_available_locales()
logger.info(f"Loading translations for locales: {available_locales}")
for locale in available_locales:
locale_data = self.loader.load_locale(locale)
self.cache.set_locale_data(locale, locale_data)
logger.info(f"Loaded {len(available_locales)} locales into cache")
@monitor_performance("translate")
def translate(
self,
key: str,
locale: Optional[str] = None,
**params
) -> str:
"""
Translate a key to the target locale.
Supports:
- Dot-separated keys (e.g., "common.success.created")
- Parameterized messages (e.g., "Hello {name}")
- Fallback mechanism
Args:
key: Translation key (format: "namespace.key.subkey")
locale: Target locale (defaults to default locale)
**params: Parameters for parameterized messages
Returns:
Translated string, or the key itself if translation not found
Examples:
translate("common.success.created", "zh")
# => "创建成功"
translate("common.validation.required", "zh", field="名称")
# => "名称不能为空"
"""
if locale is None:
locale = self.default_locale
# Parse key (namespace.key.subkey)
parts = key.split(".", 1)
if len(parts) < 2:
if self.log_missing:
logger.warning(f"Invalid translation key format: {key}")
return key
namespace = parts[0]
key_path = parts[1].split(".")
# Track request
track_translation_request(locale, namespace)
# Get translation from cache
translation = self.cache.get_translation(locale, namespace, key_path)
# Fallback to default locale if not found
if translation is None and locale != self.fallback_locale:
translation = self.cache.get_translation(
self.fallback_locale, namespace, key_path
)
# If still not found, return the key itself
if translation is None:
if self.log_missing:
logger.warning(
f"Missing translation: {key} (locale: {locale})"
)
track_missing_translation(key, locale)
# Log to translation logger with context
self.translation_logger.log_missing_translation(
key=key,
locale=locale,
context={"namespace": namespace}
)
return key
# Apply parameters if provided
if params:
try:
translation = translation.format(**params)
except KeyError as e:
error_msg = f"Missing parameter in translation '{key}': {e}"
logger.error(error_msg)
self.translation_logger.log_translation_error(
error_type="parameter_missing",
message=error_msg,
key=key,
locale=locale,
context={"params": list(params.keys())}
)
except Exception as e:
error_msg = f"Error formatting translation '{key}': {e}"
logger.error(error_msg)
self.translation_logger.log_translation_error(
error_type="format_error",
message=error_msg,
key=key,
locale=locale
)
return translation
def _get_translation(
self,
locale: str,
namespace: str,
key_path: list
) -> Optional[str]:
"""
Get translation from cache (deprecated, use cache.get_translation).
Args:
locale: Locale code
namespace: Translation namespace
key_path: List of nested keys
Returns:
Translation string or None if not found
"""
return self.cache.get_translation(locale, namespace, key_path)
@monitor_performance("translate_enum")
def translate_enum(
self,
enum_type: str,
value: str,
locale: Optional[str] = None
) -> str:
"""
Translate an enum value.
Args:
enum_type: Enum type name (e.g., "workspace_role")
value: Enum value (e.g., "manager")
locale: Target locale
Returns:
Translated enum display name
Examples:
translate_enum("workspace_role", "manager", "zh")
# => "管理员"
translate_enum("invite_status", "pending", "en")
# => "Pending"
"""
key = f"enums.{enum_type}.{value}"
return self.translate(key, locale)
def has_translation(self, key: str, locale: str) -> bool:
"""
Check if a translation exists for the given key and locale.
Args:
key: Translation key
locale: Locale code
Returns:
True if translation exists, False otherwise
"""
parts = key.split(".", 1)
if len(parts) < 2:
return False
namespace = parts[0]
key_path = parts[1].split(".")
translation = self.cache.get_translation(locale, namespace, key_path)
return translation is not None
def reload(self, locale: Optional[str] = None):
"""
Reload translation files.
Args:
locale: Specific locale to reload. If None, reloads all locales.
"""
logger.info(f"Reloading translations for locale: {locale or 'all'}")
if locale:
locale_data = self.loader.load_locale(locale)
self.cache.set_locale_data(locale, locale_data)
# Clear LRU cache for this locale
self.cache.clear_locale(locale)
else:
self._load_all_locales()
# Clear all LRU cache
self.cache.clear_lru()
logger.info("Translation reload completed")
def get_available_locales(self) -> list:
"""
Get list of all available locales.
Returns:
List of locale codes
"""
return self.cache.get_loaded_locales()
def get_cache_stats(self) -> Dict[str, Any]:
"""
Get cache statistics.
Returns:
Dictionary with cache statistics
"""
return self.cache.get_stats()
def get_metrics_summary(self) -> Dict[str, Any]:
"""
Get metrics summary.
Returns:
Dictionary with metrics summary
"""
return self.metrics.get_summary()
def get_memory_usage(self) -> Dict[str, Any]:
"""
Get memory usage information.
Returns:
Dictionary with memory usage information
"""
return self.cache.get_memory_usage()
def get_loaded_dirs(self) -> list:
"""
Get list of loaded translation directories.
Returns:
List of directory paths
"""
return self.loader.locales_dirs
# Global singleton instance
_translation_service: Optional[TranslationService] = None
def get_translation_service() -> TranslationService:
"""
Get the global translation service instance.
Returns:
TranslationService singleton
"""
global _translation_service
if _translation_service is None:
_translation_service = TranslationService()
return _translation_service
# Convenience functions for easy access
def t(key: str, locale: Optional[str] = None, **params) -> str:
"""
Translate a key (convenience function).
Args:
key: Translation key
locale: Target locale (optional, uses default if not provided)
**params: Parameters for parameterized messages
Returns:
Translated string
Examples:
t("common.success.created")
t("common.validation.required", field="名称")
t("workspace.member_count", count=5)
"""
service = get_translation_service()
return service.translate(key, locale, **params)
def t_enum(enum_type: str, value: str, locale: Optional[str] = None) -> str:
"""
Translate an enum value (convenience function).
Args:
enum_type: Enum type name
value: Enum value
locale: Target locale
Returns:
Translated enum display name
Examples:
t_enum("workspace_role", "manager")
t_enum("invite_status", "pending", "en")
"""
service = get_translation_service()
return service.translate_enum(enum_type, value, locale)

View File

@@ -0,0 +1,26 @@
# English Translation Files
This directory contains English translation files.
## File Structure
- `common.json` - Common translations (success messages, actions, validation)
- `auth.json` - Authentication module translations
- `workspace.json` - Workspace module translations
- `tenant.json` - Tenant module translations
- `errors.json` - Error message translations
- `enums.json` - Enum value translations
## Translation File Format
All translation files use JSON format and support nested structures.
Example:
```json
{
"success": {
"created": "Created successfully",
"updated": "Updated successfully"
}
}
```

View File

@@ -0,0 +1,55 @@
{
"login": {
"success": "Login successful",
"failed": "Login failed",
"invalid_credentials": "Invalid username or password",
"account_locked": "Account has been locked",
"account_disabled": "Account has been disabled"
},
"logout": {
"success": "Logout successful",
"failed": "Logout failed"
},
"token": {
"refresh_success": "Token refreshed successfully",
"invalid": "Invalid token",
"expired": "Token has expired",
"blacklisted": "Token has been invalidated",
"invalid_refresh_token": "Invalid refresh token",
"refresh_token_blacklisted": "Refresh token has been invalidated"
},
"registration": {
"success": "Registration successful",
"failed": "Registration failed",
"email_exists": "Email already in use",
"username_exists": "Username already taken"
},
"password": {
"reset_success": "Password reset successful",
"reset_failed": "Password reset failed",
"change_success": "Password changed successfully",
"change_failed": "Password change failed",
"incorrect": "Incorrect password",
"too_weak": "Password is too weak",
"mismatch": "Passwords do not match"
},
"invite": {
"invalid": "Invalid or expired invite code",
"email_mismatch": "Invite email does not match login email",
"accept_success": "Invite accepted successfully",
"accept_failed": "Failed to accept invite",
"password_verification_failed": "Failed to accept invite, password verification error",
"bind_workspace_success": "Workspace bound successfully",
"bind_workspace_failed": "Failed to bind workspace"
},
"user": {
"not_found": "User not found",
"already_exists": "User already exists",
"created_with_invite": "User created successfully and joined workspace"
},
"session": {
"expired": "Session expired, please login again",
"invalid": "Invalid session",
"single_session_enabled": "Single sign-on enabled, other device sessions will be logged out"
}
}

View File

@@ -0,0 +1,132 @@
{
"success": {
"created": "Created successfully",
"updated": "Updated successfully",
"deleted": "Deleted successfully",
"retrieved": "Retrieved successfully",
"saved": "Saved successfully",
"uploaded": "Uploaded successfully",
"downloaded": "Downloaded successfully",
"sent": "Sent successfully",
"completed": "Completed",
"confirmed": "Confirmed",
"cancelled": "Cancelled",
"archived": "Archived",
"restored": "Restored"
},
"actions": {
"create": "Create",
"update": "Update",
"delete": "Delete",
"view": "View",
"edit": "Edit",
"save": "Save",
"cancel": "Cancel",
"confirm": "Confirm",
"submit": "Submit",
"upload": "Upload",
"download": "Download",
"send": "Send",
"search": "Search",
"filter": "Filter",
"sort": "Sort",
"export": "Export",
"import": "Import",
"refresh": "Refresh",
"reset": "Reset",
"back": "Back",
"next": "Next",
"previous": "Previous",
"finish": "Finish",
"close": "Close",
"open": "Open",
"archive": "Archive",
"restore": "Restore",
"duplicate": "Duplicate",
"share": "Share",
"invite": "Invite",
"remove": "Remove",
"add": "Add",
"select": "Select",
"clear": "Clear"
},
"validation": {
"required": "{field} is required",
"invalid_format": "{field} format is invalid",
"too_long": "{field} cannot exceed {max} characters",
"too_short": "{field} must be at least {min} characters",
"invalid_email": "Invalid email format",
"invalid_url": "Invalid URL format",
"invalid_phone": "Invalid phone number format",
"invalid_date": "Invalid date format",
"invalid_number": "Must be a valid number",
"out_of_range": "{field} must be between {min} and {max}",
"already_exists": "{field} already exists",
"not_found": "{field} not found",
"invalid_value": "Invalid value for {field}",
"password_mismatch": "Passwords do not match",
"weak_password": "Password is too weak, please use a stronger password",
"invalid_credentials": "Invalid username or password",
"unauthorized": "Unauthorized access",
"forbidden": "Permission denied",
"expired": "{field} has expired",
"invalid_token": "Invalid token",
"file_too_large": "File size cannot exceed {max}",
"invalid_file_type": "Unsupported file type",
"duplicate": "Duplicate {field}"
},
"status": {
"active": "Active",
"inactive": "Inactive",
"pending": "Pending",
"processing": "Processing",
"completed": "Completed",
"failed": "Failed",
"cancelled": "Cancelled",
"archived": "Archived",
"deleted": "Deleted",
"draft": "Draft",
"published": "Published",
"suspended": "Suspended",
"expired": "Expired"
},
"messages": {
"loading": "Loading...",
"saving": "Saving...",
"processing": "Processing...",
"uploading": "Uploading...",
"downloading": "Downloading...",
"no_data": "No data available",
"no_results": "No results found",
"confirm_delete": "Are you sure you want to delete? This action cannot be undone.",
"confirm_action": "Are you sure you want to perform this action?",
"operation_success": "Operation successful",
"operation_failed": "Operation failed",
"please_wait": "Please wait...",
"try_again": "Please try again",
"contact_support": "If the problem persists, please contact support"
},
"pagination": {
"page": "Page {page}",
"of": "of {total}",
"items": "{total} items",
"per_page": "{count} per page",
"showing": "Showing {from} to {to} of {total}",
"first": "First",
"last": "Last",
"next": "Next",
"previous": "Previous"
},
"time": {
"just_now": "Just now",
"minutes_ago": "{count} minutes ago",
"hours_ago": "{count} hours ago",
"days_ago": "{count} days ago",
"weeks_ago": "{count} weeks ago",
"months_ago": "{count} months ago",
"years_ago": "{count} years ago",
"today": "Today",
"yesterday": "Yesterday",
"tomorrow": "Tomorrow"
}
}

View File

@@ -0,0 +1,132 @@
{
"workspace_role": {
"owner": "Owner",
"manager": "Manager",
"member": "Member",
"guest": "Guest"
},
"workspace_status": {
"active": "Active",
"inactive": "Inactive",
"archived": "Archived",
"suspended": "Suspended",
"deleted": "Deleted"
},
"invite_status": {
"pending": "Pending",
"accepted": "Accepted",
"rejected": "Rejected",
"revoked": "Revoked",
"expired": "Expired"
},
"user_status": {
"active": "Active",
"inactive": "Inactive",
"suspended": "Suspended",
"deleted": "Deleted",
"pending": "Pending"
},
"tenant_status": {
"active": "Active",
"inactive": "Inactive",
"suspended": "Suspended",
"expired": "Expired",
"trial": "Trial"
},
"file_status": {
"uploading": "Uploading",
"processing": "Processing",
"completed": "Completed",
"failed": "Failed",
"deleted": "Deleted"
},
"task_status": {
"pending": "Pending",
"running": "Running",
"completed": "Completed",
"failed": "Failed",
"cancelled": "Cancelled",
"paused": "Paused"
},
"priority": {
"low": "Low",
"medium": "Medium",
"high": "High",
"urgent": "Urgent"
},
"visibility": {
"public": "Public",
"private": "Private",
"internal": "Internal",
"shared": "Shared"
},
"permission": {
"read": "Read",
"write": "Write",
"delete": "Delete",
"admin": "Admin",
"owner": "Owner"
},
"notification_type": {
"info": "Info",
"warning": "Warning",
"error": "Error",
"success": "Success"
},
"language": {
"zh": "Chinese (Simplified)",
"en": "English",
"ja": "Japanese",
"ko": "Korean",
"fr": "French",
"de": "German",
"es": "Spanish"
},
"timezone": {
"utc": "UTC",
"asia_shanghai": "Asia/Shanghai",
"asia_tokyo": "Asia/Tokyo",
"america_new_york": "America/New_York",
"europe_london": "Europe/London"
},
"date_format": {
"short": "Short",
"medium": "Medium",
"long": "Long",
"full": "Full"
},
"sort_order": {
"asc": "Ascending",
"desc": "Descending"
},
"filter_operator": {
"equals": "Equals",
"not_equals": "Not Equals",
"contains": "Contains",
"not_contains": "Not Contains",
"starts_with": "Starts With",
"ends_with": "Ends With",
"greater_than": "Greater Than",
"less_than": "Less Than",
"greater_or_equal": "Greater or Equal",
"less_or_equal": "Less or Equal",
"in": "In",
"not_in": "Not In",
"is_null": "Is Null",
"is_not_null": "Is Not Null"
},
"log_level": {
"debug": "Debug",
"info": "Info",
"warning": "Warning",
"error": "Error",
"critical": "Critical"
},
"api_method": {
"get": "GET",
"post": "POST",
"put": "PUT",
"patch": "PATCH",
"delete": "DELETE"
}
}

View File

@@ -0,0 +1,138 @@
{
"common": {
"internal_error": "Internal server error",
"network_error": "Network connection error",
"timeout": "Request timeout",
"service_unavailable": "Service temporarily unavailable",
"bad_request": "Bad request parameters",
"unauthorized": "Unauthorized access",
"forbidden": "Access forbidden",
"not_found": "Resource not found",
"method_not_allowed": "Method not allowed",
"conflict": "Resource conflict",
"too_many_requests": "Too many requests, please try again later",
"validation_failed": "Validation failed",
"database_error": "Database operation failed",
"file_operation_error": "File operation failed"
},
"auth": {
"invalid_credentials": "Invalid username or password",
"token_expired": "Session expired, please login again",
"token_invalid": "Invalid authentication token",
"token_missing": "Authentication token missing",
"unauthorized": "Unauthorized access",
"forbidden": "Permission denied",
"account_locked": "Account has been locked",
"account_disabled": "Account has been disabled",
"account_not_verified": "Account not verified",
"password_incorrect": "Incorrect password",
"password_too_weak": "Password is too weak",
"password_expired": "Password expired, please change it",
"email_not_verified": "Email not verified",
"phone_not_verified": "Phone number not verified",
"verification_code_invalid": "Invalid verification code",
"verification_code_expired": "Verification code expired",
"login_failed": "Login failed",
"logout_failed": "Logout failed",
"session_expired": "Session expired",
"already_logged_in": "Already logged in",
"not_logged_in": "Not logged in"
},
"user": {
"not_found": "User not found",
"already_exists": "User already exists",
"email_already_exists": "Email already in use",
"phone_already_exists": "Phone number already in use",
"username_already_exists": "Username already taken",
"invalid_email": "Invalid email format",
"invalid_phone": "Invalid phone number format",
"invalid_username": "Invalid username format",
"create_failed": "Failed to create user",
"update_failed": "Failed to update user",
"delete_failed": "Failed to delete user",
"cannot_delete_self": "Cannot delete yourself",
"cannot_update_self_role": "Cannot update your own role",
"profile_update_failed": "Failed to update profile",
"avatar_upload_failed": "Failed to upload avatar",
"password_change_failed": "Failed to change password",
"old_password_incorrect": "Old password is incorrect"
},
"workspace": {
"not_found": "Workspace not found",
"already_exists": "Workspace already exists",
"name_required": "Workspace name is required",
"name_too_long": "Workspace name is too long",
"create_failed": "Failed to create workspace",
"update_failed": "Failed to update workspace",
"delete_failed": "Failed to delete workspace",
"permission_denied": "Permission denied to access this workspace",
"not_member": "Not a workspace member",
"already_member": "Already a workspace member",
"member_limit_reached": "Member limit reached",
"cannot_leave_last_manager": "Cannot leave, you are the last manager",
"cannot_remove_last_manager": "Cannot remove the last manager",
"cannot_remove_self": "Cannot remove yourself",
"invite_not_found": "Invite not found",
"invite_expired": "Invite has expired",
"invite_already_accepted": "Invite already accepted",
"invite_already_revoked": "Invite already revoked",
"invite_send_failed": "Failed to send invite",
"archived": "Workspace is archived",
"suspended": "Workspace is suspended"
},
"tenant": {
"not_found": "Tenant not found",
"already_exists": "Tenant already exists",
"create_failed": "Failed to create tenant",
"update_failed": "Failed to update tenant",
"delete_failed": "Failed to delete tenant",
"suspended": "Tenant is suspended",
"expired": "Tenant has expired",
"license_invalid": "Invalid license",
"license_expired": "License has expired",
"quota_exceeded": "Quota exceeded"
},
"file": {
"not_found": "File not found",
"upload_failed": "File upload failed",
"download_failed": "File download failed",
"delete_failed": "File deletion failed",
"too_large": "File size exceeds limit",
"invalid_type": "Unsupported file type",
"invalid_format": "Invalid file format",
"corrupted": "File is corrupted",
"storage_full": "Storage is full",
"access_denied": "Access denied to this file"
},
"api": {
"rate_limit_exceeded": "API rate limit exceeded",
"quota_exceeded": "API quota exceeded",
"invalid_api_key": "Invalid API key",
"api_key_expired": "API key has expired",
"api_key_revoked": "API key has been revoked",
"endpoint_not_found": "API endpoint not found",
"method_not_allowed": "Method not allowed",
"invalid_request": "Invalid request",
"missing_parameter": "Missing required parameter: {param}",
"invalid_parameter": "Invalid parameter: {param}"
},
"database": {
"connection_failed": "Database connection failed",
"query_failed": "Database query failed",
"transaction_failed": "Database transaction failed",
"constraint_violation": "Data constraint violation",
"duplicate_key": "Duplicate data",
"foreign_key_violation": "Foreign key constraint violation",
"deadlock": "Database deadlock"
},
"validation": {
"invalid_input": "Invalid input data",
"missing_field": "Missing required field: {field}",
"invalid_field": "Invalid field: {field}",
"field_too_long": "Field too long: {field}",
"field_too_short": "Field too short: {field}",
"invalid_format": "Invalid format: {field}",
"invalid_value": "Invalid value: {field}",
"out_of_range": "Value out of range: {field}"
}
}

Some files were not shown because too many files have changed in this diff Show More