Merge branch 'release/v0.2.10'
This commit is contained in:
@@ -14,7 +14,6 @@ from . import (
|
||||
document_controller,
|
||||
emotion_config_controller,
|
||||
emotion_controller,
|
||||
end_user_controller,
|
||||
file_controller,
|
||||
file_storage_controller,
|
||||
home_page_controller,
|
||||
@@ -99,6 +98,5 @@ manager_router.include_router(file_storage_controller.router)
|
||||
manager_router.include_router(ontology_controller.router)
|
||||
manager_router.include_router(skill_controller.router)
|
||||
manager_router.include_router(i18n_controller.router)
|
||||
manager_router.include_router(end_user_controller.router)
|
||||
|
||||
__all__ = ["manager_router"]
|
||||
|
||||
@@ -292,10 +292,19 @@ def get_opening(
|
||||
):
|
||||
"""返回开场白文本和预设问题,供前端对话界面初始化时展示"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
|
||||
features = cfg.features or {}
|
||||
if hasattr(features, "model_dump"):
|
||||
features = features.model_dump()
|
||||
|
||||
# 根据应用类型获取 features
|
||||
from app.models.app_model import App as AppModel
|
||||
app = db.get(AppModel, app_id)
|
||||
if app and app.type == "workflow":
|
||||
cfg = app_service.get_workflow_config(db=db, app_id=app_id, workspace_id=workspace_id)
|
||||
features = cfg.features or {}
|
||||
else:
|
||||
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
|
||||
features = cfg.features or {}
|
||||
if hasattr(features, "model_dump"):
|
||||
features = features.model_dump()
|
||||
|
||||
opening = features.get("opening_statement", {})
|
||||
return success(data=app_schema.OpeningResponse(
|
||||
enabled=opening.get("enabled", False),
|
||||
@@ -1070,6 +1079,14 @@ async def update_workflow_config(
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if payload.variables:
|
||||
from app.services.workflow_service import WorkflowService
|
||||
resolved = await WorkflowService(db)._resolve_variables_file_defaults(
|
||||
[v.model_dump() for v in payload.variables]
|
||||
)
|
||||
# Patch default values back into VariableDefinition objects
|
||||
for var_def, resolved_def in zip(payload.variables, resolved):
|
||||
var_def.default = resolved_def.get("default", var_def.default)
|
||||
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
||||
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
||||
|
||||
|
||||
@@ -53,22 +53,24 @@ async def login_for_access_token(
|
||||
user = auth_service.authenticate_user_or_raise(db, form_data.email, form_data.password)
|
||||
auth_logger.info(f"用户认证成功: {user.email} (ID: {user.id})")
|
||||
if form_data.invite:
|
||||
auth_service.bind_workspace_with_invite(db=db,
|
||||
user=user,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id)
|
||||
auth_service.bind_workspace_with_invite(
|
||||
db=db,
|
||||
user=user,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id
|
||||
)
|
||||
except BusinessException as e:
|
||||
# 用户不存在且有邀请码,尝试注册
|
||||
if e.code == BizCode.USER_NOT_FOUND:
|
||||
auth_logger.info(f"用户不存在,使用邀请码注册: {form_data.email}")
|
||||
user = auth_service.register_user_with_invite(
|
||||
db=db,
|
||||
email=form_data.email,
|
||||
username=form_data.username,
|
||||
password=form_data.password,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id
|
||||
)
|
||||
db=db,
|
||||
email=form_data.email,
|
||||
username=form_data.username,
|
||||
password=form_data.password,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id
|
||||
)
|
||||
elif e.code == BizCode.PASSWORD_ERROR:
|
||||
# 用户存在但密码错误
|
||||
auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}")
|
||||
|
||||
@@ -314,8 +314,10 @@ async def parse_documents(
|
||||
)
|
||||
|
||||
# 4. Check if the file exists
|
||||
api_logger.debug(f"Constructed file path: {file_path}")
|
||||
api_logger.debug(f"File metadata - kb_id: {db_file.kb_id}, parent_id: {db_file.parent_id}, file_id: {db_file.id}, extension: {db_file.file_ext}")
|
||||
if not os.path.exists(file_path):
|
||||
api_logger.warning(f"File not found (possibly deleted): file_path={file_path}")
|
||||
api_logger.error(f"File not found (possibly deleted): file_path={file_path}, file_id={db_file.id}, document_id={document_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="File not found (possibly deleted)"
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
"""End User 管理接口 - 无需认证"""
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.schemas.memory_api_schema import (
|
||||
CreateEndUserRequest,
|
||||
CreateEndUserResponse,
|
||||
)
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
router = APIRouter(prefix="/end_users", tags=["End Users"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_end_user(
|
||||
data: CreateEndUserRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Create an end user.
|
||||
|
||||
Creates a new end user for the given workspace.
|
||||
If an end user with the same other_id already exists in the workspace,
|
||||
returns the existing one.
|
||||
"""
|
||||
logger.info(f"Create end user request - other_id: {data.other_id}, workspace_id: {data.workspace_id}")
|
||||
|
||||
end_user_repo = EndUserRepository(db)
|
||||
end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=None,
|
||||
workspace_id=data.workspace_id,
|
||||
other_id=data.other_id,
|
||||
)
|
||||
|
||||
logger.info(f"End user ready: {end_user.id}")
|
||||
|
||||
result = {
|
||||
"id": str(end_user.id),
|
||||
"other_id": end_user.other_id or "",
|
||||
"other_name": end_user.other_name or "",
|
||||
"workspace_id": str(end_user.workspace_id),
|
||||
}
|
||||
|
||||
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")
|
||||
@@ -3,9 +3,10 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.db import get_db, SessionLocal
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.repositories.home_page_repository import HomePageRepository
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.home_page_service import HomePageService
|
||||
|
||||
@@ -31,9 +32,32 @@ def get_workspace_list(
|
||||
|
||||
@router.get("/version", response_model=ApiResponse)
|
||||
def get_system_version():
|
||||
"""获取系统版本号+说明"""
|
||||
current_version = settings.SYSTEM_VERSION
|
||||
version_info = HomePageService.load_version_introduction(current_version)
|
||||
"""获取系统版本号 + 说明"""
|
||||
current_version = None
|
||||
version_info = None
|
||||
|
||||
# 1️⃣ 优先从数据库获取最新已发布的版本
|
||||
try:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
current_version, version_info = HomePageRepository.get_latest_version_introduction(db)
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# 2️⃣ 降级:使用环境变量中的版本号
|
||||
if not current_version:
|
||||
current_version = settings.SYSTEM_VERSION
|
||||
version_info = HomePageService.load_version_introduction(current_version)
|
||||
|
||||
# 3️⃣ 如果数据库和 JSON 都没有,返回基本信息
|
||||
if not version_info:
|
||||
version_info = {
|
||||
"introduction": {"codeName": "", "releaseDate": "", "upgradePosition": "", "coreUpgrades": []},
|
||||
"introduction_en": {"codeName": "", "releaseDate": "", "upgradePosition": "", "coreUpgrades": []}
|
||||
}
|
||||
|
||||
return success(
|
||||
data={
|
||||
"version": current_version,
|
||||
|
||||
@@ -352,6 +352,7 @@ async def delete_knowledge(
|
||||
# 2. Soft-delete knowledge base
|
||||
api_logger.debug(f"Perform a soft delete: {db_knowledge.name} (ID: {knowledge_id})")
|
||||
db_knowledge.status = 2
|
||||
db_knowledge.updated_at = datetime.datetime.now()
|
||||
db.commit()
|
||||
api_logger.info(f"The knowledge base has been successfully deleted: {db_knowledge.name} (ID: {knowledge_id})")
|
||||
return success(msg="The knowledge base has been successfully deleted")
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -47,64 +49,64 @@ def get_workspace_total_end_users(
|
||||
|
||||
@router.get("/end_users", response_model=ApiResponse)
|
||||
async def get_workspace_end_users(
|
||||
workspace_id: Optional[uuid.UUID] = Query(None, description="工作空间ID(可选,默认当前用户工作空间)"),
|
||||
keyword: Optional[str] = Query(None, description="搜索关键词(同时模糊匹配 other_name 和 id)"),
|
||||
page: int = Query(1, ge=1, description="页码,从1开始"),
|
||||
pagesize: int = Query(10, ge=1, description="每页数量"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取工作空间的宿主列表(高性能优化版本 v2)
|
||||
|
||||
优化策略:
|
||||
1. 批量查询 end_users(一次查询而非循环)
|
||||
2. 并发查询所有用户的记忆数量(Neo4j)
|
||||
3. RAG 模式使用批量查询(一次 SQL)
|
||||
4. 只返回必要字段减少数据传输
|
||||
5. 添加短期缓存减少重复查询
|
||||
6. 并发执行配置查询和记忆数量查询
|
||||
|
||||
返回格式:
|
||||
{
|
||||
"end_user": {"id": "uuid", "other_name": "名称"},
|
||||
"memory_num": {"total": 数量},
|
||||
"memory_config": {"memory_config_id": "id", "memory_config_name": "名称"}
|
||||
}
|
||||
获取工作空间的宿主列表(分页查询,支持模糊搜索)
|
||||
|
||||
返回工作空间下的宿主列表,支持分页查询和模糊搜索。
|
||||
通过 keyword 参数同时模糊匹配 other_name 和 id 字段。
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID(可选,默认当前用户工作空间)
|
||||
keyword: 搜索关键词(可选,同时模糊匹配 other_name 和 id)
|
||||
page: 页码(从1开始,默认1)
|
||||
pagesize: 每页数量(默认10)
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含宿主列表和分页信息
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
from app.aioRedis import aio_redis_get, aio_redis_set
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 尝试从缓存获取(30秒缓存)
|
||||
cache_key = f"end_users:workspace:{workspace_id}"
|
||||
try:
|
||||
cached_data = await aio_redis_get(cache_key)
|
||||
if cached_data:
|
||||
api_logger.info(f"从缓存获取宿主列表: workspace_id={workspace_id}")
|
||||
return success(data=json.loads(cached_data), msg="宿主列表获取成功")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存读取失败: {str(e)}")
|
||||
|
||||
# 如果未提供 workspace_id,使用当前用户的工作空间
|
||||
if workspace_id is None:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
# 获取当前空间类型
|
||||
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
|
||||
|
||||
# 获取 end_users(已优化为批量查询)
|
||||
end_users = memory_dashboard_service.get_workspace_end_users(
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表, 类型: {current_workspace_type}")
|
||||
|
||||
# 获取分页的 end_users
|
||||
end_users_result = memory_dashboard_service.get_workspace_end_users_paginated(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
current_user=current_user
|
||||
current_user=current_user,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
keyword=keyword
|
||||
)
|
||||
|
||||
end_users = end_users_result.get("items", [])
|
||||
total = end_users_result.get("total", 0)
|
||||
|
||||
if not end_users:
|
||||
api_logger.info("工作空间下没有宿主")
|
||||
# 缓存空结果,避免重复查询
|
||||
try:
|
||||
await aio_redis_set(cache_key, json.dumps([]), expire=30)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||
return success(data=[], msg="宿主列表获取成功")
|
||||
|
||||
api_logger.info(f"工作空间下没有宿主或当前页无数据: total={total}, page={page}")
|
||||
return success(data={
|
||||
"items": [],
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": total,
|
||||
"hasnext": (page * pagesize) < total
|
||||
}
|
||||
}, msg="宿主列表获取成功")
|
||||
|
||||
end_user_ids = [str(user.id) for user in end_users]
|
||||
|
||||
|
||||
# 并发执行两个独立的查询任务
|
||||
async def get_memory_configs():
|
||||
"""获取记忆配置(在线程池中执行同步查询)"""
|
||||
@@ -116,7 +118,7 @@ async def get_workspace_end_users(
|
||||
except Exception as e:
|
||||
api_logger.error(f"批量获取记忆配置失败: {str(e)}")
|
||||
return {}
|
||||
|
||||
|
||||
async def get_memory_nums():
|
||||
"""获取记忆数量"""
|
||||
if current_workspace_type == "rag":
|
||||
@@ -130,26 +132,18 @@ async def get_workspace_end_users(
|
||||
except Exception as e:
|
||||
api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}")
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
|
||||
elif current_workspace_type == "neo4j":
|
||||
# Neo4j 模式:并发查询(带并发限制)
|
||||
# 使用信号量限制并发数,避免大量用户时压垮 Neo4j
|
||||
MAX_CONCURRENT_QUERIES = 10
|
||||
semaphore = asyncio.Semaphore(MAX_CONCURRENT_QUERIES)
|
||||
|
||||
async def get_neo4j_memory_num(end_user_id: str):
|
||||
async with semaphore:
|
||||
try:
|
||||
return await memory_storage_service.search_all(end_user_id)
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取用户 {end_user_id} Neo4j 记忆数量失败: {str(e)}")
|
||||
return {"total": 0}
|
||||
|
||||
memory_nums_list = await asyncio.gather(*[get_neo4j_memory_num(uid) for uid in end_user_ids])
|
||||
return {end_user_ids[i]: memory_nums_list[i] for i in range(len(end_user_ids))}
|
||||
|
||||
# Neo4j 模式:批量查询(简化版本,只返回total)
|
||||
try:
|
||||
batch_result = await memory_storage_service.search_all_batch(end_user_ids)
|
||||
return {uid: {"total": count} for uid, count in batch_result.items()}
|
||||
except Exception as e:
|
||||
api_logger.error(f"批量获取 Neo4j 记忆数量失败: {str(e)}")
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
|
||||
# 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据
|
||||
try:
|
||||
from app.celery_app import celery_app as _celery_app
|
||||
@@ -170,13 +164,13 @@ async def get_workspace_end_users(
|
||||
get_memory_configs(),
|
||||
get_memory_nums()
|
||||
)
|
||||
|
||||
# 构建结果(优化:使用列表推导式)
|
||||
result = []
|
||||
|
||||
# 构建结果列表
|
||||
items = []
|
||||
for end_user in end_users:
|
||||
user_id = str(end_user.id)
|
||||
config_info = memory_configs_map.get(user_id, {})
|
||||
result.append({
|
||||
items.append({
|
||||
'end_user': {
|
||||
'id': user_id,
|
||||
'other_name': end_user.other_name
|
||||
@@ -187,12 +181,6 @@ async def get_workspace_end_users(
|
||||
"memory_config_name": config_info.get("memory_config_name")
|
||||
}
|
||||
})
|
||||
|
||||
# 写入缓存(30秒过期)
|
||||
try:
|
||||
await aio_redis_set(cache_key, json.dumps(result), expire=30)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||
|
||||
# 触发社区聚类补全任务(异步,不阻塞接口响应)
|
||||
try:
|
||||
@@ -202,7 +190,18 @@ async def get_workspace_end_users(
|
||||
except Exception as e:
|
||||
api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}")
|
||||
|
||||
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
||||
# 构建分页响应
|
||||
result = {
|
||||
"items": items,
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": total,
|
||||
"hasnext": (page * pagesize) < total
|
||||
}
|
||||
}
|
||||
|
||||
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录,总计 {total} 条")
|
||||
return success(data=result, msg="宿主列表获取成功")
|
||||
|
||||
|
||||
@@ -592,7 +591,7 @@ async def dashboard_data(
|
||||
"total_api_call": None
|
||||
}
|
||||
|
||||
# 1. 获取记忆总量(total_memory)
|
||||
# 1. 获取记忆总量(total_memory)—— neo4j 独有逻辑:查询 neo4j 存储节点
|
||||
try:
|
||||
total_memory_data = await memory_dashboard_service.get_workspace_total_memory_count(
|
||||
db=db,
|
||||
@@ -601,49 +600,33 @@ async def dashboard_data(
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
neo4j_data["total_memory"] = total_memory_data.get("total_memory_count", 0)
|
||||
# total_app: 统计当前空间下的所有app数量
|
||||
# 包含自有app + 被分享给本工作空间的app
|
||||
from app.services import app_service as _app_svc
|
||||
_, total_app = _app_svc.AppService(db).list_apps(
|
||||
workspace_id=workspace_id, include_shared=True, pagesize=1
|
||||
)
|
||||
neo4j_data["total_app"] = total_app
|
||||
api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}, 应用数量: {neo4j_data['total_app']}")
|
||||
api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取记忆总量失败: {str(e)}")
|
||||
|
||||
# 2. 获取知识库类型统计(total_knowledge)
|
||||
try:
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
memory_agent_service = MemoryAgentService()
|
||||
knowledge_stats = await memory_agent_service.get_knowledge_type_stats(
|
||||
end_user_id=end_user_id,
|
||||
only_active=True,
|
||||
current_workspace_id=workspace_id,
|
||||
db=db
|
||||
)
|
||||
neo4j_data["total_knowledge"] = knowledge_stats.get("total", 0)
|
||||
api_logger.info(f"成功获取知识库类型统计total: {neo4j_data['total_knowledge']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取知识库类型统计失败: {str(e)}")
|
||||
# 2. 获取共享统计数据(total_app、total_knowledge、total_api_call)
|
||||
common_stats = memory_dashboard_service.get_dashboard_common_stats(db, workspace_id)
|
||||
neo4j_data.update(common_stats)
|
||||
api_logger.info(f"成功获取共享统计: app={common_stats['total_app']}, knowledge={common_stats['total_knowledge']}, api_call={common_stats['total_api_call']}")
|
||||
|
||||
# 3. 获取API调用统计(total_api_call)
|
||||
# 计算昨日对比
|
||||
try:
|
||||
# 使用 AppStatisticsService 获取真实的API调用统计
|
||||
app_stats_service = AppStatisticsService(db)
|
||||
api_stats = app_stats_service.get_workspace_api_statistics(
|
||||
changes = memory_dashboard_service.get_dashboard_yesterday_changes(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
storage_type=storage_type,
|
||||
today_data=neo4j_data
|
||||
)
|
||||
# 计算总调用次数
|
||||
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
|
||||
neo4j_data["total_api_call"] = total_api_calls
|
||||
api_logger.info(f"成功获取API调用统计: {neo4j_data['total_api_call']}")
|
||||
neo4j_data.update(changes)
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取API调用统计失败: {str(e)}")
|
||||
neo4j_data["total_api_call"] = 0
|
||||
|
||||
api_logger.warning(f"计算neo4j昨日对比失败: {str(e)}")
|
||||
neo4j_data.update({
|
||||
"total_memory_change": None,
|
||||
"total_app_change": None,
|
||||
"total_knowledge_change": None,
|
||||
"total_api_call_change": None,
|
||||
})
|
||||
|
||||
result["neo4j_data"] = neo4j_data
|
||||
api_logger.info("成功获取neo4j_data")
|
||||
|
||||
@@ -656,44 +639,37 @@ async def dashboard_data(
|
||||
"total_api_call": None
|
||||
}
|
||||
|
||||
# 获取RAG相关数据
|
||||
# 1. 获取记忆总量(total_memory)—— rag 独有逻辑:查询 document 表的 chunk_num
|
||||
try:
|
||||
# total_memory: 只统计用户知识库(permission_id='Memory')的chunk数
|
||||
total_chunk = memory_dashboard_service.get_rag_user_kb_total_chunk(db, current_user)
|
||||
rag_data["total_memory"] = total_chunk
|
||||
|
||||
# total_app: 统计当前空间下的所有app数量
|
||||
# 包含自有app + 被分享给本工作空间的app
|
||||
from app.services import app_service as _app_svc
|
||||
_, total_app = _app_svc.AppService(db).list_apps(
|
||||
workspace_id=workspace_id, include_shared=True, pagesize=1
|
||||
)
|
||||
rag_data["total_app"] = total_app
|
||||
|
||||
# total_knowledge: 使用 total_kb(总知识库数)
|
||||
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
|
||||
rag_data["total_knowledge"] = total_kb
|
||||
|
||||
# total_api_call: 使用 AppStatisticsService 获取真实的API调用统计
|
||||
try:
|
||||
app_stats_service = AppStatisticsService(db)
|
||||
api_stats = app_stats_service.get_workspace_api_statistics(
|
||||
workspace_id=workspace_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
# 计算总调用次数
|
||||
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
|
||||
rag_data["total_api_call"] = total_api_calls
|
||||
api_logger.info(f"成功获取RAG模式API调用统计: {rag_data['total_api_call']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取RAG模式API调用统计失败,使用默认值: {str(e)}")
|
||||
rag_data["total_api_call"] = 0
|
||||
|
||||
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={total_app}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}")
|
||||
api_logger.info(f"成功获取RAG记忆总量: {total_chunk}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取RAG相关数据失败: {str(e)}")
|
||||
api_logger.warning(f"获取RAG记忆总量失败: {str(e)}")
|
||||
|
||||
# 2. 获取共享统计数据(total_app、total_knowledge、total_api_call)
|
||||
common_stats = memory_dashboard_service.get_dashboard_common_stats(db, workspace_id)
|
||||
rag_data.update(common_stats)
|
||||
api_logger.info(f"成功获取共享统计: app={common_stats['total_app']}, knowledge={common_stats['total_knowledge']}, api_call={common_stats['total_api_call']}")
|
||||
|
||||
# 计算昨日对比
|
||||
try:
|
||||
changes = memory_dashboard_service.get_dashboard_yesterday_changes(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
storage_type=storage_type,
|
||||
today_data=rag_data
|
||||
)
|
||||
rag_data.update(changes)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"计算RAG昨日对比失败: {str(e)}")
|
||||
rag_data.update({
|
||||
"total_memory_change": None,
|
||||
"total_app_change": None,
|
||||
"total_knowledge_change": None,
|
||||
"total_api_call_change": None,
|
||||
})
|
||||
|
||||
result["rag_data"] = rag_data
|
||||
api_logger.info("成功获取rag_data")
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ from app.services.memory_storage_service import (
|
||||
analytics_hot_memory_tags,
|
||||
analytics_recent_activity_stats,
|
||||
kb_type_distribution,
|
||||
search_all,
|
||||
search_all_batch,
|
||||
search_chunk,
|
||||
search_detials,
|
||||
search_dialogue,
|
||||
@@ -409,7 +409,10 @@ async def search_all_num(
|
||||
) -> dict:
|
||||
api_logger.info(f"Search all requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_all(end_user_id)
|
||||
if not end_user_id:
|
||||
return success(data={"total": 0}, msg="查询成功")
|
||||
batch_result = await search_all_batch([end_user_id])
|
||||
result = {"total": batch_result.get(end_user_id, 0)}
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Search all failed: {str(e)}")
|
||||
|
||||
@@ -163,6 +163,7 @@ def _get_ontology_service(
|
||||
api_key=api_key_config.api_key,
|
||||
base_url=api_key_config.api_base,
|
||||
is_omni=api_key_config.is_omni,
|
||||
support_thinking="thinking" in (api_key_config.capability or []),
|
||||
max_retries=3,
|
||||
timeout=60.0
|
||||
)
|
||||
|
||||
@@ -453,31 +453,10 @@ async def chat(
|
||||
# 流式返回
|
||||
agent_config = agent_config_4_app_release(release)
|
||||
|
||||
if payload.stream:
|
||||
# async def event_generator():
|
||||
# async for event in service.chat_stream(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# ):
|
||||
# yield event
|
||||
if not (agent_config.model_parameters.get("deep_thinking", False) and payload.thinking):
|
||||
agent_config.model_parameters["deep_thinking"] = False
|
||||
|
||||
# return StreamingResponse(
|
||||
# event_generator(),
|
||||
# media_type="text/event-stream",
|
||||
# headers={
|
||||
# "Cache-Control": "no-cache",
|
||||
# "Connection": "keep-alive",
|
||||
# "X-Accel-Buffering": "no"
|
||||
# }
|
||||
# )
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.agnet_chat_stream(
|
||||
message=payload.message,
|
||||
@@ -503,20 +482,6 @@ async def chat(
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
# 非流式返回
|
||||
# result = await service.chat(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# )
|
||||
# return success(data=conversation_schema.ChatResponse(**result))
|
||||
result = await app_chat_service.agnet_chat(
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
@@ -575,48 +540,6 @@ async def chat(
|
||||
)
|
||||
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
# 多 Agent 流式返回
|
||||
# if payload.stream:
|
||||
# async def event_generator():
|
||||
# async for event in service.multi_agent_chat_stream(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# ):
|
||||
# yield event
|
||||
|
||||
# return StreamingResponse(
|
||||
# event_generator(),
|
||||
# media_type="text/event-stream",
|
||||
# headers={
|
||||
# "Cache-Control": "no-cache",
|
||||
# "Connection": "keep-alive",
|
||||
# "X-Accel-Buffering": "no"
|
||||
# }
|
||||
# )
|
||||
|
||||
# # 多 Agent 非流式返回
|
||||
# result = await service.multi_agent_chat(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# )
|
||||
|
||||
# return success(data=conversation_schema.ChatResponse(**result))
|
||||
elif app_type == AppType.WORKFLOW:
|
||||
config = workflow_config_4_app_release(release)
|
||||
if not config.id:
|
||||
@@ -714,7 +637,8 @@ async def config_query(
|
||||
"app_type": release.app.type,
|
||||
"variables": release.config.get("variables"),
|
||||
"memory": release.config.get("memory", {}).get("enabled"),
|
||||
"features": release.config.get("features")
|
||||
"features": release.config.get("features"),
|
||||
"model_parameters": release.config.get("model_parameters")
|
||||
}
|
||||
elif release.app.type == AppType.MULTI_AGENT:
|
||||
content = {
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
认证方式: API Key
|
||||
"""
|
||||
from fastapi import APIRouter
|
||||
from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller
|
||||
from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller, end_user_api_controller
|
||||
|
||||
# 创建 V1 API 路由器
|
||||
service_router = APIRouter()
|
||||
@@ -16,5 +16,6 @@ service_router.include_router(rag_api_document_controller.router)
|
||||
service_router.include_router(rag_api_file_controller.router)
|
||||
service_router.include_router(rag_api_chunk_controller.router)
|
||||
service_router.include_router(memory_api_controller.router)
|
||||
service_router.include_router(end_user_api_controller.router)
|
||||
|
||||
__all__ = ["service_router"]
|
||||
|
||||
@@ -144,6 +144,11 @@ async def chat(
|
||||
# print(app.current_release.default_model_config_id)
|
||||
agent_config = agent_config_4_app_release(app.current_release)
|
||||
# print(agent_config.default_model_config_id)
|
||||
|
||||
# thinking 开关:仅当 agent 配置了 deep_thinking 且请求 thinking=True 时才启用
|
||||
if not (agent_config.model_parameters.get("deep_thinking", False) and payload.thinking):
|
||||
agent_config.model_parameters["deep_thinking"] = False
|
||||
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
|
||||
92
api/app/controllers/service/end_user_api_controller.py
Normal file
92
api/app/controllers/service/end_user_api_controller.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""End User 服务接口 - 基于 API Key 认证"""
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.api_key_auth import require_api_key
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.schemas.memory_api_schema import CreateEndUserRequest, CreateEndUserResponse
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
router = APIRouter(prefix="/end_user", tags=["V1 - End User API"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@router.post("/create")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def create_end_user(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(..., description="Request body"),
|
||||
):
|
||||
"""
|
||||
Create or retrieve an end user for the workspace.
|
||||
|
||||
Creates a new end user and connects it to a memory configuration.
|
||||
If an end user with the same other_id already exists in the workspace,
|
||||
returns the existing one.
|
||||
|
||||
Optionally accepts a memory_config_id to connect the end user to a specific
|
||||
memory configuration. If not provided, falls back to the workspace default config.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = CreateEndUserRequest(**body)
|
||||
workspace_id = api_key_auth.workspace_id
|
||||
|
||||
logger.info("Create end user request - other_id: %s, workspace_id: %s", payload.other_id, workspace_id)
|
||||
|
||||
# Resolve memory_config_id: explicit > workspace default
|
||||
memory_config_id = None
|
||||
config_service = MemoryConfigService(db)
|
||||
|
||||
if payload.memory_config_id:
|
||||
try:
|
||||
memory_config_id = uuid.UUID(payload.memory_config_id)
|
||||
except ValueError:
|
||||
raise BusinessException(
|
||||
f"Invalid memory_config_id format: {payload.memory_config_id}",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
config = config_service.get_config_with_fallback(memory_config_id, workspace_id)
|
||||
if not config:
|
||||
raise BusinessException(
|
||||
f"Memory config not found: {payload.memory_config_id}",
|
||||
BizCode.MEMORY_CONFIG_NOT_FOUND
|
||||
)
|
||||
memory_config_id = config.config_id
|
||||
else:
|
||||
default_config = config_service.get_workspace_default_config(workspace_id)
|
||||
if default_config:
|
||||
memory_config_id = default_config.config_id
|
||||
logger.info(f"Using workspace default memory config: {memory_config_id}")
|
||||
else:
|
||||
logger.warning(f"No default memory config found for workspace: {workspace_id}")
|
||||
|
||||
end_user_repo = EndUserRepository(db)
|
||||
end_user = end_user_repo.get_or_create_end_user_with_config(
|
||||
app_id=api_key_auth.resource_id,
|
||||
workspace_id=workspace_id,
|
||||
other_id=payload.other_id,
|
||||
memory_config_id=memory_config_id,
|
||||
)
|
||||
|
||||
logger.info(f"End user ready: {end_user.id}")
|
||||
|
||||
result = {
|
||||
"id": str(end_user.id),
|
||||
"other_id": end_user.other_id or "",
|
||||
"other_name": end_user.other_name or "",
|
||||
"workspace_id": str(end_user.workspace_id),
|
||||
"memory_config_id": str(end_user.memory_config_id) if end_user.memory_config_id else None,
|
||||
}
|
||||
|
||||
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")
|
||||
@@ -6,6 +6,8 @@ from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.schemas.memory_api_schema import (
|
||||
CreateEndUserRequest,
|
||||
CreateEndUserResponse,
|
||||
ListConfigsResponse,
|
||||
MemoryReadRequest,
|
||||
MemoryReadResponse,
|
||||
@@ -113,3 +115,31 @@ async def list_memory_configs(
|
||||
|
||||
logger.info(f"Listed {result['total']} configs for workspace: {api_key_auth.workspace_id}")
|
||||
return success(data=ListConfigsResponse(**result).model_dump(), msg="Configs listed successfully")
|
||||
|
||||
|
||||
@router.post("/end_users")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def create_end_user(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Create an end user.
|
||||
|
||||
Creates a new end user for the authorized workspace.
|
||||
If an end user with the same other_id already exists, returns the existing one.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = CreateEndUserRequest(**body)
|
||||
logger.info(f"Create end user request - other_id: {payload.other_id}, workspace_id: {api_key_auth.workspace_id}")
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
|
||||
result = memory_api_service.create_end_user(
|
||||
workspace_id=api_key_auth.workspace_id,
|
||||
other_id=payload.other_id,
|
||||
)
|
||||
|
||||
logger.info(f"End user ready: {result['id']}")
|
||||
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")
|
||||
|
||||
@@ -11,17 +11,14 @@ LangChain Agent 封装
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.models_model import ModelType, ModelProvider
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.errors import GraphRecursionError
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.models_model import ModelType
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -41,7 +38,10 @@ class LangChainAgent:
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
streaming: bool = False,
|
||||
max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算)
|
||||
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
|
||||
max_tool_consecutive_calls: int = 3, # 单个工具最大连续调用次数
|
||||
deep_thinking: bool = False, # 是否启用深度思考模式
|
||||
thinking_budget_tokens: Optional[int] = None, # 深度思考 token 预算
|
||||
capability: Optional[List[str]] = None # 模型能力列表,用于校验是否支持深度思考
|
||||
):
|
||||
"""初始化 LangChain Agent
|
||||
|
||||
@@ -64,6 +64,7 @@ class LangChainAgent:
|
||||
self.streaming = streaming
|
||||
self.is_omni = is_omni
|
||||
self.max_tool_consecutive_calls = max_tool_consecutive_calls
|
||||
self.deep_thinking = deep_thinking and ("thinking" in (capability or []))
|
||||
|
||||
# 工具调用计数器:记录每个工具的连续调用次数
|
||||
self.tool_call_counter: Dict[str, int] = {}
|
||||
@@ -86,6 +87,13 @@ class LangChainAgent:
|
||||
f"auto_calculated={max_iterations is None}"
|
||||
)
|
||||
|
||||
# 根据 capability 校验是否真正支持深度思考
|
||||
actual_deep_thinking = self.deep_thinking
|
||||
if deep_thinking and not actual_deep_thinking:
|
||||
logger.warning(
|
||||
f"模型 {model_name} 不支持深度思考(capability 中无 'thinking'),已自动关闭 deep_thinking"
|
||||
)
|
||||
|
||||
# 创建 RedBearLLM(支持多提供商)
|
||||
model_config = RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
@@ -93,10 +101,13 @@ class LangChainAgent:
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
is_omni=is_omni,
|
||||
deep_thinking=actual_deep_thinking,
|
||||
thinking_budget_tokens=thinking_budget_tokens if actual_deep_thinking else None,
|
||||
support_thinking="thinking" in (capability or []),
|
||||
extra_params={
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"streaming": streaming # 使用参数控制流式
|
||||
"streaming": streaming
|
||||
}
|
||||
)
|
||||
|
||||
@@ -226,10 +237,9 @@ class LangChainAgent:
|
||||
Returns:
|
||||
List[BaseMessage]: 消息列表
|
||||
"""
|
||||
messages = []
|
||||
messages:list = [SystemMessage(content=self.system_prompt)]
|
||||
|
||||
# 添加系统提示词
|
||||
messages.append(SystemMessage(content=self.system_prompt))
|
||||
|
||||
# 添加历史消息
|
||||
if history:
|
||||
@@ -254,6 +264,33 @@ class LangChainAgent:
|
||||
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def _extract_tokens_from_message(msg) -> int:
|
||||
"""从 AIMessage 或类似对象中提取 total_tokens,兼容多种 provider 格式
|
||||
|
||||
支持的格式:
|
||||
- response_metadata.token_usage.total_tokens (OpenAI/ChatOpenAI)
|
||||
- response_metadata.usage.total_tokens (部分 provider)
|
||||
- usage_metadata.total_tokens (LangChain 新版)
|
||||
"""
|
||||
total = 0
|
||||
# 1. response_metadata
|
||||
response_meta = getattr(msg, "response_metadata", None)
|
||||
if response_meta and isinstance(response_meta, dict):
|
||||
# 尝试 token_usage 路径
|
||||
token_usage = response_meta.get("token_usage") or response_meta.get("usage", {})
|
||||
if isinstance(token_usage, dict):
|
||||
total = token_usage.get("total_tokens", 0)
|
||||
# 2. usage_metadata(LangChain 新版 AIMessage 属性)
|
||||
if not total:
|
||||
usage_meta = getattr(msg, "usage_metadata", None)
|
||||
if usage_meta:
|
||||
if isinstance(usage_meta, dict):
|
||||
total = usage_meta.get("total_tokens", 0)
|
||||
else:
|
||||
total = getattr(usage_meta, "total_tokens", 0)
|
||||
return total or 0
|
||||
|
||||
def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
构建多模态消息内容
|
||||
@@ -288,17 +325,23 @@ class LangChainAgent:
|
||||
|
||||
return content_parts
|
||||
|
||||
@staticmethod
|
||||
def _extract_reasoning_content(msg) -> str:
|
||||
"""从 AIMessage 中提取深度思考内容(reasoning_content)
|
||||
|
||||
所有 provider 统一通过 additional_kwargs.reasoning_content 传递:
|
||||
- DeepSeek-R1 / QwQ: 原生字段
|
||||
- Volcano (Doubao-thinking): 由 VolcanoChatOpenAI 从 delta.reasoning_content 注入
|
||||
"""
|
||||
additional = getattr(msg, "additional_kwargs", None) or {}
|
||||
return additional.get("reasoning_content") or additional.get("reasoning", "")
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
config_id: Optional[str] = None, # 添加这个参数
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True,
|
||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||
files: Optional[List[Dict[str, Any]]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""执行对话
|
||||
|
||||
@@ -306,31 +349,12 @@ class LangChainAgent:
|
||||
message: 用户消息
|
||||
history: 历史消息列表 [{"role": "user/assistant", "content": "..."}]
|
||||
context: 上下文信息(如知识库检索结果)
|
||||
files: 多模态文件
|
||||
|
||||
Returns:
|
||||
Dict: 包含 content 和元数据的字典
|
||||
"""
|
||||
message_chat = message
|
||||
start_time = time.time()
|
||||
actual_config_id = config_id
|
||||
# If config_id is None, try to get from end_user's connected config
|
||||
if actual_config_id is None and end_user_id:
|
||||
try:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
db = next(get_db())
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
actual_config_id = connected_config.get("memory_config_id")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
try:
|
||||
# 准备消息列表(支持多模态)
|
||||
messages = self._prepare_messages(message, history, context, files)
|
||||
@@ -354,7 +378,7 @@ class LangChainAgent:
|
||||
{"messages": messages},
|
||||
config={"recursion_limit": self.max_iterations}
|
||||
)
|
||||
except RecursionError as e:
|
||||
except (RecursionError, GraphRecursionError) as e:
|
||||
logger.warning(
|
||||
f"Agent 达到最大迭代次数限制 ({self.max_iterations}),可能存在工具调用循环",
|
||||
extra={"error": str(e)}
|
||||
@@ -377,6 +401,7 @@ class LangChainAgent:
|
||||
|
||||
logger.debug(f"输出消息数量: {len(output_messages)}")
|
||||
total_tokens = 0
|
||||
reasoning_content = ""
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
logger.debug(f"找到 AI 消息,content 类型: {type(msg.content)}")
|
||||
@@ -411,16 +436,13 @@ class LangChainAgent:
|
||||
else:
|
||||
content = str(msg.content)
|
||||
logger.debug(f"转换为字符串: {content[:100]}...")
|
||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
|
||||
total_tokens = self._extract_tokens_from_message(msg)
|
||||
reasoning_content = self._extract_reasoning_content(msg) if self.deep_thinking else ""
|
||||
break
|
||||
|
||||
logger.info(f"最终提取的内容长度: {len(content)}")
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
if memory_flag:
|
||||
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id,
|
||||
actual_config_id)
|
||||
response = {
|
||||
"content": content,
|
||||
"model": self.model_name,
|
||||
@@ -431,6 +453,8 @@ class LangChainAgent:
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
}
|
||||
if reasoning_content:
|
||||
response["reasoning_content"] = reasoning_content
|
||||
|
||||
logger.debug(
|
||||
"Agent 调用完成",
|
||||
@@ -451,22 +475,20 @@ class LangChainAgent:
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
config_id: Optional[str] = None,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True,
|
||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||
) -> AsyncGenerator[str, None]:
|
||||
files: Optional[List[Dict[str, Any]]] = None
|
||||
) -> AsyncGenerator[str | int | dict[str, str], None]:
|
||||
"""执行流式对话
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
history: 历史消息列表
|
||||
context: 上下文信息
|
||||
files: 多模态文件
|
||||
|
||||
Yields:
|
||||
str: 消息内容块
|
||||
int: token 统计
|
||||
Dict: 深度思考内容 {"type": "reasoning", "content": "..."}
|
||||
"""
|
||||
logger.info("=" * 80)
|
||||
logger.info(" chat_stream 方法开始执行")
|
||||
@@ -474,23 +496,6 @@ class LangChainAgent:
|
||||
logger.info(f" Has tools: {bool(self.tools)}")
|
||||
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
|
||||
logger.info("=" * 80)
|
||||
message_chat = message
|
||||
actual_config_id = config_id
|
||||
# If config_id is None, try to get from end_user's connected config
|
||||
if actual_config_id is None and end_user_id:
|
||||
try:
|
||||
db = next(get_db())
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
actual_config_id = connected_config.get("memory_config_id")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
|
||||
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||
try:
|
||||
# 准备消息列表(支持多模态)
|
||||
messages = self._prepare_messages(message, history, context, files)
|
||||
@@ -500,17 +505,19 @@ class LangChainAgent:
|
||||
)
|
||||
|
||||
chunk_count = 0
|
||||
yielded_content = False
|
||||
|
||||
# 统一使用 agent 的 astream_events 实现流式输出
|
||||
logger.debug("使用 Agent astream_events 实现流式输出")
|
||||
full_content = ''
|
||||
full_reasoning = ''
|
||||
try:
|
||||
last_event = {}
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
version="v2",
|
||||
config={"recursion_limit": self.max_iterations}
|
||||
):
|
||||
last_event = event
|
||||
chunk_count += 1
|
||||
kind = event.get("event")
|
||||
|
||||
@@ -519,12 +526,18 @@ class LangChainAgent:
|
||||
# LLM 流式输出
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
if chunk and hasattr(chunk, "content"):
|
||||
# 提取深度思考内容(仅在启用深度思考时)
|
||||
if self.deep_thinking:
|
||||
reasoning_chunk = self._extract_reasoning_content(chunk)
|
||||
if reasoning_chunk:
|
||||
full_reasoning += reasoning_chunk
|
||||
yield {"type": "reasoning", "content": reasoning_chunk}
|
||||
|
||||
# 处理多模态响应:content 可能是字符串或列表
|
||||
chunk_content = chunk.content
|
||||
if isinstance(chunk_content, str) and chunk_content:
|
||||
full_content += chunk_content
|
||||
yield chunk_content
|
||||
yielded_content = True
|
||||
elif isinstance(chunk_content, list):
|
||||
# 多模态响应:提取文本部分
|
||||
for item in chunk_content:
|
||||
@@ -535,29 +548,32 @@ class LangChainAgent:
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||
elif item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
elif isinstance(item, str):
|
||||
full_content += item
|
||||
yield item
|
||||
yielded_content = True
|
||||
|
||||
elif kind == "on_llm_stream":
|
||||
# 另一种 LLM 流式事件
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
if chunk:
|
||||
if hasattr(chunk, "content"):
|
||||
# 提取深度思考内容(仅在启用深度思考时)
|
||||
if self.deep_thinking:
|
||||
reasoning_chunk = self._extract_reasoning_content(chunk)
|
||||
if reasoning_chunk:
|
||||
full_reasoning += reasoning_chunk
|
||||
yield {"type": "reasoning", "content": reasoning_chunk}
|
||||
|
||||
chunk_content = chunk.content
|
||||
if isinstance(chunk_content, str) and chunk_content:
|
||||
full_content += chunk_content
|
||||
yield chunk_content
|
||||
yielded_content = True
|
||||
elif isinstance(chunk_content, list):
|
||||
# 多模态响应:提取文本部分
|
||||
for item in chunk_content:
|
||||
@@ -568,22 +584,18 @@ class LangChainAgent:
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||
elif item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
elif isinstance(item, str):
|
||||
full_content += item
|
||||
yield item
|
||||
yielded_content = True
|
||||
elif isinstance(chunk, str):
|
||||
full_content += chunk
|
||||
yield chunk
|
||||
yielded_content = True
|
||||
|
||||
# 记录工具调用(可选)
|
||||
elif kind == "on_tool_start":
|
||||
@@ -593,19 +605,20 @@ class LangChainAgent:
|
||||
|
||||
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||
# 统计token消耗
|
||||
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
|
||||
output_messages = last_event.get("data", {}).get("output", {}).get("messages", [])
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||
total_tokens = response_meta.get("token_usage", {}).get(
|
||||
"total_tokens",
|
||||
0
|
||||
) if response_meta else 0
|
||||
yield total_tokens
|
||||
stream_total_tokens = self._extract_tokens_from_message(msg)
|
||||
logger.info(f"流式 token 统计: total_tokens={stream_total_tokens}")
|
||||
yield stream_total_tokens
|
||||
break
|
||||
if memory_flag:
|
||||
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,
|
||||
actual_config_id)
|
||||
|
||||
except GraphRecursionError:
|
||||
logger.warning(
|
||||
f"Agent 达到最大迭代次数限制 ({self.max_iterations}),模型可能不支持正确的工具调用停止判断"
|
||||
)
|
||||
if not full_content:
|
||||
yield "抱歉,我在处理您的请求时遇到了问题(已达最大处理步骤限制)。请尝试简化问题或更换模型后重试。"
|
||||
except Exception as e:
|
||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
@@ -19,6 +19,7 @@ class BizCode(IntEnum):
|
||||
TENANT_NOT_FOUND = 3002
|
||||
WORKSPACE_NO_ACCESS = 3003
|
||||
WORKSPACE_INVITE_NOT_FOUND = 3004
|
||||
WORKSPACE_ACCESS_DENIED = 3005
|
||||
# API Key 管理(3xxx)
|
||||
API_KEY_NOT_FOUND = 3007
|
||||
API_KEY_DUPLICATE_NAME = 3008
|
||||
@@ -113,6 +114,8 @@ HTTP_MAPPING = {
|
||||
BizCode.FORBIDDEN: 403,
|
||||
BizCode.TENANT_NOT_FOUND: 400,
|
||||
BizCode.WORKSPACE_NO_ACCESS: 403,
|
||||
BizCode.WORKSPACE_INVITE_NOT_FOUND: 400,
|
||||
BizCode.WORKSPACE_ACCESS_DENIED: 403,
|
||||
BizCode.NOT_FOUND: 400,
|
||||
BizCode.USER_NOT_FOUND: 200,
|
||||
BizCode.WORKSPACE_NOT_FOUND: 400,
|
||||
|
||||
@@ -0,0 +1,408 @@
|
||||
"""
|
||||
Perceptual Memory Retrieval Node & Service
|
||||
|
||||
Provides PerceptualSearchService for searching perceptual memories (vision, audio,
|
||||
text, conversation) from Neo4j using keyword fulltext + embedding semantic search
|
||||
with BM25+embedding fusion reranking.
|
||||
|
||||
Also provides the perceptual_retrieve_node for use as a LangGraph node.
|
||||
"""
|
||||
import asyncio
|
||||
import math
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState
|
||||
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||
from app.repositories.neo4j.graph_search import (
|
||||
search_perceptual,
|
||||
search_perceptual_by_embedding,
|
||||
)
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class PerceptualSearchService:
|
||||
"""
|
||||
感知记忆检索服务。
|
||||
|
||||
封装关键词全文检索 + 向量语义检索 + BM25/embedding 融合排序的完整流程。
|
||||
调用方只需提供 query / keywords、end_user_id、memory_config,即可获得
|
||||
格式化并排序后的感知记忆列表和拼接文本。
|
||||
|
||||
Usage:
|
||||
service = PerceptualSearchService(end_user_id=..., memory_config=...)
|
||||
results = await service.search(query="...", keywords=[...], limit=10)
|
||||
# results = {"memories": [...], "content": "...", "keyword_raw": N, "embedding_raw": M}
|
||||
"""
|
||||
|
||||
DEFAULT_ALPHA = 0.6
|
||||
DEFAULT_CONTENT_SCORE_THRESHOLD = 0.5
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
end_user_id: str,
|
||||
memory_config: Any,
|
||||
alpha: float = DEFAULT_ALPHA,
|
||||
content_score_threshold: float = DEFAULT_CONTENT_SCORE_THRESHOLD,
|
||||
):
|
||||
self.end_user_id = end_user_id
|
||||
self.memory_config = memory_config
|
||||
self.alpha = alpha
|
||||
self.content_score_threshold = content_score_threshold
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
keywords: Optional[List[str]] = None,
|
||||
limit: int = 10,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
执行感知记忆检索(关键词 + 向量并行),融合排序后返回结果。
|
||||
|
||||
对 embedding 命中但 keyword 未命中的结果,补查全文索引获取 BM25 分数,
|
||||
确保所有结果都同时具备 BM25 和 embedding 两个维度的评分。
|
||||
|
||||
Args:
|
||||
query: 原始用户查询(用于向量检索和 BM25 补查)
|
||||
keywords: 关键词列表(用于全文检索),为 None 时使用 [query]
|
||||
limit: 最大返回数量
|
||||
|
||||
Returns:
|
||||
{
|
||||
"memories": [格式化后的记忆 dict, ...],
|
||||
"content": "拼接的纯文本摘要",
|
||||
"keyword_raw": int,
|
||||
"embedding_raw": int,
|
||||
}
|
||||
"""
|
||||
if keywords is None:
|
||||
keywords = [query] if query else []
|
||||
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
kw_task = self._keyword_search(connector, keywords, limit)
|
||||
emb_task = self._embedding_search(connector, query, limit)
|
||||
|
||||
kw_results, emb_results = await asyncio.gather(
|
||||
kw_task, emb_task, return_exceptions=True
|
||||
)
|
||||
if isinstance(kw_results, Exception):
|
||||
logger.warning(f"[PerceptualSearch] keyword search error: {kw_results}")
|
||||
kw_results = []
|
||||
if isinstance(emb_results, Exception):
|
||||
logger.warning(f"[PerceptualSearch] embedding search error: {emb_results}")
|
||||
emb_results = []
|
||||
|
||||
# 补查 BM25:找出 embedding 命中但 keyword 未命中的 id,
|
||||
# 用原始 query 对这些节点补查全文索引拿 BM25 score
|
||||
kw_ids = {r.get("id") for r in kw_results if r.get("id")}
|
||||
emb_only_ids = {r.get("id") for r in emb_results if r.get("id") and r.get("id") not in kw_ids}
|
||||
|
||||
if emb_only_ids and query:
|
||||
backfill = await self._bm25_backfill(connector, query, emb_only_ids, limit)
|
||||
# 把补查到的 BM25 score 注入到 embedding 结果中
|
||||
backfill_map = {r["id"]: r.get("score", 0) for r in backfill}
|
||||
for r in emb_results:
|
||||
rid = r.get("id", "")
|
||||
if rid in backfill_map:
|
||||
r["bm25_backfill_score"] = backfill_map[rid]
|
||||
logger.info(
|
||||
f"[PerceptualSearch] BM25 backfill: {len(emb_only_ids)} embedding-only ids, "
|
||||
f"{len(backfill_map)} got BM25 scores"
|
||||
)
|
||||
|
||||
reranked = self._rerank(kw_results, emb_results, limit)
|
||||
|
||||
memories = []
|
||||
content_parts = []
|
||||
for record in reranked:
|
||||
fmt = self._format_result(record)
|
||||
fmt["score"] = round(record.get("content_score", 0), 4)
|
||||
memories.append(fmt)
|
||||
content_parts.append(self._build_content_text(fmt))
|
||||
|
||||
logger.info(
|
||||
f"[PerceptualSearch] {len(memories)} results after rerank "
|
||||
f"(keyword_raw={len(kw_results)}, embedding_raw={len(emb_results)})"
|
||||
)
|
||||
return {
|
||||
"memories": memories,
|
||||
"content": "\n\n".join(content_parts),
|
||||
"keyword_raw": len(kw_results),
|
||||
"embedding_raw": len(emb_results),
|
||||
}
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
async def _bm25_backfill(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
query: str,
|
||||
target_ids: set,
|
||||
limit: int,
|
||||
) -> List[dict]:
|
||||
"""
|
||||
对指定 id 集合补查全文索引 BM25 score。
|
||||
|
||||
用原始 query 查全文索引,只保留 id 在 target_ids 中的结果。
|
||||
"""
|
||||
escaped = escape_lucene_query(query)
|
||||
if not escaped.strip():
|
||||
return []
|
||||
try:
|
||||
r = await search_perceptual(
|
||||
connector=connector, q=escaped,
|
||||
end_user_id=self.end_user_id,
|
||||
limit=limit * 5, # 多查一些以提高命中率
|
||||
)
|
||||
all_hits = r.get("perceptuals", [])
|
||||
return [h for h in all_hits if h.get("id") in target_ids]
|
||||
except Exception as e:
|
||||
logger.warning(f"[PerceptualSearch] BM25 backfill failed: {e}")
|
||||
return []
|
||||
|
||||
async def _keyword_search(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
keywords: List[str],
|
||||
limit: int,
|
||||
) -> List[dict]:
|
||||
"""并发对每个关键词做全文检索,去重后按 score 降序返回 top N 原始结果。"""
|
||||
seen_ids: set = set()
|
||||
all_results: List[dict] = []
|
||||
|
||||
async def _one(kw: str):
|
||||
escaped = escape_lucene_query(kw)
|
||||
if not escaped.strip():
|
||||
return []
|
||||
r = await search_perceptual(
|
||||
connector=connector, q=escaped,
|
||||
end_user_id=self.end_user_id, limit=limit,
|
||||
)
|
||||
return r.get("perceptuals", [])
|
||||
|
||||
tasks = [_one(kw) for kw in keywords[:10]]
|
||||
batch = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for result in batch:
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f"[PerceptualSearch] keyword sub-query error: {result}")
|
||||
continue
|
||||
for rec in result:
|
||||
rid = rec.get("id", "")
|
||||
if rid and rid not in seen_ids:
|
||||
seen_ids.add(rid)
|
||||
all_results.append(rec)
|
||||
|
||||
all_results.sort(key=lambda x: float(x.get("score", 0)), reverse=True)
|
||||
return all_results[:limit]
|
||||
|
||||
async def _embedding_search(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
query_text: str,
|
||||
limit: int,
|
||||
) -> List[dict]:
|
||||
"""向量语义检索,返回原始结果(不做阈值过滤)。"""
|
||||
try:
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
with get_db_context() as db:
|
||||
cfg = MemoryConfigService(db).get_embedder_config(
|
||||
str(self.memory_config.embedding_model_id)
|
||||
)
|
||||
client = OpenAIEmbedderClient(RedBearModelConfig(**cfg))
|
||||
|
||||
r = await search_perceptual_by_embedding(
|
||||
connector=connector, embedder_client=client,
|
||||
query_text=query_text, end_user_id=self.end_user_id,
|
||||
limit=limit,
|
||||
)
|
||||
return r.get("perceptuals", [])
|
||||
except Exception as e:
|
||||
logger.warning(f"[PerceptualSearch] embedding search failed: {e}")
|
||||
return []
|
||||
|
||||
def _rerank(
|
||||
self,
|
||||
keyword_results: List[dict],
|
||||
embedding_results: List[dict],
|
||||
limit: int,
|
||||
) -> List[dict]:
|
||||
"""BM25 + embedding 融合排序。
|
||||
|
||||
对 embedding 结果中带有 bm25_backfill_score 的条目,
|
||||
将其与 keyword 结果合并后统一归一化,确保 BM25 分数在同一尺度上。
|
||||
"""
|
||||
# 把补查的 BM25 score 合并到 keyword_results 中统一归一化
|
||||
emb_backfill_items = []
|
||||
for item in embedding_results:
|
||||
backfill_score = item.get("bm25_backfill_score")
|
||||
if backfill_score is not None and item.get("id"):
|
||||
emb_backfill_items.append({"id": item["id"], "score": backfill_score})
|
||||
|
||||
# 合并后统一归一化 BM25 scores
|
||||
all_bm25_items = keyword_results + emb_backfill_items
|
||||
all_bm25_items = self._normalize_scores(all_bm25_items)
|
||||
|
||||
# 建立 id -> normalized BM25 score 的映射
|
||||
bm25_norm_map: Dict[str, float] = {}
|
||||
for item in all_bm25_items:
|
||||
item_id = item.get("id", "")
|
||||
if item_id:
|
||||
bm25_norm_map[item_id] = float(item.get("normalized_score", 0))
|
||||
|
||||
# 归一化 embedding scores
|
||||
embedding_results = self._normalize_scores(embedding_results)
|
||||
|
||||
# 合并
|
||||
combined: Dict[str, dict] = {}
|
||||
for item in keyword_results:
|
||||
item_id = item.get("id", "")
|
||||
if not item_id:
|
||||
continue
|
||||
combined[item_id] = item.copy()
|
||||
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
|
||||
combined[item_id]["embedding_score"] = 0.0
|
||||
|
||||
for item in embedding_results:
|
||||
item_id = item.get("id", "")
|
||||
if not item_id:
|
||||
continue
|
||||
if item_id in combined:
|
||||
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
else:
|
||||
combined[item_id] = item.copy()
|
||||
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
|
||||
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
|
||||
for item in combined.values():
|
||||
bm25 = float(item.get("bm25_score", 0) or 0)
|
||||
emb = float(item.get("embedding_score", 0) or 0)
|
||||
item["content_score"] = self.alpha * bm25 + (1 - self.alpha) * emb
|
||||
|
||||
results = list(combined.values())
|
||||
before = len(results)
|
||||
results = [r for r in results if r["content_score"] >= self.content_score_threshold]
|
||||
results.sort(key=lambda x: x["content_score"], reverse=True)
|
||||
results = results[:limit]
|
||||
|
||||
logger.info(
|
||||
f"[PerceptualSearch] rerank: merged={before}, after_threshold={len(results)} "
|
||||
f"(alpha={self.alpha}, threshold={self.content_score_threshold})"
|
||||
)
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _normalize_scores(items: List[dict], field: str = "score") -> List[dict]:
|
||||
"""Z-score + sigmoid 归一化。"""
|
||||
if not items:
|
||||
return items
|
||||
scores = [float(it.get(field, 0) or 0) for it in items]
|
||||
if len(scores) <= 1:
|
||||
for it in items:
|
||||
it[f"normalized_{field}"] = 1.0
|
||||
return items
|
||||
mean = sum(scores) / len(scores)
|
||||
var = sum((s - mean) ** 2 for s in scores) / len(scores)
|
||||
std = math.sqrt(var)
|
||||
if std == 0:
|
||||
for it in items:
|
||||
it[f"normalized_{field}"] = 1.0
|
||||
else:
|
||||
for it, s in zip(items, scores):
|
||||
z = (s - mean) / std
|
||||
it[f"normalized_{field}"] = 1 / (1 + math.exp(-z))
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
def _format_result(record: dict) -> dict:
|
||||
return {
|
||||
"id": record.get("id", ""),
|
||||
"perceptual_type": record.get("perceptual_type", ""),
|
||||
"file_name": record.get("file_name", ""),
|
||||
"file_path": record.get("file_path", ""),
|
||||
"summary": record.get("summary", ""),
|
||||
"topic": record.get("topic", ""),
|
||||
"domain": record.get("domain", ""),
|
||||
"keywords": record.get("keywords", []),
|
||||
"created_at": str(record.get("created_at", "")),
|
||||
"file_type": record.get("file_type", ""),
|
||||
"score": record.get("score", 0),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_content_text(formatted: dict) -> str:
|
||||
parts = []
|
||||
if formatted["summary"]:
|
||||
parts.append(formatted["summary"])
|
||||
if formatted["topic"]:
|
||||
parts.append(f"[主题: {formatted['topic']}]")
|
||||
if formatted["keywords"]:
|
||||
kw_list = formatted["keywords"]
|
||||
if isinstance(kw_list, list):
|
||||
parts.append(f"[关键词: {', '.join(kw_list)}]")
|
||||
if formatted["file_name"]:
|
||||
parts.append(f"[文件: {formatted['file_name']}]")
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def _extract_keywords_from_problems(problem_extension: dict) -> List[str]:
|
||||
"""Extract search keywords from problem extension results."""
|
||||
keywords = []
|
||||
context = problem_extension.get("context", {})
|
||||
if isinstance(context, dict):
|
||||
for original_q, extended_qs in context.items():
|
||||
keywords.append(original_q)
|
||||
if isinstance(extended_qs, list):
|
||||
keywords.extend(extended_qs)
|
||||
return keywords
|
||||
|
||||
|
||||
async def perceptual_retrieve_node(state: ReadState) -> ReadState:
|
||||
"""
|
||||
LangGraph node: perceptual memory retrieval.
|
||||
|
||||
Uses PerceptualSearchService to run keyword + embedding search with
|
||||
BM25 fusion reranking, then writes results to state['perceptual_data'].
|
||||
"""
|
||||
end_user_id = state.get("end_user_id", "")
|
||||
problem_extension = state.get("problem_extension", {})
|
||||
original_query = state.get("data", "")
|
||||
memory_config = state.get("memory_config", None)
|
||||
|
||||
logger.info(f"Perceptual_Retrieve: start, end_user_id={end_user_id}")
|
||||
|
||||
keywords = _extract_keywords_from_problems(problem_extension)
|
||||
if not keywords:
|
||||
keywords = [original_query] if original_query else []
|
||||
|
||||
logger.info(f"Perceptual_Retrieve: {len(keywords)} keywords extracted")
|
||||
|
||||
service = PerceptualSearchService(
|
||||
end_user_id=end_user_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
search_result = await service.search(
|
||||
query=original_query,
|
||||
keywords=keywords,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
result = {
|
||||
"memories": search_result["memories"],
|
||||
"content": search_result["content"],
|
||||
"_intermediate": {
|
||||
"type": "perceptual_retrieve",
|
||||
"title": "感知记忆检索",
|
||||
"data": search_result["memories"],
|
||||
"query": original_query,
|
||||
"result_count": len(search_result["memories"]),
|
||||
},
|
||||
}
|
||||
return {"perceptual_data": result}
|
||||
@@ -263,7 +263,6 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
logger.info(f"Problem extension result: {aggregated_dict}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
print(time.time() - start)
|
||||
result = {
|
||||
"context": aggregated_dict,
|
||||
"original": data,
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
|
||||
PerceptualSearchService,
|
||||
)
|
||||
from app.core.memory.agent.models.summary_models import (
|
||||
RetrieveSummaryResponse,
|
||||
SummaryResponse,
|
||||
@@ -339,11 +343,45 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
|
||||
try:
|
||||
if storage_type != "rag":
|
||||
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(
|
||||
|
||||
async def _perceptual_search():
|
||||
service = PerceptualSearchService(
|
||||
end_user_id=end_user_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
return await service.search(query=data, limit=5)
|
||||
|
||||
hybrid_task = SearchService().execute_hybrid_search(
|
||||
**search_params,
|
||||
memory_config=memory_config,
|
||||
expand_communities=False, # 路径 "2" 只需要 community 的 summary 文本,不展开到 Statement
|
||||
expand_communities=False,
|
||||
)
|
||||
perceptual_task = _perceptual_search()
|
||||
|
||||
gather_results = await asyncio.gather(
|
||||
hybrid_task, perceptual_task, return_exceptions=True
|
||||
)
|
||||
hybrid_result = gather_results[0]
|
||||
perceptual_results = gather_results[1]
|
||||
|
||||
# 处理 hybrid search 异常
|
||||
if isinstance(hybrid_result, Exception):
|
||||
raise hybrid_result
|
||||
retrieve_info, question, raw_results = hybrid_result
|
||||
|
||||
# 处理感知记忆结果
|
||||
if isinstance(perceptual_results, Exception):
|
||||
logger.warning(f"[Input_Summary] perceptual search failed: {perceptual_results}")
|
||||
perceptual_results = []
|
||||
|
||||
# 拼接感知记忆内容到 retrieve_info
|
||||
if perceptual_results and isinstance(perceptual_results, dict):
|
||||
perceptual_content = perceptual_results.get("content", "")
|
||||
if perceptual_content:
|
||||
retrieve_info = f"{retrieve_info}\n\n<history-files>\n{perceptual_content}"
|
||||
count = len(perceptual_results.get("memories", []))
|
||||
logger.info(f"[Input_Summary] appended {count} perceptual memories (reranked)")
|
||||
|
||||
# 调试:打印 community 检索结果数量
|
||||
if raw_results and isinstance(raw_results, dict):
|
||||
reranked = raw_results.get('reranked_results', {})
|
||||
@@ -371,10 +409,7 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
"error": str(e)
|
||||
}
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
duration = end - start
|
||||
log_time('检索', duration)
|
||||
return {"summary": summary}
|
||||
|
||||
@@ -412,8 +447,20 @@ async def Retrieve_Summary(state: ReadState) -> ReadState:
|
||||
retrieve_info_str = list(set(retrieve_info_str))
|
||||
retrieve_info_str = '\n'.join(retrieve_info_str)
|
||||
|
||||
aimessages = await summary_llm(state, history, retrieve_info_str,
|
||||
'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1")
|
||||
# Merge perceptual memory content
|
||||
perceptual_data = state.get("perceptual_data", {})
|
||||
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
|
||||
if perceptual_content:
|
||||
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
|
||||
|
||||
aimessages = await summary_llm(
|
||||
state,
|
||||
history,
|
||||
retrieve_info_str,
|
||||
'direct_summary_prompt.jinja2',
|
||||
'retrieve_summary', RetrieveSummaryResponse,
|
||||
"1"
|
||||
)
|
||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||
await summary_redis_save(state, aimessages)
|
||||
if aimessages == '':
|
||||
@@ -458,6 +505,12 @@ async def Summary(state: ReadState) -> ReadState:
|
||||
retrieve_info_str += i + '\n'
|
||||
history = await summary_history(state)
|
||||
|
||||
# Merge perceptual memory content
|
||||
perceptual_data = state.get("perceptual_data", {})
|
||||
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
|
||||
if perceptual_content:
|
||||
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
|
||||
|
||||
data = {
|
||||
"query": query,
|
||||
"history": history,
|
||||
@@ -508,6 +561,13 @@ async def Summary_fails(state: ReadState) -> ReadState:
|
||||
if key == 'answer_small':
|
||||
for i in value:
|
||||
retrieve_info_str += i + '\n'
|
||||
|
||||
# Merge perceptual memory content
|
||||
perceptual_data = state.get("perceptual_data", {})
|
||||
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
|
||||
if perceptual_content:
|
||||
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
|
||||
|
||||
data = {
|
||||
"query": query,
|
||||
"history": history,
|
||||
|
||||
@@ -15,7 +15,10 @@ from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import (
|
||||
Problem_Extension,
|
||||
)
|
||||
from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import (
|
||||
retrieve,
|
||||
retrieve_nodes,
|
||||
)
|
||||
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
|
||||
perceptual_retrieve_node,
|
||||
)
|
||||
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
|
||||
Input_Summary,
|
||||
@@ -48,13 +51,14 @@ async def make_read_graph():
|
||||
"""
|
||||
try:
|
||||
# Build workflow graph
|
||||
workflow = StateGraph(ReadState)
|
||||
workflow = StateGraph(ReadState)
|
||||
workflow.add_node("content_input", content_input_node)
|
||||
workflow.add_node("Split_The_Problem", Split_The_Problem)
|
||||
workflow.add_node("Problem_Extension", Problem_Extension)
|
||||
workflow.add_node("Input_Summary", Input_Summary)
|
||||
# workflow.add_node("Retrieve", retrieve_nodes)
|
||||
workflow.add_node("Retrieve", retrieve)
|
||||
workflow.add_node("Retrieve", retrieve_nodes)
|
||||
# workflow.add_node("Retrieve", retrieve)
|
||||
workflow.add_node("Perceptual_Retrieve", perceptual_retrieve_node)
|
||||
workflow.add_node("Verify", Verify)
|
||||
workflow.add_node("Retrieve_Summary", Retrieve_Summary)
|
||||
workflow.add_node("Summary", Summary)
|
||||
@@ -65,14 +69,15 @@ async def make_read_graph():
|
||||
workflow.add_conditional_edges("content_input", Split_continue)
|
||||
workflow.add_edge("Input_Summary", END)
|
||||
workflow.add_edge("Split_The_Problem", "Problem_Extension")
|
||||
workflow.add_edge("Problem_Extension", "Retrieve")
|
||||
# After Problem_Extension, retrieve perceptual memory first, then main Retrieve
|
||||
workflow.add_edge("Problem_Extension", "Perceptual_Retrieve")
|
||||
workflow.add_edge("Perceptual_Retrieve", "Retrieve")
|
||||
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
|
||||
workflow.add_edge("Retrieve_Summary", END)
|
||||
workflow.add_conditional_edges("Verify", Verify_continue)
|
||||
workflow.add_edge("Summary_fails", END)
|
||||
workflow.add_edge("Summary", END)
|
||||
|
||||
'''-----'''
|
||||
# workflow.add_edge("Retrieve", END)
|
||||
|
||||
# Compile workflow
|
||||
@@ -80,7 +85,5 @@ async def make_read_graph():
|
||||
yield graph
|
||||
|
||||
except Exception as e:
|
||||
print(f"创建工作流失败: {e}")
|
||||
logger.error(f"创建工作流失败: {e}")
|
||||
raise
|
||||
finally:
|
||||
print("工作流创建完成")
|
||||
|
||||
@@ -12,7 +12,6 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
@@ -21,25 +20,6 @@ logger = get_agent_logger(__name__)
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
|
||||
|
||||
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
|
||||
"""
|
||||
Write messages to RAG storage system
|
||||
|
||||
Combines user and AI messages into a single string format and stores them
|
||||
in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval.
|
||||
|
||||
Args:
|
||||
end_user_id: User identifier for the conversation
|
||||
user_message: User's input message content
|
||||
ai_message: AI's response message content
|
||||
user_rag_memory_id: RAG memory identifier for storage location
|
||||
"""
|
||||
# RAG mode: combine messages into string format (maintain original logic)
|
||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
|
||||
|
||||
async def write(
|
||||
storage_type,
|
||||
end_user_id,
|
||||
@@ -118,7 +98,7 @@ async def write(
|
||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||
|
||||
|
||||
async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope):
|
||||
async def term_memory_save(end_user_id, strategy_type, scope):
|
||||
"""
|
||||
Save long-term memory data to database
|
||||
|
||||
@@ -127,10 +107,8 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
|
||||
to long-term memory storage.
|
||||
|
||||
Args:
|
||||
long_term_messages: Long-term message data to be saved
|
||||
actual_config_id: Configuration identifier for memory settings
|
||||
end_user_id: User identifier for memory association
|
||||
type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
|
||||
strategy_type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
|
||||
scope: Scope/window size for memory processing
|
||||
"""
|
||||
with get_db_context() as db_session:
|
||||
@@ -138,7 +116,10 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
|
||||
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
result = write_store.get_session_by_userid(end_user_id)
|
||||
if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
if not result:
|
||||
logger.warning(f"No write data found for user {end_user_id}")
|
||||
return
|
||||
if strategy_type in [AgentMemory_Long_Term.STRATEGY_CHUNK, AgentMemory_Long_Term.STRATEGY_AGGREGATE]:
|
||||
data = await format_parsing(result, "dict")
|
||||
chunk_data = data[:scope]
|
||||
if len(chunk_data) == scope:
|
||||
@@ -151,9 +132,6 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
|
||||
logger.info(f'写入短长期:')
|
||||
|
||||
|
||||
"""Window-based dialogue processing"""
|
||||
|
||||
|
||||
async def window_dialogue(end_user_id, langchain_messages, memory_config, scope):
|
||||
"""
|
||||
Process dialogue based on window size and write to Neo4j
|
||||
@@ -167,40 +145,33 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
||||
langchain_messages: Original message data list
|
||||
scope: Window size determining when to trigger long-term storage
|
||||
"""
|
||||
scope = scope
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
||||
if is_end_user_id is not False:
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
||||
redis_messages = count_store.get_sessions_count(end_user_id)[1]
|
||||
if is_end_user_id and int(is_end_user_id) != int(scope):
|
||||
is_end_user_id += 1
|
||||
langchain_messages += redis_messages
|
||||
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
|
||||
elif int(is_end_user_id) == int(scope):
|
||||
is_end_user_has_history = count_store.get_sessions_count(end_user_id)
|
||||
if is_end_user_has_history:
|
||||
end_user_visit_count, redis_messages = is_end_user_has_history
|
||||
else:
|
||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||
return
|
||||
end_user_visit_count += 1
|
||||
if end_user_visit_count < scope:
|
||||
redis_messages.extend(langchain_messages)
|
||||
count_store.update_sessions_count(end_user_id, end_user_visit_count, redis_messages)
|
||||
else:
|
||||
logger.info('写入长期记忆NEO4J')
|
||||
formatted_messages = redis_messages
|
||||
redis_messages.extend(langchain_messages)
|
||||
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
|
||||
if hasattr(memory_config, 'config_id'):
|
||||
config_id = memory_config.config_id
|
||||
else:
|
||||
config_id = memory_config
|
||||
|
||||
await write(
|
||||
AgentMemory_Long_Term.STORAGE_NEO4J,
|
||||
end_user_id,
|
||||
"",
|
||||
"",
|
||||
None,
|
||||
end_user_id,
|
||||
config_id,
|
||||
formatted_messages
|
||||
write_message_task.delay(
|
||||
end_user_id, # end_user_id: User ID
|
||||
redis_messages, # message: JSON string format message list
|
||||
config_id, # config_id: Configuration ID string
|
||||
AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j"
|
||||
"" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
|
||||
)
|
||||
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
|
||||
else:
|
||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||
|
||||
|
||||
"""Time-based memory processing"""
|
||||
count_store.update_sessions_count(end_user_id, 0, [])
|
||||
|
||||
|
||||
async def memory_long_term_storage(end_user_id, memory_config, time):
|
||||
@@ -291,9 +262,7 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
||||
return result_dict
|
||||
|
||||
except Exception as e:
|
||||
print(f"[aggregate_judgment] 发生错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
logger.error(f"[aggregate_judgment] 发生错误: {e}", exc_info=True)
|
||||
|
||||
return {
|
||||
"is_same_event": False,
|
||||
|
||||
@@ -1,49 +1,25 @@
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from app.db import get_db, get_db_context
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
|
||||
aggregate_judgment
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def make_write_graph():
|
||||
"""
|
||||
Create a write graph workflow for memory operations.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
tools: MCP tools loaded from session
|
||||
apply_id: Application identifier
|
||||
end_user_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
"""
|
||||
workflow = StateGraph(WriteState)
|
||||
workflow.add_node("save_neo4j", write_node)
|
||||
workflow.add_edge(START, "save_neo4j")
|
||||
workflow.add_edge("save_neo4j", END)
|
||||
|
||||
graph = workflow.compile()
|
||||
|
||||
yield graph
|
||||
|
||||
|
||||
async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '',
|
||||
end_user_id: str = '', scope: int = 6):
|
||||
async def long_term_storage(
|
||||
long_term_type: str,
|
||||
langchain_messages: list,
|
||||
memory_config_id: str,
|
||||
end_user_id: str,
|
||||
scope: int = 6
|
||||
):
|
||||
"""
|
||||
Handle long-term memory storage with different strategies
|
||||
|
||||
@@ -53,33 +29,39 @@ async def long_term_storage(long_term_type: str = "chunk", langchain_messages: l
|
||||
Args:
|
||||
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
|
||||
langchain_messages: List of messages to store
|
||||
memory_config: Memory configuration identifier
|
||||
memory_config_id: Memory configuration identifier
|
||||
end_user_id: User group identifier
|
||||
scope: Scope parameter for chunk-based storage (default: 6)
|
||||
"""
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
|
||||
aggregate_judgment
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
if langchain_messages is None:
|
||||
langchain_messages = []
|
||||
|
||||
write_store.save_session_write(end_user_id, langchain_messages)
|
||||
# 获取数据库会话
|
||||
with get_db_context() as db_session:
|
||||
config_service = MemoryConfigService(db_session)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=memory_config, # 改为整数
|
||||
config_id=memory_config_id, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
|
||||
'''Strategy 1: Dialogue window with 6 rounds of conversation'''
|
||||
# Dialogue window with 6 rounds of conversation
|
||||
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
|
||||
"""Time-based strategy"""
|
||||
# Time-based strategy
|
||||
await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
"""Strategy 3: Aggregate judgment"""
|
||||
# Aggregate judgment
|
||||
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||
|
||||
|
||||
async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id):
|
||||
async def write_long_term(
|
||||
storage_type: str,
|
||||
end_user_id: str,
|
||||
messages: list[dict],
|
||||
user_rag_memory_id: str,
|
||||
actual_config_id: str
|
||||
):
|
||||
"""
|
||||
Write long-term memory with different storage types
|
||||
|
||||
@@ -89,44 +71,24 @@ async def write_long_term(storage_type, end_user_id, message_chat, aimessages, u
|
||||
Args:
|
||||
storage_type: Type of storage (RAG or traditional)
|
||||
end_user_id: User group identifier
|
||||
message_chat: User message content
|
||||
aimessages: AI response messages
|
||||
messages: message list
|
||||
user_rag_memory_id: RAG memory identifier
|
||||
actual_config_id: Actual configuration ID
|
||||
"""
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
||||
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
||||
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
|
||||
message_content = []
|
||||
for message in messages:
|
||||
message_content.append(f'{message.get("role")}:{message.get("content")}')
|
||||
messages_string = "\n".join(message_content)
|
||||
await write_rag(end_user_id, messages_string, user_rag_memory_id)
|
||||
else:
|
||||
# AI reply writing (user messages and AI replies paired, written as complete dialogue at once)
|
||||
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
||||
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
||||
long_term_messages = await agent_chat_messages(message_chat, aimessages)
|
||||
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
|
||||
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
|
||||
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
|
||||
|
||||
# async def main():
|
||||
# """主函数 - 运行工作流"""
|
||||
# langchain_messages = [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": "今天周五去爬山"
|
||||
# },
|
||||
# {
|
||||
# "role": "assistant",
|
||||
# "content": "好耶"
|
||||
# }
|
||||
#
|
||||
# ]
|
||||
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
|
||||
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
|
||||
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
|
||||
#
|
||||
#
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# import asyncio
|
||||
# asyncio.run(main())
|
||||
await long_term_storage(long_term_type=CHUNK,
|
||||
langchain_messages=messages,
|
||||
memory_config_id=actual_config_id,
|
||||
end_user_id=end_user_id,
|
||||
scope=SCOPE)
|
||||
await term_memory_save(end_user_id, CHUNK, scope=SCOPE)
|
||||
|
||||
@@ -10,7 +10,6 @@ from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.src.search import run_hybrid_search
|
||||
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
# 需要从展开结果中过滤的字段(含 Neo4j DateTime,不可 JSON 序列化)
|
||||
@@ -31,10 +30,10 @@ def _clean_expand_fields(obj):
|
||||
|
||||
|
||||
async def expand_communities_to_statements(
|
||||
community_results: List[dict],
|
||||
end_user_id: str,
|
||||
existing_content: str = "",
|
||||
limit: int = 10,
|
||||
community_results: List[dict],
|
||||
end_user_id: str,
|
||||
existing_content: str = "",
|
||||
limit: int = 10,
|
||||
) -> Tuple[List[dict], List[str]]:
|
||||
"""
|
||||
社区展开 helper:给定命中的 community 列表,拉取关联 Statement。
|
||||
@@ -76,17 +75,18 @@ async def expand_communities_to_statements(
|
||||
if s.get("statement") and s["statement"] not in existing_lines
|
||||
]
|
||||
cleaned = _clean_expand_fields(expanded_stmts)
|
||||
logger.info(f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements,新增 {len(new_texts)} 条,community_ids={community_ids}")
|
||||
logger.info(
|
||||
f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements,新增 {len(new_texts)} 条,community_ids={community_ids}")
|
||||
return cleaned, new_texts
|
||||
|
||||
|
||||
class SearchService:
|
||||
"""Service for executing hybrid search and processing results."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the search service."""
|
||||
logger.info("SearchService initialized")
|
||||
|
||||
|
||||
def extract_content_from_result(self, result: dict, node_type: str = "") -> str:
|
||||
"""
|
||||
Extract only meaningful content from search results, dropping all metadata.
|
||||
@@ -107,19 +107,19 @@ class SearchService:
|
||||
"""
|
||||
if not isinstance(result, dict):
|
||||
return str(result)
|
||||
|
||||
|
||||
content_parts = []
|
||||
|
||||
|
||||
# Statements: extract statement field
|
||||
if 'statement' in result and result['statement']:
|
||||
content_parts.append(result['statement'])
|
||||
|
||||
|
||||
# Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定
|
||||
# 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要
|
||||
is_community = (
|
||||
node_type == "community"
|
||||
or 'member_count' in result
|
||||
or 'core_entities' in result
|
||||
node_type == "community"
|
||||
or 'member_count' in result
|
||||
or 'core_entities' in result
|
||||
)
|
||||
if is_community:
|
||||
name = result.get('name', '')
|
||||
@@ -130,16 +130,16 @@ class SearchService:
|
||||
elif 'content' in result and result['content']:
|
||||
# Summaries / Chunks
|
||||
content_parts.append(result['content'])
|
||||
|
||||
|
||||
# Entities: extract name and fact_summary (commented out in original)
|
||||
# if 'name' in result and result['name']:
|
||||
# content_parts.append(result['name'])
|
||||
# if result.get('fact_summary'):
|
||||
# content_parts.append(result['fact_summary'])
|
||||
|
||||
|
||||
# Return concatenated content or empty string
|
||||
return '\n'.join(content_parts) if content_parts else ""
|
||||
|
||||
|
||||
def clean_query(self, query: str) -> str:
|
||||
"""
|
||||
Clean and escape query text for Lucene.
|
||||
@@ -155,33 +155,33 @@ class SearchService:
|
||||
Cleaned and escaped query string
|
||||
"""
|
||||
q = str(query).strip()
|
||||
|
||||
|
||||
# Remove wrapping quotes
|
||||
if (q.startswith("'") and q.endswith("'")) or (
|
||||
q.startswith('"') and q.endswith('"')
|
||||
q.startswith('"') and q.endswith('"')
|
||||
):
|
||||
q = q[1:-1]
|
||||
|
||||
|
||||
# Remove newlines and carriage returns
|
||||
q = q.replace('\r', ' ').replace('\n', ' ').strip()
|
||||
|
||||
|
||||
# Apply Lucene escaping
|
||||
q = escape_lucene_query(q)
|
||||
|
||||
|
||||
return q
|
||||
|
||||
|
||||
async def execute_hybrid_search(
|
||||
self,
|
||||
end_user_id: str,
|
||||
question: str,
|
||||
limit: int = 5,
|
||||
search_type: str = "hybrid",
|
||||
include: Optional[List[str]] = None,
|
||||
rerank_alpha: float = 0.4,
|
||||
output_path: str = "search_results.json",
|
||||
return_raw_results: bool = False,
|
||||
memory_config = None,
|
||||
expand_communities: bool = True,
|
||||
self,
|
||||
end_user_id: str,
|
||||
question: str,
|
||||
limit: int = 5,
|
||||
search_type: str = "hybrid",
|
||||
include: Optional[List[str]] = None,
|
||||
rerank_alpha: float = 0.4,
|
||||
output_path: str = "search_results.json",
|
||||
return_raw_results: bool = False,
|
||||
memory_config=None,
|
||||
expand_communities: bool = True,
|
||||
) -> Tuple[str, str, Optional[dict]]:
|
||||
"""
|
||||
Execute hybrid search and return clean content.
|
||||
@@ -205,10 +205,10 @@ class SearchService:
|
||||
"""
|
||||
if include is None:
|
||||
include = ["statements", "chunks", "entities", "summaries", "communities"]
|
||||
|
||||
|
||||
# Clean query
|
||||
cleaned_query = self.clean_query(question)
|
||||
|
||||
|
||||
try:
|
||||
# Execute search
|
||||
answer = await run_hybrid_search(
|
||||
@@ -221,18 +221,18 @@ class SearchService:
|
||||
memory_config=memory_config,
|
||||
rerank_alpha=rerank_alpha
|
||||
)
|
||||
|
||||
|
||||
# Extract results based on search type and include parameter
|
||||
# Prioritize summaries as they contain synthesized contextual information
|
||||
answer_list = []
|
||||
|
||||
|
||||
# For hybrid search, use reranked_results
|
||||
if search_type == "hybrid":
|
||||
reranked_results = answer.get('reranked_results', {})
|
||||
|
||||
|
||||
# Priority order: summaries first (most contextual), then communities, statements, chunks, entities
|
||||
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
|
||||
|
||||
|
||||
for category in priority_order:
|
||||
if category in include and category in reranked_results:
|
||||
category_results = reranked_results[category]
|
||||
@@ -242,7 +242,7 @@ class SearchService:
|
||||
# For keyword or embedding search, results are directly in answer dict
|
||||
# Apply same priority order
|
||||
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
|
||||
|
||||
|
||||
for category in priority_order:
|
||||
if category in include and category in answer:
|
||||
category_results = answer[category]
|
||||
@@ -261,7 +261,7 @@ class SearchService:
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
answer_list.extend(cleaned_stmts)
|
||||
|
||||
|
||||
# Extract clean content from all results,按类型传入 node_type 区分 community
|
||||
content_list = []
|
||||
for ans in answer_list:
|
||||
@@ -269,19 +269,18 @@ class SearchService:
|
||||
ntype = "community" if ('member_count' in ans or 'core_entities' in ans) else ""
|
||||
content_list.append(self.extract_content_from_result(ans, node_type=ntype))
|
||||
|
||||
|
||||
# Filter out empty strings and join with newlines
|
||||
clean_content = '\n'.join([c for c in content_list if c])
|
||||
|
||||
|
||||
# Log first 200 chars
|
||||
logger.info(f"检索接口搜索结果==>>:{clean_content[:200]}...")
|
||||
|
||||
|
||||
# Return raw results if requested
|
||||
if return_raw_results:
|
||||
return clean_content, cleaned_query, answer
|
||||
else:
|
||||
return clean_content, cleaned_query, None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Search failed for query '{question}' in group '{end_user_id}': {e}",
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Annotated, TypedDict
|
||||
@@ -52,6 +51,7 @@ class ReadState(TypedDict):
|
||||
embedding_id: str
|
||||
memory_config: object # 新增字段用于传递内存配置对象
|
||||
retrieve: dict
|
||||
perceptual_data: dict
|
||||
RetrieveSummary: dict
|
||||
InputSummary: dict
|
||||
verify: dict
|
||||
|
||||
@@ -3,8 +3,9 @@ import uuid
|
||||
from app.core.config import settings
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.memory.agent.utils.redis_base import (
|
||||
serialize_messages,
|
||||
serialize_messages,
|
||||
deserialize_messages,
|
||||
fix_encoding,
|
||||
format_session_data,
|
||||
@@ -14,12 +15,12 @@ from app.core.memory.agent.utils.redis_base import (
|
||||
get_current_timestamp
|
||||
)
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RedisWriteStore:
|
||||
"""Redis Write 类型存储类,用于管理 save_session_write 相关的数据"""
|
||||
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
@@ -66,10 +67,10 @@ class RedisWriteStore:
|
||||
})
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
logger.debug(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
return session_id
|
||||
except Exception as e:
|
||||
print(f"[save_session_write] 保存会话失败: {e}")
|
||||
logger.error(f"[save_session_write] 保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
|
||||
@@ -99,7 +100,7 @@ class RedisWriteStore:
|
||||
for key, data in zip(keys, all_data):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == userid:
|
||||
# 从 key 中提取 session_id: session:write:{session_id}
|
||||
@@ -108,16 +109,16 @@ class RedisWriteStore:
|
||||
"sessionid": session_id,
|
||||
"messages": fix_encoding(data.get('messages', ''))
|
||||
})
|
||||
|
||||
|
||||
if not results:
|
||||
return False
|
||||
|
||||
print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
|
||||
|
||||
logger.debug(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
|
||||
return results
|
||||
except Exception as e:
|
||||
print(f"[get_session_by_userid] 查询失败: {e}")
|
||||
logger.error(f"[get_session_by_userid] 查询失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
|
||||
"""
|
||||
通过 end_user_id 获取所有 write 类型的会话数据
|
||||
@@ -144,7 +145,7 @@ class RedisWriteStore:
|
||||
# 只查询 write 类型的 key
|
||||
keys = self.r.keys('session:write:*')
|
||||
if not keys:
|
||||
print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
|
||||
logger.debug(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
|
||||
return False
|
||||
|
||||
# 批量获取数据
|
||||
@@ -158,12 +159,12 @@ class RedisWriteStore:
|
||||
for key, data in zip(keys, all_data):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == end_user_id:
|
||||
# 从 key 中提取 session_id: session:write:{session_id}
|
||||
session_id = key.split(':')[-1]
|
||||
|
||||
|
||||
# 构建完整的会话信息
|
||||
session_info = {
|
||||
"session_id": session_id,
|
||||
@@ -173,23 +174,21 @@ class RedisWriteStore:
|
||||
"starttime": data.get('starttime', '')
|
||||
}
|
||||
results.append(session_info)
|
||||
|
||||
|
||||
if not results:
|
||||
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
|
||||
logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
|
||||
return False
|
||||
|
||||
|
||||
# 按时间排序(最新的在前)
|
||||
results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||
|
||||
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
|
||||
|
||||
logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
|
||||
return results
|
||||
except Exception as e:
|
||||
print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
logger.error(f"[get_all_sessions_by_end_user_id] 查询失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def find_user_recent_sessions(self, userid: str,
|
||||
def find_user_recent_sessions(self, userid: str,
|
||||
minutes: int = 5) -> List[Dict[str, str]]:
|
||||
"""
|
||||
根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据
|
||||
@@ -203,11 +202,11 @@ class RedisWriteStore:
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
# 只查询 write 类型的 key
|
||||
keys = self.r.keys('session:write:*')
|
||||
if not keys:
|
||||
print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
logger.debug(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
return []
|
||||
|
||||
# 批量获取数据
|
||||
@@ -221,7 +220,7 @@ class RedisWriteStore:
|
||||
for data in all_data:
|
||||
if not data:
|
||||
continue
|
||||
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == userid and data.get('starttime'):
|
||||
# write 类型没有 aimessages,所以 Answer 为空
|
||||
@@ -230,15 +229,14 @@ class RedisWriteStore:
|
||||
"Answer": "",
|
||||
"starttime": data.get('starttime', '')
|
||||
})
|
||||
|
||||
|
||||
# 根据时间范围过滤
|
||||
filtered_items = filter_by_time_range(matched_items, minutes)
|
||||
# 排序并移除时间字段
|
||||
result_items = sort_and_limit_results(filtered_items, limit=None)
|
||||
print(result_items)
|
||||
result_items = sort_and_limit_results(filtered_items)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
|
||||
logger.debug(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
|
||||
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
|
||||
return result_items
|
||||
@@ -258,7 +256,7 @@ class RedisWriteStore:
|
||||
|
||||
class RedisCountStore:
|
||||
"""Redis Count 类型存储类,用于管理访问次数统计相关的数据"""
|
||||
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
@@ -278,7 +276,7 @@ class RedisCountStore:
|
||||
decode_responses=True,
|
||||
encoding='utf-8'
|
||||
)
|
||||
self.uudi = session_id
|
||||
self.uuid = session_id
|
||||
|
||||
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
|
||||
"""
|
||||
@@ -295,26 +293,26 @@ class RedisCountStore:
|
||||
session_id = str(uuid.uuid4())
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
index_key = f'session:count:index:{end_user_id}' # 索引键
|
||||
|
||||
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"id": self.uuid,
|
||||
"end_user_id": end_user_id,
|
||||
"count": int(count),
|
||||
"messages": serialize_messages(messages),
|
||||
"starttime": get_current_timestamp()
|
||||
})
|
||||
pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期
|
||||
|
||||
|
||||
# 创建索引:end_user_id -> session_id 映射
|
||||
pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60)
|
||||
|
||||
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
|
||||
|
||||
logger.debug(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
|
||||
return session_id
|
||||
|
||||
def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]:
|
||||
def get_sessions_count(self, end_user_id: str) -> tuple[int, list[dict]] | bool:
|
||||
"""
|
||||
通过 end_user_id 查询访问次数统计
|
||||
|
||||
@@ -327,7 +325,7 @@ class RedisCountStore:
|
||||
try:
|
||||
# 使用索引键快速查找
|
||||
index_key = f'session:count:index:{end_user_id}'
|
||||
|
||||
|
||||
# 检查索引键类型,避免 WRONGTYPE 错误
|
||||
try:
|
||||
key_type = self.r.type(index_key)
|
||||
@@ -335,35 +333,40 @@ class RedisCountStore:
|
||||
self.r.delete(index_key)
|
||||
return False
|
||||
except Exception as type_error:
|
||||
print(f"[get_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
logger.error(f"[get_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
session_id = self.r.get(index_key)
|
||||
|
||||
|
||||
if not session_id:
|
||||
return False
|
||||
|
||||
|
||||
# 直接获取数据
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
data = self.r.hgetall(key)
|
||||
|
||||
|
||||
if not data:
|
||||
# 索引存在但数据不存在,清理索引
|
||||
self.r.delete(index_key)
|
||||
return False
|
||||
|
||||
|
||||
count = data.get('count')
|
||||
messages_str = data.get('messages')
|
||||
|
||||
|
||||
if count is not None:
|
||||
messages = deserialize_messages(messages_str)
|
||||
return [int(count), messages]
|
||||
|
||||
messages: list[dict] = deserialize_messages(messages_str)
|
||||
return int(count), messages
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"[get_sessions_count] 查询失败: {e}")
|
||||
logger.error(f"[get_sessions_count] 查询失败: {e}")
|
||||
return False
|
||||
def update_sessions_count(self, end_user_id: str, new_count: int,
|
||||
messages: Any) -> bool:
|
||||
|
||||
def update_sessions_count(
|
||||
self,
|
||||
end_user_id: str,
|
||||
new_count: int,
|
||||
messages: Any
|
||||
) -> bool:
|
||||
"""
|
||||
通过 end_user_id 修改访问次数统计(优化版:使用索引)
|
||||
|
||||
@@ -378,39 +381,39 @@ class RedisCountStore:
|
||||
try:
|
||||
# 使用索引键快速查找
|
||||
index_key = f'session:count:index:{end_user_id}'
|
||||
|
||||
|
||||
# 检查索引键类型,避免 WRONGTYPE 错误
|
||||
try:
|
||||
key_type = self.r.type(index_key)
|
||||
if key_type != 'string' and key_type != 'none':
|
||||
# 索引键类型错误,删除并返回 False
|
||||
print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
|
||||
logger.warning(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
|
||||
self.r.delete(index_key)
|
||||
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
return False
|
||||
except Exception as type_error:
|
||||
print(f"[update_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
logger.error(f"[update_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
session_id = self.r.get(index_key)
|
||||
|
||||
|
||||
if not session_id:
|
||||
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
return False
|
||||
|
||||
|
||||
# 直接更新数据
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
messages_str = serialize_messages(messages)
|
||||
|
||||
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, 'count', int(new_count))
|
||||
pipe.hset(key, 'count', str(new_count))
|
||||
pipe.hset(key, 'messages', messages_str)
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
|
||||
|
||||
logger.debug(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"[update_sessions_count] 更新失败: {e}")
|
||||
logger.debug(f"[update_sessions_count] 更新失败: {e}")
|
||||
return False
|
||||
|
||||
def delete_all_count_sessions(self) -> int:
|
||||
@@ -428,7 +431,7 @@ class RedisCountStore:
|
||||
|
||||
class RedisSessionStore:
|
||||
"""Redis 会话存储类,用于管理会话数据"""
|
||||
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
@@ -451,9 +454,9 @@ class RedisSessionStore:
|
||||
self.uudi = session_id
|
||||
|
||||
# ==================== 写入操作 ====================
|
||||
|
||||
def save_session(self, userid: str, messages: str, aimessages: str,
|
||||
apply_id: str, end_user_id: str) -> str:
|
||||
|
||||
def save_session(self, userid: str, messages: str, aimessages: str,
|
||||
apply_id: str, end_user_id: str) -> str:
|
||||
"""
|
||||
写入一条会话数据,返回 session_id
|
||||
|
||||
@@ -483,14 +486,14 @@ class RedisSessionStore:
|
||||
})
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
logger.debug(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
return session_id
|
||||
except Exception as e:
|
||||
print(f"[save_session] 保存会话失败: {e}")
|
||||
logger.error(f"[save_session] 保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
# ==================== 读取操作 ====================
|
||||
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
读取一条会话数据
|
||||
@@ -520,8 +523,8 @@ class RedisSessionStore:
|
||||
sessions[sid] = self.get_session(sid)
|
||||
return sessions
|
||||
|
||||
def find_user_apply_group(self, sessionid: str, apply_id: str,
|
||||
end_user_id: str) -> List[Dict[str, str]]:
|
||||
def find_user_apply_group(self, sessionid: str, apply_id: str,
|
||||
end_user_id: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
根据 sessionid、apply_id 和 end_user_id 查询会话数据,返回最新的6条
|
||||
|
||||
@@ -535,10 +538,10 @@ class RedisSessionStore:
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
keys = self.r.keys('session:*')
|
||||
if not keys:
|
||||
print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
logger.debug(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
return []
|
||||
|
||||
# 批量获取数据
|
||||
@@ -556,21 +559,21 @@ class RedisSessionStore:
|
||||
continue
|
||||
|
||||
if (data.get('apply_id') == apply_id and
|
||||
data.get('end_user_id') == end_user_id):
|
||||
data.get('end_user_id') == end_user_id):
|
||||
# 支持模糊匹配或完全匹配 sessionid
|
||||
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
||||
matched_items.append(format_session_data(data, include_time=True))
|
||||
|
||||
|
||||
# 排序、限制数量并移除时间字段
|
||||
result_items = sort_and_limit_results(matched_items, limit=6)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
logger.debug(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
|
||||
return result_items
|
||||
|
||||
# ==================== 更新操作 ====================
|
||||
|
||||
|
||||
def update_session(self, session_id: str, field: str, value: Any) -> bool:
|
||||
"""
|
||||
更新单个字段
|
||||
@@ -591,7 +594,7 @@ class RedisSessionStore:
|
||||
return bool(results[0])
|
||||
|
||||
# ==================== 删除操作 ====================
|
||||
|
||||
|
||||
def delete_session(self, session_id: str) -> int:
|
||||
"""
|
||||
删除单条会话
|
||||
@@ -632,7 +635,7 @@ class RedisSessionStore:
|
||||
|
||||
keys = self.r.keys('session:*')
|
||||
if not keys:
|
||||
print("[delete_duplicate_sessions] 没有会话数据")
|
||||
logger.debug("[delete_duplicate_sessions] 没有会话数据")
|
||||
return 0
|
||||
|
||||
# 批量获取所有数据
|
||||
@@ -678,7 +681,7 @@ class RedisSessionStore:
|
||||
deleted_count += len(batch)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
||||
logger.debug(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
||||
return deleted_count
|
||||
|
||||
|
||||
|
||||
@@ -152,6 +152,24 @@ async def write(
|
||||
# Step 3: Save all data to Neo4j database
|
||||
step_start = time.time()
|
||||
|
||||
# Neo4j 写入前:清洗用户/AI助手实体之间的别名交叉污染
|
||||
# 从 Neo4j 查询已有的 AI 助手别名,与本轮实体中的 AI 助手别名合并,
|
||||
# 确保用户实体的 aliases 不包含 AI 助手的名字
|
||||
try:
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
|
||||
clean_cross_role_aliases,
|
||||
fetch_neo4j_assistant_aliases,
|
||||
)
|
||||
neo4j_assistant_aliases = set()
|
||||
if all_entity_nodes:
|
||||
_eu_id = all_entity_nodes[0].end_user_id
|
||||
if _eu_id:
|
||||
neo4j_assistant_aliases = await fetch_neo4j_assistant_aliases(neo4j_connector, _eu_id)
|
||||
clean_cross_role_aliases(all_entity_nodes, external_assistant_aliases=neo4j_assistant_aliases)
|
||||
logger.info(f"Neo4j 写入前别名清洗完成,AI助手别名排除集大小: {len(neo4j_assistant_aliases)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Neo4j 写入前别名清洗失败(不影响主流程): {e}")
|
||||
|
||||
# 添加死锁重试机制
|
||||
max_retries = 3
|
||||
retry_delay = 1 # 秒
|
||||
|
||||
@@ -56,7 +56,7 @@ class LLMClient(ABC):
|
||||
self.max_retries = self.config.max_retries
|
||||
self.timeout = self.config.timeout
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"初始化 LLM 客户端: provider={self.provider}, "
|
||||
f"model={self.model_name}, max_retries={self.max_retries}"
|
||||
)
|
||||
|
||||
@@ -43,6 +43,7 @@ load_dotenv()
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
def _parse_datetime(value: Any) -> Optional[datetime]:
|
||||
"""Parse ISO `created_at` strings of the form 'YYYY-MM-DDTHH:MM:SS.ssssss'."""
|
||||
if value is None:
|
||||
@@ -75,7 +76,7 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
|
||||
if score_field == "activation_value" and score is None:
|
||||
scores.append(None) # 保持 None,稍后特殊处理
|
||||
continue
|
||||
|
||||
|
||||
if score is not None and isinstance(score, (int, float)):
|
||||
scores.append(float(score))
|
||||
else:
|
||||
@@ -83,10 +84,10 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
|
||||
|
||||
if not scores:
|
||||
return results
|
||||
|
||||
|
||||
# 过滤掉 None 值,只对有效分数进行归一化
|
||||
valid_scores = [s for s in scores if s is not None]
|
||||
|
||||
|
||||
if not valid_scores:
|
||||
# 所有分数都是 None,不进行归一化
|
||||
for item in results:
|
||||
@@ -94,7 +95,7 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
|
||||
item[f"normalized_{score_field}"] = None
|
||||
return results
|
||||
|
||||
if len(valid_scores) == 1: # Single valid score, set to 1.0
|
||||
if len(valid_scores) == 1: # Single valid score, set to 1.0
|
||||
for item, score in zip(results, scores):
|
||||
if score_field in item or score_field == "activation_value":
|
||||
if score is None:
|
||||
@@ -132,7 +133,6 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
|
||||
return results
|
||||
|
||||
|
||||
|
||||
def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Remove duplicate items from search results based on content.
|
||||
@@ -150,52 +150,53 @@ def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
seen_ids = set()
|
||||
seen_content = set()
|
||||
deduplicated = []
|
||||
|
||||
|
||||
for item in items:
|
||||
# Try multiple ID fields to identify unique items
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
|
||||
|
||||
# Extract content from various possible fields
|
||||
content = (
|
||||
item.get("text") or
|
||||
item.get("content") or
|
||||
item.get("statement") or
|
||||
item.get("name") or
|
||||
""
|
||||
item.get("text") or
|
||||
item.get("content") or
|
||||
item.get("statement") or
|
||||
item.get("name") or
|
||||
""
|
||||
)
|
||||
|
||||
|
||||
# Normalize content for comparison (strip whitespace and lowercase)
|
||||
normalized_content = str(content).strip().lower() if content else ""
|
||||
|
||||
|
||||
# Check if we've seen this ID or content before
|
||||
is_duplicate = False
|
||||
|
||||
|
||||
if item_id and item_id in seen_ids:
|
||||
is_duplicate = True
|
||||
elif normalized_content and normalized_content in seen_content:
|
||||
# Only check content duplication if content is not empty
|
||||
is_duplicate = True
|
||||
|
||||
|
||||
if not is_duplicate:
|
||||
# Mark as seen
|
||||
if item_id:
|
||||
seen_ids.add(item_id)
|
||||
if normalized_content: # Only track non-empty content
|
||||
seen_content.add(normalized_content)
|
||||
|
||||
|
||||
deduplicated.append(item)
|
||||
|
||||
|
||||
return deduplicated
|
||||
|
||||
|
||||
def rerank_with_activation(
|
||||
keyword_results: Dict[str, List[Dict[str, Any]]],
|
||||
embedding_results: Dict[str, List[Dict[str, Any]]],
|
||||
alpha: float = 0.6,
|
||||
limit: int = 10,
|
||||
forgetting_config: ForgettingEngineConfig | None = None,
|
||||
activation_boost_factor: float = 0.8,
|
||||
now: datetime | None = None,
|
||||
keyword_results: Dict[str, List[Dict[str, Any]]],
|
||||
embedding_results: Dict[str, List[Dict[str, Any]]],
|
||||
alpha: float = 0.6,
|
||||
limit: int = 10,
|
||||
forgetting_config: ForgettingEngineConfig | None = None,
|
||||
activation_boost_factor: float = 0.8,
|
||||
now: datetime | None = None,
|
||||
content_score_threshold: float = 0.5,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
两阶段排序:先按内容相关性筛选,再按激活值排序。
|
||||
@@ -222,6 +223,8 @@ def rerank_with_activation(
|
||||
forgetting_config: 遗忘引擎配置(当前未使用)
|
||||
activation_boost_factor: 激活度对记忆强度的影响系数 (默认: 0.8)
|
||||
now: 当前时间(用于遗忘计算)
|
||||
content_score_threshold: 内容相关性最低阈值(基于归一化后的 content_score),
|
||||
低于此阈值的结果会被过滤。默认 0.5。
|
||||
|
||||
返回:
|
||||
带评分元数据的重排序结果,按 final_score 排序
|
||||
@@ -229,26 +232,26 @@ def rerank_with_activation(
|
||||
# 验证权重范围
|
||||
if not (0 <= alpha <= 1):
|
||||
raise ValueError(f"alpha 必须在 [0, 1] 范围内,当前值: {alpha}")
|
||||
|
||||
|
||||
# 初始化遗忘引擎(如果需要)
|
||||
engine = None
|
||||
if forgetting_config:
|
||||
engine = ForgettingEngine(forgetting_config)
|
||||
now_dt = now or datetime.now()
|
||||
|
||||
|
||||
reranked: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
|
||||
for category in ["statements", "chunks", "entities", "summaries", "communities"]:
|
||||
keyword_items = keyword_results.get(category, [])
|
||||
embedding_items = embedding_results.get(category, [])
|
||||
|
||||
|
||||
# 步骤 1: 归一化分数
|
||||
keyword_items = normalize_scores(keyword_items, "score")
|
||||
embedding_items = normalize_scores(embedding_items, "score")
|
||||
|
||||
|
||||
# 步骤 2: 按 ID 合并结果(去重)
|
||||
combined_items: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
# 添加关键词结果
|
||||
for item in keyword_items:
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
@@ -257,7 +260,7 @@ def rerank_with_activation(
|
||||
combined_items[item_id] = item.copy()
|
||||
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
combined_items[item_id]["embedding_score"] = 0 # 默认值
|
||||
|
||||
|
||||
# 添加或更新向量嵌入结果
|
||||
for item in embedding_items:
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
@@ -271,18 +274,18 @@ def rerank_with_activation(
|
||||
combined_items[item_id] = item.copy()
|
||||
combined_items[item_id]["bm25_score"] = 0 # 默认值
|
||||
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
|
||||
|
||||
# 步骤 3: 归一化激活度分数
|
||||
# 为所有项准备激活度值列表
|
||||
items_list = list(combined_items.values())
|
||||
items_list = normalize_scores(items_list, "activation_value")
|
||||
|
||||
|
||||
# 更新 combined_items 中的归一化激活度分数
|
||||
for item in items_list:
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
if item_id and item_id in combined_items:
|
||||
combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value")
|
||||
|
||||
|
||||
# 步骤 4: 计算基础分数和最终分数
|
||||
for item_id, item in combined_items.items():
|
||||
bm25_norm = float(item.get("bm25_score", 0) or 0)
|
||||
@@ -290,45 +293,45 @@ def rerank_with_activation(
|
||||
# normalized_activation_value 为 None 表示该节点无激活值,保留 None 语义
|
||||
raw_act_norm = item.get("normalized_activation_value")
|
||||
act_norm = float(raw_act_norm) if raw_act_norm is not None else None
|
||||
|
||||
|
||||
# 第一阶段:只考虑内容相关性(BM25 + Embedding)
|
||||
# alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重
|
||||
content_score = alpha * bm25_norm + (1 - alpha) * emb_norm
|
||||
base_score = content_score # 第一阶段用内容分数
|
||||
|
||||
|
||||
# 存储激活度分数供第二阶段使用(None 表示无激活值,不参与激活值排序)
|
||||
item["activation_score"] = act_norm # 可能为 None
|
||||
item["content_score"] = content_score
|
||||
item["base_score"] = base_score
|
||||
|
||||
|
||||
# 步骤 5: 应用遗忘曲线(可选)
|
||||
if engine:
|
||||
# 计算受激活度影响的记忆强度
|
||||
importance = float(item.get("importance_score", 0.5) or 0.5)
|
||||
|
||||
|
||||
# 获取 activation_value
|
||||
activation_val = item.get("activation_value")
|
||||
|
||||
|
||||
# 只对有激活值的节点应用遗忘曲线
|
||||
if activation_val is not None and isinstance(activation_val, (int, float)):
|
||||
activation_val = float(activation_val)
|
||||
|
||||
|
||||
# 计算记忆强度:importance_score × (1 + activation_value × boost_factor)
|
||||
memory_strength = importance * (1 + activation_val * activation_boost_factor)
|
||||
|
||||
|
||||
# 计算经过的时间(天数)
|
||||
dt = _parse_datetime(item.get("created_at"))
|
||||
if dt is None:
|
||||
time_elapsed_days = 0.0
|
||||
else:
|
||||
time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
|
||||
|
||||
|
||||
# 获取遗忘权重
|
||||
forgetting_weight = engine.calculate_weight(
|
||||
time_elapsed=time_elapsed_days,
|
||||
memory_strength=memory_strength
|
||||
)
|
||||
|
||||
|
||||
# 应用到基础分数
|
||||
item["forgetting_weight"] = forgetting_weight
|
||||
item["final_score"] = base_score * forgetting_weight
|
||||
@@ -338,7 +341,7 @@ def rerank_with_activation(
|
||||
else:
|
||||
# 不使用遗忘曲线
|
||||
item["final_score"] = base_score
|
||||
|
||||
|
||||
# 步骤 6: 两阶段排序和限制
|
||||
# 第一阶段:按内容相关性(base_score)排序,取 Top-K
|
||||
first_stage_limit = limit * 3 # 可配置,取3倍候选
|
||||
@@ -347,11 +350,11 @@ def rerank_with_activation(
|
||||
key=lambda x: float(x.get("base_score", 0) or 0), # 按内容分数排序
|
||||
reverse=True
|
||||
)[:first_stage_limit]
|
||||
|
||||
|
||||
# 第二阶段:分离有激活值和无激活值的节点
|
||||
items_with_activation = []
|
||||
items_without_activation = []
|
||||
|
||||
|
||||
for item in first_stage_sorted:
|
||||
activation_score = item.get("activation_score")
|
||||
# 检查是否有有效的激活值(不是 None)
|
||||
@@ -359,14 +362,14 @@ def rerank_with_activation(
|
||||
items_with_activation.append(item)
|
||||
else:
|
||||
items_without_activation.append(item)
|
||||
|
||||
|
||||
# 优先按激活值排序有激活值的节点
|
||||
sorted_with_activation = sorted(
|
||||
items_with_activation,
|
||||
key=lambda x: float(x.get("activation_score", 0) or 0),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
|
||||
# 如果有激活值的节点不足 limit,用无激活值的节点补充
|
||||
if len(sorted_with_activation) < limit:
|
||||
needed = limit - len(sorted_with_activation)
|
||||
@@ -374,7 +377,7 @@ def rerank_with_activation(
|
||||
sorted_items = sorted_with_activation + items_without_activation[:needed]
|
||||
else:
|
||||
sorted_items = sorted_with_activation[:limit]
|
||||
|
||||
|
||||
# 两阶段排序完成,更新 final_score 以反映实际排序依据
|
||||
# Stage 1: 按 content_score 筛选候选(已完成)
|
||||
# Stage 2: 按 activation_score 排序(已完成)
|
||||
@@ -390,16 +393,29 @@ def rerank_with_activation(
|
||||
else:
|
||||
# 无激活值:使用内容相关性分数
|
||||
item["final_score"] = item.get("base_score", 0)
|
||||
|
||||
# 最终去重确保没有重复项
|
||||
|
||||
if content_score_threshold > 0:
|
||||
before_count = len(sorted_items)
|
||||
sorted_items = [
|
||||
item for item in sorted_items
|
||||
if float(item.get("content_score", 0) or 0) >= content_score_threshold
|
||||
]
|
||||
filtered_count = before_count - len(sorted_items)
|
||||
if filtered_count > 0:
|
||||
logger.info(
|
||||
f"[rerank] {category}: filtered {filtered_count}/{before_count} "
|
||||
f"items below content_score_threshold={content_score_threshold}"
|
||||
)
|
||||
|
||||
sorted_items = _deduplicate_results(sorted_items)
|
||||
|
||||
|
||||
reranked[category] = sorted_items
|
||||
|
||||
|
||||
return reranked
|
||||
|
||||
|
||||
def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str], log_file: str = None):
|
||||
def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str],
|
||||
log_file: str = None):
|
||||
"""Log search query information using the logger.
|
||||
|
||||
Args:
|
||||
@@ -412,7 +428,7 @@ def log_search_query(query_text: str, search_type: str, end_user_id: str | None,
|
||||
"""
|
||||
# Ensure the query text is plain and clean before logging
|
||||
cleaned_query = extract_plain_query(query_text)
|
||||
|
||||
|
||||
# Log using the standard logger
|
||||
logger.info(
|
||||
f"Search query: query='{cleaned_query}', type={search_type}, "
|
||||
@@ -439,8 +455,8 @@ def _remove_keys_recursive(obj: Any, keys_to_remove: List[str]) -> Any:
|
||||
|
||||
|
||||
def apply_reranker_placeholder(
|
||||
results: Dict[str, List[Dict[str, Any]]],
|
||||
query_text: str,
|
||||
results: Dict[str, List[Dict[str, Any]]],
|
||||
query_text: str,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Placeholder for a cross-encoder reranker.
|
||||
@@ -483,7 +499,7 @@ def apply_reranker_placeholder(
|
||||
# ) -> Dict[str, List[Dict[str, Any]]]:
|
||||
# """
|
||||
# Apply LLM-based reranking to search results.
|
||||
|
||||
|
||||
# Args:
|
||||
# results: Search results organized by category
|
||||
# query_text: Original search query
|
||||
@@ -491,7 +507,7 @@ def apply_reranker_placeholder(
|
||||
# llm_weight: Weight for LLM score (0.0-1.0, higher favors LLM)
|
||||
# top_k: Maximum number of items to rerank per category
|
||||
# batch_size: Number of items to process concurrently
|
||||
|
||||
|
||||
# Returns:
|
||||
# Reranked results with final_score and reranker_model fields
|
||||
# """
|
||||
@@ -501,18 +517,18 @@ def apply_reranker_placeholder(
|
||||
# # except Exception as e:
|
||||
# # logger.debug(f"Failed to load reranker config: {e}")
|
||||
# # rc = {}
|
||||
|
||||
|
||||
# # Check if reranking is enabled
|
||||
# enabled = rc.get("enabled", False)
|
||||
# if not enabled:
|
||||
# logger.debug("LLM reranking is disabled in configuration")
|
||||
# return results
|
||||
|
||||
|
||||
# # Load configuration parameters with defaults
|
||||
# llm_weight = llm_weight if llm_weight is not None else rc.get("llm_weight", 0.5)
|
||||
# top_k = top_k if top_k is not None else rc.get("top_k", 20)
|
||||
# batch_size = batch_size if batch_size is not None else rc.get("batch_size", 5)
|
||||
|
||||
|
||||
# # Initialize reranker client if not provided
|
||||
# if reranker_client is None:
|
||||
# try:
|
||||
@@ -520,10 +536,10 @@ def apply_reranker_placeholder(
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Failed to initialize reranker client: {e}, skipping LLM reranking")
|
||||
# return results
|
||||
|
||||
|
||||
# # Get model name for metadata
|
||||
# model_name = getattr(reranker_client, 'model_name', 'unknown')
|
||||
|
||||
|
||||
# # Process each category
|
||||
# reranked_results = {}
|
||||
# for category in ["statements", "chunks", "entities", "summaries"]:
|
||||
@@ -531,38 +547,38 @@ def apply_reranker_placeholder(
|
||||
# if not items:
|
||||
# reranked_results[category] = []
|
||||
# continue
|
||||
|
||||
|
||||
# # Select top K items by combined_score for reranking
|
||||
# sorted_items = sorted(
|
||||
# items,
|
||||
# key=lambda x: float(x.get("combined_score", x.get("score", 0.0)) or 0.0),
|
||||
# reverse=True
|
||||
# )
|
||||
|
||||
|
||||
# top_items = sorted_items[:top_k]
|
||||
# remaining_items = sorted_items[top_k:]
|
||||
|
||||
|
||||
# # Extract text content from each item
|
||||
# def extract_text(item: Dict[str, Any]) -> str:
|
||||
# """Extract text content from a result item."""
|
||||
# # Try different text fields based on category
|
||||
# text = item.get("text") or item.get("content") or item.get("statement") or item.get("name") or ""
|
||||
# return str(text).strip()
|
||||
|
||||
|
||||
# # Batch items for concurrent processing
|
||||
# batches = []
|
||||
# for i in range(0, len(top_items), batch_size):
|
||||
# batch = top_items[i:i + batch_size]
|
||||
# batches.append(batch)
|
||||
|
||||
|
||||
# # Process batches concurrently
|
||||
# async def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
# """Process a batch of items with LLM relevance scoring."""
|
||||
# scored_batch = []
|
||||
|
||||
|
||||
# for item in batch:
|
||||
# item_text = extract_text(item)
|
||||
|
||||
|
||||
# # Skip items with no text
|
||||
# if not item_text:
|
||||
# item_copy = item.copy()
|
||||
@@ -572,7 +588,7 @@ def apply_reranker_placeholder(
|
||||
# item_copy["reranker_model"] = model_name
|
||||
# scored_batch.append(item_copy)
|
||||
# continue
|
||||
|
||||
|
||||
# # Create relevance scoring prompt
|
||||
# prompt = f"""Given the search query and a result item, rate the relevance of the item to the query on a scale from 0.0 to 1.0.
|
||||
|
||||
@@ -585,15 +601,15 @@ def apply_reranker_placeholder(
|
||||
# - 1.0 means perfectly relevant
|
||||
|
||||
# Relevance score:"""
|
||||
|
||||
|
||||
# # Send request to LLM
|
||||
# try:
|
||||
# messages = [{"role": "user", "content": prompt}]
|
||||
# response = await reranker_client.chat(messages)
|
||||
|
||||
|
||||
# # Parse LLM response to extract relevance score
|
||||
# response_text = str(response.content if hasattr(response, 'content') else response).strip()
|
||||
|
||||
|
||||
# # Try to extract a float from the response
|
||||
# try:
|
||||
# # Remove any non-numeric characters except decimal point
|
||||
@@ -608,11 +624,11 @@ def apply_reranker_placeholder(
|
||||
# except (ValueError, AttributeError) as e:
|
||||
# logger.warning(f"Invalid LLM score format: {response_text}, using combined_score. Error: {e}")
|
||||
# llm_score = None
|
||||
|
||||
|
||||
# # Calculate final score
|
||||
# item_copy = item.copy()
|
||||
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
|
||||
|
||||
# if llm_score is not None:
|
||||
# final_score = (1 - llm_weight) * combined_score + llm_weight * llm_score
|
||||
# item_copy["llm_relevance_score"] = llm_score
|
||||
@@ -620,7 +636,7 @@ def apply_reranker_placeholder(
|
||||
# # Use combined_score as fallback
|
||||
# final_score = combined_score
|
||||
# item_copy["llm_relevance_score"] = combined_score
|
||||
|
||||
|
||||
# item_copy["final_score"] = final_score
|
||||
# item_copy["reranker_model"] = model_name
|
||||
# scored_batch.append(item_copy)
|
||||
@@ -632,14 +648,14 @@ def apply_reranker_placeholder(
|
||||
# item_copy["llm_relevance_score"] = combined_score
|
||||
# item_copy["reranker_model"] = model_name
|
||||
# scored_batch.append(item_copy)
|
||||
|
||||
|
||||
# return scored_batch
|
||||
|
||||
|
||||
# # Process all batches concurrently
|
||||
# try:
|
||||
# batch_tasks = [process_batch(batch) for batch in batches]
|
||||
# batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
|
||||
|
||||
|
||||
# # Merge batch results
|
||||
# scored_items = []
|
||||
# for result in batch_results:
|
||||
@@ -647,7 +663,7 @@ def apply_reranker_placeholder(
|
||||
# logger.warning(f"Batch processing failed: {result}")
|
||||
# continue
|
||||
# scored_items.extend(result)
|
||||
|
||||
|
||||
# # Add remaining items (not in top K) with their combined_score as final_score
|
||||
# for item in remaining_items:
|
||||
# item_copy = item.copy()
|
||||
@@ -655,11 +671,11 @@ def apply_reranker_placeholder(
|
||||
# item_copy["final_score"] = combined_score
|
||||
# item_copy["reranker_model"] = model_name
|
||||
# scored_items.append(item_copy)
|
||||
|
||||
|
||||
# # Sort all items by final_score in descending order
|
||||
# scored_items.sort(key=lambda x: float(x.get("final_score", 0.0) or 0.0), reverse=True)
|
||||
# reranked_results[category] = scored_items
|
||||
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error in LLM reranking for category {category}: {e}, returning original results")
|
||||
# # Return original items with combined_score as final_score
|
||||
@@ -668,22 +684,22 @@ def apply_reranker_placeholder(
|
||||
# item["final_score"] = combined_score
|
||||
# item["reranker_model"] = model_name
|
||||
# reranked_results[category] = items
|
||||
|
||||
|
||||
# return reranked_results
|
||||
|
||||
|
||||
async def run_hybrid_search(
|
||||
query_text: str,
|
||||
search_type: str,
|
||||
end_user_id: str | None,
|
||||
limit: int,
|
||||
include: List[str],
|
||||
output_path: str | None,
|
||||
memory_config: "MemoryConfig",
|
||||
rerank_alpha: float = 0.6,
|
||||
activation_boost_factor: float = 0.8,
|
||||
use_forgetting_rerank: bool = False,
|
||||
use_llm_rerank: bool = False,
|
||||
query_text: str,
|
||||
search_type: str,
|
||||
end_user_id: str | None,
|
||||
limit: int,
|
||||
include: List[str],
|
||||
output_path: str | None,
|
||||
memory_config: "MemoryConfig",
|
||||
rerank_alpha: float = 0.6,
|
||||
activation_boost_factor: float = 0.8,
|
||||
use_forgetting_rerank: bool = False,
|
||||
use_llm_rerank: bool = False,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -699,7 +715,7 @@ async def run_hybrid_search(
|
||||
|
||||
# Clean and normalize the incoming query before use/logging
|
||||
query_text = extract_plain_query(query_text)
|
||||
|
||||
|
||||
# Validate query is not empty after cleaning
|
||||
if not query_text or not query_text.strip():
|
||||
logger.warning("Empty query after cleaning, returning empty results")
|
||||
@@ -716,7 +732,7 @@ async def run_hybrid_search(
|
||||
"error": "Empty query"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Log the search query
|
||||
log_search_query(query_text, search_type, end_user_id, limit, include)
|
||||
|
||||
@@ -747,7 +763,7 @@ async def run_hybrid_search(
|
||||
# Embedding-based search
|
||||
logger.info("[PERF] Starting embedding search...")
|
||||
embedding_start = time.time()
|
||||
|
||||
|
||||
# 从数据库读取嵌入器配置(按 ID)并构建 RedBearModelConfig
|
||||
config_load_start = time.time()
|
||||
try:
|
||||
@@ -758,8 +774,7 @@ async def run_hybrid_search(
|
||||
model_name=embedder_config_dict["model_name"],
|
||||
provider=embedder_config_dict["provider"],
|
||||
api_key=embedder_config_dict["api_key"],
|
||||
base_url=embedder_config_dict["base_url"],
|
||||
type="llm"
|
||||
base_url=embedder_config_dict["base_url"]
|
||||
)
|
||||
config_load_time = time.time() - config_load_start
|
||||
logger.info(f"[PERF] Config loading took {config_load_time:.4f}s")
|
||||
@@ -769,7 +784,7 @@ async def run_hybrid_search(
|
||||
embedder = OpenAIEmbedderClient(model_config=rb_config)
|
||||
embedder_init_time = time.time() - embedder_init_start
|
||||
logger.info(f"[PERF] Embedder init took {embedder_init_time:.4f}s")
|
||||
|
||||
|
||||
embedding_task = asyncio.create_task(
|
||||
search_graph_by_embedding(
|
||||
connector=connector,
|
||||
@@ -789,7 +804,7 @@ async def run_hybrid_search(
|
||||
|
||||
if keyword_task:
|
||||
keyword_results = await keyword_task
|
||||
keyword_latency = time.time() - keyword_start
|
||||
keyword_latency = time.time() - search_start_time
|
||||
latency_metrics["keyword_search_latency"] = round(keyword_latency, 4)
|
||||
logger.info(f"[PERF] Keyword search completed in {keyword_latency:.4f}s")
|
||||
if search_type == "keyword":
|
||||
@@ -799,7 +814,7 @@ async def run_hybrid_search(
|
||||
|
||||
if embedding_task:
|
||||
embedding_results = await embedding_task
|
||||
embedding_latency = time.time() - embedding_start
|
||||
embedding_latency = time.time() - search_start_time
|
||||
latency_metrics["embedding_search_latency"] = round(embedding_latency, 4)
|
||||
logger.info(f"[PERF] Embedding search completed in {embedding_latency:.4f}s")
|
||||
if search_type == "embedding":
|
||||
@@ -811,7 +826,8 @@ async def run_hybrid_search(
|
||||
if search_type == "hybrid":
|
||||
results["combined_summary"] = {
|
||||
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
|
||||
"total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
|
||||
"total_embedding_results": sum(
|
||||
len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
|
||||
"search_query": query_text,
|
||||
"search_timestamp": datetime.now().isoformat()
|
||||
}
|
||||
@@ -819,7 +835,7 @@ async def run_hybrid_search(
|
||||
# Apply two-stage reranking with ACTR activation calculation
|
||||
rerank_start = time.time()
|
||||
logger.info("[PERF] Using two-stage reranking with ACTR activation")
|
||||
|
||||
|
||||
# 加载遗忘引擎配置
|
||||
config_start = time.time()
|
||||
try:
|
||||
@@ -830,7 +846,7 @@ async def run_hybrid_search(
|
||||
forgetting_cfg = ForgettingEngineConfig()
|
||||
config_time = time.time() - config_start
|
||||
logger.info(f"[PERF] Forgetting config loading took {config_time:.4f}s")
|
||||
|
||||
|
||||
# 统一使用激活度重排序(两阶段:检索 + ACTR计算)
|
||||
rerank_compute_start = time.time()
|
||||
reranked_results = rerank_with_activation(
|
||||
@@ -843,14 +859,14 @@ async def run_hybrid_search(
|
||||
)
|
||||
rerank_compute_time = time.time() - rerank_compute_start
|
||||
logger.info(f"[PERF] Rerank computation took {rerank_compute_time:.4f}s")
|
||||
|
||||
|
||||
rerank_latency = time.time() - rerank_start
|
||||
latency_metrics["reranking_latency"] = round(rerank_latency, 4)
|
||||
logger.info(f"[PERF] Total reranking completed in {rerank_latency:.4f}s")
|
||||
|
||||
|
||||
# Optional: apply reranker placeholder if enabled via config
|
||||
reranked_results = apply_reranker_placeholder(reranked_results, query_text)
|
||||
|
||||
|
||||
# Apply LLM reranking if enabled
|
||||
llm_rerank_applied = False
|
||||
# if use_llm_rerank:
|
||||
@@ -863,11 +879,12 @@ async def run_hybrid_search(
|
||||
# logger.info("LLM reranking applied successfully")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"LLM reranking failed: {e}, using previous scores")
|
||||
|
||||
|
||||
results["reranked_results"] = reranked_results
|
||||
results["combined_summary"] = {
|
||||
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
|
||||
"total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
|
||||
"total_embedding_results": sum(
|
||||
len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
|
||||
"total_reranked_results": sum(len(v) if isinstance(v, list) else 0 for v in reranked_results.values()),
|
||||
"search_query": query_text,
|
||||
"search_timestamp": datetime.now().isoformat(),
|
||||
@@ -880,13 +897,13 @@ async def run_hybrid_search(
|
||||
# Calculate total latency
|
||||
total_latency = time.time() - search_start_time
|
||||
latency_metrics["total_latency"] = round(total_latency, 4)
|
||||
|
||||
|
||||
# Add latency metrics to results
|
||||
if "combined_summary" in results:
|
||||
results["combined_summary"]["latency_metrics"] = latency_metrics
|
||||
else:
|
||||
results["latency_metrics"] = latency_metrics
|
||||
|
||||
|
||||
logger.info(f"[PERF] ===== SEARCH PERFORMANCE SUMMARY =====")
|
||||
logger.info(f"[PERF] Total search completed in {total_latency:.4f}s")
|
||||
logger.info(f"[PERF] Latency breakdown: {json.dumps(latency_metrics, indent=2)}")
|
||||
@@ -909,8 +926,10 @@ async def run_hybrid_search(
|
||||
# Log search completion with result count
|
||||
if search_type == "hybrid":
|
||||
result_counts = {
|
||||
"keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in keyword_results.items()},
|
||||
"embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in embedding_results.items()}
|
||||
"keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in
|
||||
keyword_results.items()},
|
||||
"embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in
|
||||
embedding_results.items()}
|
||||
}
|
||||
else:
|
||||
result_counts = {key: len(value) if isinstance(value, list) else 0 for key, value in results.items()}
|
||||
@@ -928,12 +947,12 @@ async def run_hybrid_search(
|
||||
|
||||
|
||||
async def search_by_temporal(
|
||||
end_user_id: Optional[str] = "test",
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
end_user_id: Optional[str] = "test",
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
):
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
@@ -969,13 +988,13 @@ async def search_by_temporal(
|
||||
|
||||
|
||||
async def search_by_keyword_temporal(
|
||||
query_text: str,
|
||||
end_user_id: Optional[str] = "test",
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
query_text: str,
|
||||
end_user_id: Optional[str] = "test",
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
):
|
||||
"""
|
||||
Temporal keyword search across Statements.
|
||||
@@ -1012,9 +1031,9 @@ async def search_by_keyword_temporal(
|
||||
|
||||
|
||||
async def search_chunk_by_chunk_id(
|
||||
chunk_id: str,
|
||||
end_user_id: Optional[str] = "test",
|
||||
limit: int = 1,
|
||||
chunk_id: str,
|
||||
end_user_id: Optional[str] = "test",
|
||||
limit: int = 1,
|
||||
):
|
||||
"""
|
||||
Search for Chunks by chunk_id.
|
||||
@@ -1027,4 +1046,3 @@ async def search_chunk_by_chunk_id(
|
||||
limit=limit
|
||||
)
|
||||
return {"chunks": chunks}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import asyncio
|
||||
import difflib # 提供字符串相似度计算工具
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
@@ -16,6 +17,8 @@ from app.core.memory.models.graph_models import (
|
||||
)
|
||||
from app.core.memory.models.variate_config import DedupConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 模块级类型统一工具函数
|
||||
def _unify_entity_type(canonical: ExtractedEntityNode, losing: ExtractedEntityNode, suggested_type: str = None) -> None:
|
||||
@@ -198,6 +201,161 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 用户和AI助手的占位名称集合(用于名称标准化)
|
||||
_USER_PLACEHOLDER_NAMES = {"用户", "我", "user", "i"}
|
||||
_ASSISTANT_PLACEHOLDER_NAMES = {"ai助手", "助手", "人工智能助手", "智能助手", "智能体", "ai assistant", "assistant"}
|
||||
|
||||
# 标准化后的规范名称和类型
|
||||
_CANONICAL_USER_NAME = "用户"
|
||||
_CANONICAL_USER_TYPE = "用户"
|
||||
_CANONICAL_ASSISTANT_NAME = "AI助手"
|
||||
_CANONICAL_ASSISTANT_TYPE = "Agent"
|
||||
|
||||
# 用户和AI助手的所有可能名称(用于判断实体是否为特殊角色实体)
|
||||
_ALL_USER_NAMES = _USER_PLACEHOLDER_NAMES
|
||||
_ALL_ASSISTANT_NAMES = _ASSISTANT_PLACEHOLDER_NAMES
|
||||
|
||||
|
||||
def _is_user_entity(ent: ExtractedEntityNode) -> bool:
|
||||
"""判断实体是否为用户实体(name 或 entity_type 匹配)"""
|
||||
name = (getattr(ent, "name", "") or "").strip().lower()
|
||||
etype = (getattr(ent, "entity_type", "") or "").strip()
|
||||
return name in _ALL_USER_NAMES or etype == _CANONICAL_USER_TYPE
|
||||
|
||||
|
||||
def _is_assistant_entity(ent: ExtractedEntityNode) -> bool:
|
||||
"""判断实体是否为AI助手实体(name 或 entity_type 匹配)"""
|
||||
name = (getattr(ent, "name", "") or "").strip().lower()
|
||||
etype = (getattr(ent, "entity_type", "") or "").strip()
|
||||
return name in _ALL_ASSISTANT_NAMES or etype == _CANONICAL_ASSISTANT_TYPE
|
||||
|
||||
|
||||
def _would_merge_cross_role(a: ExtractedEntityNode, b: ExtractedEntityNode) -> bool:
|
||||
"""判断两个实体的合并是否会跨越用户/AI助手角色边界。
|
||||
|
||||
用户实体和AI助手实体永远不应该被合并在一起。
|
||||
如果一方是用户实体、另一方是AI助手实体,返回 True(阻止合并)。
|
||||
"""
|
||||
return (
|
||||
(_is_user_entity(a) and _is_assistant_entity(b))
|
||||
or (_is_assistant_entity(a) and _is_user_entity(b))
|
||||
)
|
||||
|
||||
|
||||
def _normalize_special_entity_names(
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
) -> None:
|
||||
"""标准化用户和AI助手实体的名称和类型。
|
||||
|
||||
多轮对话中,LLM 对同一角色可能使用不同的名称变体(如"用户"/"我"/"User",
|
||||
"AI助手"/"助手"/"Assistant"),导致精确匹配无法合并。
|
||||
此函数在去重前将这些变体统一为规范名称,并强制绑定 entity_type,确保:
|
||||
- name="用户" 的实体 entity_type 一定为 "用户"
|
||||
- name="AI助手" 的实体 entity_type 一定为 "Agent"
|
||||
|
||||
Args:
|
||||
entity_nodes: 实体节点列表(原地修改)
|
||||
"""
|
||||
for ent in entity_nodes:
|
||||
name = (getattr(ent, "name", "") or "").strip()
|
||||
name_lower = name.lower()
|
||||
|
||||
if name_lower in _USER_PLACEHOLDER_NAMES:
|
||||
ent.name = _CANONICAL_USER_NAME
|
||||
ent.entity_type = _CANONICAL_USER_TYPE
|
||||
elif name_lower in _ASSISTANT_PLACEHOLDER_NAMES:
|
||||
ent.name = _CANONICAL_ASSISTANT_NAME
|
||||
ent.entity_type = _CANONICAL_ASSISTANT_TYPE
|
||||
|
||||
# 第二步:清洗用户/AI助手之间的别名交叉污染(复用 clean_cross_role_aliases)
|
||||
clean_cross_role_aliases(entity_nodes)
|
||||
|
||||
|
||||
async def fetch_neo4j_assistant_aliases(neo4j_connector, end_user_id: str) -> set:
|
||||
"""从 Neo4j 查询 AI 助手实体的所有别名(小写归一化)。
|
||||
|
||||
这是助手别名查询的唯一入口,供 write_tools 和 extraction_orchestrator 共用,
|
||||
避免多处维护相同的 Cypher 和名称列表。
|
||||
|
||||
Args:
|
||||
neo4j_connector: Neo4j 连接器实例(需提供 execute_query 方法)
|
||||
end_user_id: 终端用户 ID
|
||||
|
||||
Returns:
|
||||
小写归一化后的助手别名集合
|
||||
"""
|
||||
# 查询名称列表:规范名称 + 常见变体(与 _normalize_special_entity_names 标准化后一致)
|
||||
query_names = [_CANONICAL_ASSISTANT_NAME, *_ASSISTANT_PLACEHOLDER_NAMES]
|
||||
# 去重保序
|
||||
query_names = list(dict.fromkeys(query_names))
|
||||
|
||||
cypher = """
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE e.end_user_id = $end_user_id AND e.name IN $names
|
||||
RETURN e.aliases AS aliases
|
||||
"""
|
||||
try:
|
||||
result = await neo4j_connector.execute_query(
|
||||
cypher, end_user_id=end_user_id, names=query_names
|
||||
)
|
||||
assistant_aliases: set = set()
|
||||
for record in (result or []):
|
||||
for alias in (record.get("aliases") or []):
|
||||
assistant_aliases.add(alias.strip().lower())
|
||||
if assistant_aliases:
|
||||
logger.debug(f"Neo4j 中 AI 助手别名: {assistant_aliases}")
|
||||
return assistant_aliases
|
||||
except Exception as e:
|
||||
logger.warning(f"查询 Neo4j AI 助手别名失败: {e}")
|
||||
return set()
|
||||
|
||||
|
||||
def clean_cross_role_aliases(
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
external_assistant_aliases: set = None,
|
||||
) -> None:
|
||||
"""清洗用户实体和AI助手实体之间的别名交叉污染。
|
||||
|
||||
在 Neo4j 写入前调用,确保:
|
||||
- 用户实体的 aliases 不包含 AI 助手的别名
|
||||
- AI 助手实体的 aliases 不包含用户的别名
|
||||
|
||||
Args:
|
||||
entity_nodes: 实体节点列表(原地修改)
|
||||
external_assistant_aliases: 外部传入的 AI 助手别名集合(如从 Neo4j 查询),
|
||||
与本轮实体中的 AI 助手别名合并使用
|
||||
"""
|
||||
# 收集本轮 AI 助手实体的所有别名
|
||||
assistant_aliases = set(external_assistant_aliases or set())
|
||||
user_aliases = set()
|
||||
|
||||
for ent in entity_nodes:
|
||||
if _is_assistant_entity(ent):
|
||||
for alias in (getattr(ent, "aliases", []) or []):
|
||||
assistant_aliases.add(alias.strip().lower())
|
||||
elif _is_user_entity(ent):
|
||||
for alias in (getattr(ent, "aliases", []) or []):
|
||||
user_aliases.add(alias.strip().lower())
|
||||
|
||||
# 从用户实体的 aliases 中移除 AI 助手别名
|
||||
if assistant_aliases:
|
||||
for ent in entity_nodes:
|
||||
if _is_user_entity(ent):
|
||||
original = getattr(ent, "aliases", []) or []
|
||||
cleaned = [a for a in original if a.strip().lower() not in assistant_aliases]
|
||||
if len(cleaned) < len(original):
|
||||
ent.aliases = cleaned
|
||||
|
||||
# 从 AI 助手实体的 aliases 中移除用户别名
|
||||
if user_aliases:
|
||||
for ent in entity_nodes:
|
||||
if _is_assistant_entity(ent):
|
||||
original = getattr(ent, "aliases", []) or []
|
||||
cleaned = [a for a in original if a.strip().lower() not in user_aliases]
|
||||
if len(cleaned) < len(original):
|
||||
ent.aliases = cleaned
|
||||
|
||||
|
||||
def accurate_match(
|
||||
entity_nodes: List[ExtractedEntityNode]
|
||||
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]:
|
||||
@@ -261,6 +419,10 @@ def accurate_match(
|
||||
canonical = alias_index.get((ent_uid, ent_name))
|
||||
# 确保不是自身
|
||||
if canonical is not None and canonical.id != ent.id:
|
||||
# 保护:禁止跨角色合并(用户实体和AI助手实体不能互相合并)
|
||||
if _would_merge_cross_role(canonical, ent):
|
||||
i += 1
|
||||
continue
|
||||
_merge_attribute(canonical, ent)
|
||||
id_redirect[ent.id] = canonical.id
|
||||
for k, v in list(id_redirect.items()):
|
||||
@@ -704,6 +866,11 @@ def fuzzy_match(
|
||||
# 条件A(快速通道):alias_match_merge = True
|
||||
# 条件B(标准通道):s_name ≥ tn AND s_type ≥ type_threshold AND overall ≥ tover
|
||||
if alias_match_merge or (s_name >= tn and s_type >= type_threshold and overall >= tover):
|
||||
# 保护:禁止跨角色合并(用户实体和AI助手实体不能互相合并)
|
||||
if _would_merge_cross_role(a, b):
|
||||
j += 1
|
||||
continue
|
||||
|
||||
# ========== 第六步:执行实体合并 ==========
|
||||
|
||||
# 6.1 合并别名
|
||||
@@ -813,6 +980,12 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
|
||||
b = entity_by_id.get(losing_id)
|
||||
if not a or not b: # 若不存在 a 或 b,可能已在精确或模糊阶段合并,在之前阶段合并之后,不会再处理但是处于审计的目的会记录
|
||||
continue
|
||||
# 保护:禁止跨角色合并(用户实体和AI助手实体不能互相合并)
|
||||
if _would_merge_cross_role(a, b):
|
||||
llm_records.append(
|
||||
f"[LLM阻断] 跨角色合并被阻止: {a.id} ({a.name}) 与 {b.id} ({b.name})"
|
||||
)
|
||||
continue
|
||||
_merge_attribute(a, b)
|
||||
# ID 重定向
|
||||
try:
|
||||
@@ -934,6 +1107,9 @@ async def deduplicate_entities_and_edges(
|
||||
返回:去重后的实体、语句→实体边、实体↔实体边。
|
||||
"""
|
||||
local_llm_records: List[str] = [] # 作为“审计日志”的本地收集器 初始化,保留为了之后对于LLM决策追溯
|
||||
# 0) 标准化用户和AI助手实体名称(确保多轮对话中的变体名称统一)
|
||||
_normalize_special_entity_names(entity_nodes)
|
||||
|
||||
# 1) 精确匹配
|
||||
deduped_entities, id_redirect, exact_merge_map = accurate_match(entity_nodes)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from app.core.memory.models.message_models import DialogData
|
||||
from app.core.memory.models.variate_config import ExtractionPipelineConfig
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
|
||||
deduplicate_entities_and_edges,
|
||||
clean_cross_role_aliases,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import (
|
||||
second_layer_dedup_and_merge_with_neo4j,
|
||||
@@ -100,6 +101,10 @@ async def dedup_layers_and_merge_and_return(
|
||||
except Exception as e:
|
||||
print(f"Second-layer dedup failed: {e}")
|
||||
|
||||
# 第二层去重后,清洗用户/AI助手之间的别名交叉污染
|
||||
# 第二层从 Neo4j 合并了旧实体,可能带入历史脏数据
|
||||
clean_cross_role_aliases(fused_entity_nodes)
|
||||
|
||||
return (
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
|
||||
@@ -44,6 +44,10 @@ from app.core.memory.models.variate_config import (
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import (
|
||||
dedup_layers_and_merge_and_return,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
|
||||
_USER_PLACEHOLDER_NAMES,
|
||||
fetch_neo4j_assistant_aliases,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import (
|
||||
embedding_generation,
|
||||
generate_entity_embeddings_from_triplets,
|
||||
@@ -1341,14 +1345,20 @@ class ExtractionOrchestrator:
|
||||
dialog_data_list: List[DialogData]
|
||||
) -> None:
|
||||
"""
|
||||
从 Neo4j 读取用户实体的最终 aliases,同步到 end_user 和 end_user_info 表
|
||||
将本轮提取的用户别名同步到 end_user 和 end_user_info 表。
|
||||
|
||||
注意:
|
||||
1. other_name 使用本次对话提取的第一个别名(保持时间顺序)
|
||||
2. aliases 从 Neo4j 读取(保持完整性)
|
||||
注意:此方法在 Neo4j 写入之前调用,因此不能依赖 Neo4j 作为别名的权威数据源。
|
||||
改为直接使用内存中去重后的 entity_nodes 的 aliases,与 PgSQL 已有的 aliases 合并。
|
||||
|
||||
策略:
|
||||
1. 从内存中的 entity_nodes 提取本轮用户别名(current_aliases)
|
||||
2. 从去重后的 entity_nodes 中提取完整别名(含 Neo4j 二层去重合并的历史别名)
|
||||
3. 从 PgSQL end_user_info 读取已有的 aliases(db_aliases)
|
||||
4. 合并 db_aliases + deduped_aliases + current_aliases,去重保序
|
||||
5. 写回 PgSQL
|
||||
|
||||
Args:
|
||||
entity_nodes: 实体节点列表
|
||||
entity_nodes: 去重后的实体节点列表(内存中,含二层去重合并结果)
|
||||
dialog_data_list: 对话数据列表
|
||||
"""
|
||||
try:
|
||||
@@ -1361,23 +1371,40 @@ class ExtractionOrchestrator:
|
||||
logger.warning("end_user_id 为空,跳过用户别名同步")
|
||||
return
|
||||
|
||||
# 1. 提取本次对话的用户别名(保持 LLM 提取的原始顺序,不排序)
|
||||
current_aliases = self._extract_current_aliases(entity_nodes)
|
||||
# 1. 提取本轮对话的用户别名(保持 LLM 提取的原始顺序,不排序)
|
||||
current_aliases = self._extract_current_aliases(entity_nodes, dialog_data_list)
|
||||
|
||||
# 2. 从 Neo4j 获取完整 aliases(权威数据源)
|
||||
neo4j_aliases = await self._fetch_neo4j_user_aliases(end_user_id)
|
||||
# 1.5 从去重后的 entity_nodes 中提取完整别名
|
||||
# 二层去重会将 Neo4j 中已有的历史别名合并到 entity_nodes 中,
|
||||
# 这里提取出来确保 PgSQL 与 Neo4j 的别名保持同步
|
||||
deduped_aliases = self._extract_deduped_entity_aliases(entity_nodes)
|
||||
|
||||
if not neo4j_aliases:
|
||||
# Neo4j 中没有别名,使用本次对话提取的别名
|
||||
neo4j_aliases = current_aliases
|
||||
if not neo4j_aliases:
|
||||
logger.debug(f"aliases 为空,跳过同步: end_user_id={end_user_id}")
|
||||
return
|
||||
# 1.6 从 Neo4j 查询已有的 AI 助手别名,作为额外的排除源
|
||||
# (防止 LLM 未提取出 AI 助手实体时,AI 别名泄漏到用户别名中)
|
||||
neo4j_assistant_aliases = await self._fetch_neo4j_assistant_aliases(end_user_id)
|
||||
if neo4j_assistant_aliases:
|
||||
before_count = len(current_aliases)
|
||||
current_aliases = [
|
||||
a for a in current_aliases
|
||||
if a.strip().lower() not in neo4j_assistant_aliases
|
||||
]
|
||||
if len(current_aliases) < before_count:
|
||||
logger.info(f"通过 Neo4j AI 助手别名排除了 {before_count - len(current_aliases)} 个误归属别名")
|
||||
# 同样过滤 deduped_aliases
|
||||
deduped_aliases = [
|
||||
a for a in deduped_aliases
|
||||
if a.strip().lower() not in neo4j_assistant_aliases
|
||||
]
|
||||
|
||||
logger.info(f"本次对话提取的 aliases: {current_aliases}")
|
||||
logger.info(f"Neo4j 中的完整 aliases: {neo4j_aliases}")
|
||||
if not current_aliases and not deduped_aliases:
|
||||
logger.debug(f"本轮未提取到用户别名,跳过同步: end_user_id={end_user_id}")
|
||||
return
|
||||
|
||||
# 3. 同步到数据库
|
||||
logger.info(f"本轮对话提取的 aliases: {current_aliases}")
|
||||
if deduped_aliases:
|
||||
logger.info(f"去重后实体的完整 aliases(含历史): {deduped_aliases}")
|
||||
|
||||
# 2. 同步到数据库
|
||||
end_user_uuid = uuid.UUID(end_user_id)
|
||||
with get_db_context() as db:
|
||||
# 更新 end_user 表
|
||||
@@ -1386,7 +1413,38 @@ class ExtractionOrchestrator:
|
||||
logger.warning(f"未找到 end_user_id={end_user_id} 的用户记录")
|
||||
return
|
||||
|
||||
new_name = self._resolve_other_name(end_user.other_name, current_aliases, neo4j_aliases)
|
||||
# 3. 从 PgSQL 读取已有 aliases 并与本轮合并
|
||||
info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid)
|
||||
db_aliases = (info.aliases if info and info.aliases else [])
|
||||
# 过滤掉占位名称
|
||||
db_aliases = [a for a in db_aliases if a.strip().lower() not in self.USER_PLACEHOLDER_NAMES]
|
||||
|
||||
# 合并:已有 + 去重后完整别名 + 本轮新增,去重保序
|
||||
merged_aliases = list(db_aliases)
|
||||
seen_lower = {a.strip().lower() for a in merged_aliases}
|
||||
# 先合并去重后实体的完整别名(含 Neo4j 历史别名)
|
||||
for alias in deduped_aliases:
|
||||
if alias.strip().lower() not in seen_lower:
|
||||
merged_aliases.append(alias)
|
||||
seen_lower.add(alias.strip().lower())
|
||||
# 再合并本轮新提取的别名
|
||||
for alias in current_aliases:
|
||||
if alias.strip().lower() not in seen_lower:
|
||||
merged_aliases.append(alias)
|
||||
seen_lower.add(alias.strip().lower())
|
||||
|
||||
# 最终过滤:从合并结果中排除 AI 助手别名(清理历史脏数据)
|
||||
if neo4j_assistant_aliases:
|
||||
merged_aliases = [
|
||||
a for a in merged_aliases
|
||||
if a.strip().lower() not in neo4j_assistant_aliases
|
||||
]
|
||||
|
||||
logger.info(f"PgSQL 已有 aliases: {db_aliases}")
|
||||
logger.info(f"合并后 aliases: {merged_aliases}")
|
||||
|
||||
# 更新 end_user 表 other_name
|
||||
new_name = self._resolve_other_name(end_user.other_name, current_aliases, merged_aliases)
|
||||
if new_name is not None:
|
||||
end_user.other_name = new_name
|
||||
logger.info(f"更新 end_user 表 other_name → {new_name}")
|
||||
@@ -1394,26 +1452,27 @@ class ExtractionOrchestrator:
|
||||
logger.debug(f"end_user 表 other_name 保持不变: {end_user.other_name}")
|
||||
|
||||
# 更新或创建 end_user_info 记录
|
||||
info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid)
|
||||
if info:
|
||||
new_name_info = self._resolve_other_name(info.other_name, current_aliases, neo4j_aliases)
|
||||
new_name_info = self._resolve_other_name(info.other_name, current_aliases, merged_aliases)
|
||||
if new_name_info is not None:
|
||||
info.other_name = new_name_info
|
||||
logger.info(f"更新 end_user_info 表 other_name → {new_name_info}")
|
||||
if info.aliases != neo4j_aliases:
|
||||
info.aliases = neo4j_aliases
|
||||
logger.info(f"同步 Neo4j aliases 到 end_user_info: {neo4j_aliases}")
|
||||
if info.aliases != merged_aliases:
|
||||
info.aliases = merged_aliases
|
||||
logger.info(f"同步合并后 aliases 到 end_user_info: {merged_aliases}")
|
||||
else:
|
||||
first_alias = current_aliases[0].strip() if current_aliases else ""
|
||||
first_alias = current_aliases[0].strip() if current_aliases else (
|
||||
deduped_aliases[0].strip() if deduped_aliases else ""
|
||||
)
|
||||
# 确保 first_alias 不是占位名称
|
||||
if first_alias and first_alias not in self.USER_PLACEHOLDER_NAMES:
|
||||
if first_alias and first_alias.lower() not in self.USER_PLACEHOLDER_NAMES:
|
||||
db.add(EndUserInfo(
|
||||
end_user_id=end_user_uuid,
|
||||
other_name=first_alias,
|
||||
aliases=neo4j_aliases,
|
||||
aliases=merged_aliases,
|
||||
meta_data={}
|
||||
))
|
||||
logger.info(f"创建 end_user_info 记录,other_name={first_alias}, aliases={neo4j_aliases}")
|
||||
logger.info(f"创建 end_user_info 记录,other_name={first_alias}, aliases={merged_aliases}")
|
||||
|
||||
db.commit()
|
||||
|
||||
@@ -1423,49 +1482,81 @@ class ExtractionOrchestrator:
|
||||
|
||||
|
||||
# 用户实体占位名称,不允许作为 other_name 或出现在 aliases 中
|
||||
USER_PLACEHOLDER_NAMES = {'用户', '我', 'User', 'I'}
|
||||
# 复用 deduped_and_disamb 模块级常量,避免重复维护
|
||||
USER_PLACEHOLDER_NAMES = _USER_PLACEHOLDER_NAMES
|
||||
|
||||
def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode]) -> List[str]:
|
||||
"""从实体节点提取用户别名(保持 LLM 提取的原始顺序,不进行任何排序)
|
||||
def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode], dialog_data_list=None) -> List[str]:
|
||||
"""从用户发言的原始实体中提取本轮新增别名(绕过去重污染)
|
||||
|
||||
这个方法直接返回 LLM 提取的别名列表,并过滤掉占位名称("用户"、"我"、"User"、"I")。
|
||||
第一个别名将被用作 other_name。
|
||||
策略:
|
||||
仅从 dialog_data_list 中找到 speaker="user" 的 statement,
|
||||
从这些 statement 的 triplet_extraction_info 中提取用户实体的 aliases。
|
||||
这样拿到的是 LLM 对用户原话的提取结果,不受去重合并的影响。
|
||||
|
||||
注意:不再使用去重后 entity_nodes 作为兜底,因为二层去重会将 Neo4j 历史别名
|
||||
合并进来,导致历史别名被误认为"本轮提取"。历史别名的同步由
|
||||
_extract_deduped_entity_aliases 负责。
|
||||
|
||||
Args:
|
||||
entity_nodes: 实体节点列表
|
||||
entity_nodes: 去重后的实体节点列表(未使用,保留参数兼容性)
|
||||
dialog_data_list: 对话数据列表
|
||||
|
||||
Returns:
|
||||
别名列表(保持 LLM 提取的原始顺序,已过滤占位名称)
|
||||
别名列表(保持原始顺序,已过滤)
|
||||
"""
|
||||
if not dialog_data_list:
|
||||
return []
|
||||
|
||||
all_user_aliases = []
|
||||
seen_lower = set()
|
||||
for dialog in dialog_data_list:
|
||||
for chunk in dialog.chunks:
|
||||
speaker = getattr(chunk, 'speaker', None)
|
||||
for statement in chunk.statements:
|
||||
stmt_speaker = getattr(statement, 'speaker', None) or speaker
|
||||
if stmt_speaker != "user":
|
||||
continue
|
||||
triplet_info = getattr(statement, 'triplet_extraction_info', None)
|
||||
if not triplet_info:
|
||||
continue
|
||||
for entity in (triplet_info.entities or []):
|
||||
ent_name = getattr(entity, 'name', '').strip()
|
||||
if ent_name.lower() in self.USER_PLACEHOLDER_NAMES:
|
||||
for alias in (getattr(entity, 'aliases', []) or []):
|
||||
a = alias.strip()
|
||||
if a and a.lower() not in self.USER_PLACEHOLDER_NAMES and a.lower() not in seen_lower:
|
||||
all_user_aliases.append(a)
|
||||
seen_lower.add(a.lower())
|
||||
if all_user_aliases:
|
||||
logger.debug(f"从用户原始发言提取到别名: {all_user_aliases}")
|
||||
return all_user_aliases
|
||||
|
||||
def _extract_deduped_entity_aliases(self, entity_nodes: List[ExtractedEntityNode]) -> List[str]:
|
||||
"""从去重后的用户实体中提取完整别名列表。
|
||||
|
||||
二层去重会将 Neo4j 中已有的历史别名合并到 entity_nodes 的用户实体中,
|
||||
因此这里提取到的别名包含了历史积累的所有别名,可用于同步到 PgSQL。
|
||||
|
||||
Args:
|
||||
entity_nodes: 去重后的实体节点列表(含二层去重合并结果)
|
||||
|
||||
Returns:
|
||||
别名列表(已过滤占位名称,去重保序)
|
||||
"""
|
||||
for entity in entity_nodes:
|
||||
if getattr(entity, 'name', '').strip() in self.USER_PLACEHOLDER_NAMES:
|
||||
if getattr(entity, 'name', '').strip().lower() in self.USER_PLACEHOLDER_NAMES:
|
||||
aliases = getattr(entity, 'aliases', []) or []
|
||||
# 过滤掉占位名称,防止 "用户"/"我"/"User"/"I" 被存入 aliases 和 other_name
|
||||
filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES]
|
||||
logger.debug(f"提取到用户别名(原始顺序,已过滤占位名称): {filtered}")
|
||||
return filtered
|
||||
filtered = [
|
||||
a for a in aliases
|
||||
if a.strip().lower() not in self.USER_PLACEHOLDER_NAMES
|
||||
]
|
||||
if filtered:
|
||||
return filtered
|
||||
return []
|
||||
|
||||
|
||||
async def _fetch_neo4j_user_aliases(self, end_user_id: str) -> List[str]:
|
||||
"""从 Neo4j 查询用户实体的完整 aliases 列表(已过滤占位名称)"""
|
||||
cypher = """
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE e.end_user_id = $end_user_id AND e.name IN ['用户', '我', 'User', 'I']
|
||||
RETURN e.aliases AS aliases
|
||||
LIMIT 1
|
||||
"""
|
||||
result = await Neo4jConnector().execute_query(cypher, end_user_id=end_user_id)
|
||||
if not result:
|
||||
logger.debug(f"Neo4j 中未找到用户实体: end_user_id={end_user_id}")
|
||||
return []
|
||||
aliases = result[0].get('aliases') or []
|
||||
if not aliases:
|
||||
logger.debug(f"Neo4j 用户实体 aliases 为空: end_user_id={end_user_id}")
|
||||
return []
|
||||
# 过滤掉占位名称,防止历史脏数据传播
|
||||
filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES]
|
||||
return filtered
|
||||
async def _fetch_neo4j_assistant_aliases(self, end_user_id: str) -> set:
|
||||
"""从 Neo4j 查询 AI 助手实体的所有别名(用于从用户别名中排除)"""
|
||||
return await fetch_neo4j_assistant_aliases(self.connector, end_user_id)
|
||||
|
||||
def _resolve_other_name(
|
||||
self,
|
||||
@@ -1484,16 +1575,16 @@ class ExtractionOrchestrator:
|
||||
注意:返回值不允许是占位名称("用户"、"我"、"User"、"I")
|
||||
"""
|
||||
# 当前值为空或为占位名称时,需要更新
|
||||
if not current or not current.strip() or current.strip() in self.USER_PLACEHOLDER_NAMES:
|
||||
if not current or not current.strip() or current.strip().lower() in self.USER_PLACEHOLDER_NAMES:
|
||||
candidate = current_aliases[0].strip() if current_aliases else None
|
||||
# 确保候选值不是占位名称
|
||||
if candidate and candidate in self.USER_PLACEHOLDER_NAMES:
|
||||
if candidate and candidate.lower() in self.USER_PLACEHOLDER_NAMES:
|
||||
return None
|
||||
return candidate
|
||||
if current not in neo4j_aliases:
|
||||
candidate = neo4j_aliases[0].strip() if neo4j_aliases else None
|
||||
# 确保候选值不是占位名称
|
||||
if candidate and candidate in self.USER_PLACEHOLDER_NAMES:
|
||||
if candidate and candidate.lower() in self.USER_PLACEHOLDER_NAMES:
|
||||
return None
|
||||
return candidate
|
||||
|
||||
|
||||
@@ -61,6 +61,7 @@ class TripletExtractor:
|
||||
predicate_instructions=PREDICATE_DEFINITIONS,
|
||||
language=self._get_language(),
|
||||
ontology_types=self.ontology_types,
|
||||
speaker=getattr(statement, 'speaker', None),
|
||||
)
|
||||
|
||||
# Create messages for LLM
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||
from app.core.memory.utils.log.logging_utils import log_prompt_rendering, log_template_rendering
|
||||
|
||||
# Setup Jinja2 environment
|
||||
@@ -205,6 +205,7 @@ async def render_triplet_extraction_prompt(
|
||||
predicate_instructions: dict = None,
|
||||
language: str = "zh",
|
||||
ontology_types: "OntologyTypeList | None" = None,
|
||||
speaker: str = None,
|
||||
) -> str:
|
||||
"""
|
||||
Renders the triplet extraction prompt using the extract_triplet.jinja2 template.
|
||||
@@ -216,6 +217,7 @@ async def render_triplet_extraction_prompt(
|
||||
predicate_instructions: Optional predicate instructions
|
||||
language: The language to use for entity descriptions ("zh" for Chinese, "en" for English)
|
||||
ontology_types: Optional OntologyTypeList containing predefined ontology types for entity classification
|
||||
speaker: Speaker role ("user" or "assistant") for the current statement
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
@@ -223,7 +225,7 @@ async def render_triplet_extraction_prompt(
|
||||
template = prompt_env.get_template("extract_triplet.jinja2")
|
||||
|
||||
# 准备本体类型数据
|
||||
ontology_type_section = ""
|
||||
ontology_type_section = None
|
||||
ontology_type_names = []
|
||||
type_hierarchy_hints = []
|
||||
if ontology_types and ontology_types.types:
|
||||
@@ -240,6 +242,7 @@ async def render_triplet_extraction_prompt(
|
||||
ontology_types=ontology_type_section,
|
||||
ontology_type_names=ontology_type_names,
|
||||
type_hierarchy_hints=type_hierarchy_hints,
|
||||
speaker=speaker,
|
||||
)
|
||||
# 记录渲染结果到提示日志(与示例日志结构一致)
|
||||
log_prompt_rendering('triplet extraction', rendered_prompt)
|
||||
|
||||
@@ -23,6 +23,16 @@ Extract entities and knowledge triplets from the given statement.
|
||||
===Inputs===
|
||||
**Chunk Content:** "{{ chunk_content }}"
|
||||
**Statement:** "{{ statement }}"
|
||||
{% if speaker %}
|
||||
**Speaker:** {{ speaker }}
|
||||
{% if speaker == "assistant" %}
|
||||
{% if language == "zh" %}
|
||||
⚠️ 当前陈述句来自 **AI助手的回复**。AI助手在回复中用来称呼用户的名字是**用户的别名**,不是 AI 助手的别名。但只能提取原文中逐字出现的名字,严禁推测或创造原文中不存在的别名变体。
|
||||
{% else %}
|
||||
⚠️ This statement is from the **AI assistant's reply**. Names the AI uses to address the user are **user's aliases**, NOT the AI assistant's aliases. But only extract names that appear VERBATIM in the text — never infer or fabricate alias variants.
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
{% if ontology_types %}
|
||||
===Ontology Type Guidance===
|
||||
@@ -87,7 +97,17 @@ Extract entities and knowledge triplets from the given statement.
|
||||
* "我叫张三,大家叫我小张" → aliases=["张三", "小张"](张三是第一个,将成为 other_name)
|
||||
* "大家叫我小李,我全名叫李明" → aliases=["小李", "李明"](小李先出现,将成为 other_name)
|
||||
- 空值:如果没有别名,使用 `[]`
|
||||
- 重要:只提取本次对话中明确提到的别名,不要推测或添加未提及的名字
|
||||
- **🚨🚨🚨 严禁幻觉:只提取对话原文中逐字出现的别名,绝对不能推测、衍生或创造任何未在原文中出现的名字。例如,看到"陈思远"不能自行添加"思远大人""远哥""小远"等变体。如果原文没有这些字,就不能出现在 aliases 中。**
|
||||
- **🚨 归属区分:必须严格区分名称的归属对象。默认情况下,用户提到的名字归属用户实体。只有出现明确的第二人称命名表达(如"叫你""给你取名")时,才将名字归属 AI/助手实体。**
|
||||
- **🚨 说话人视角:当 speaker 为 assistant 时,AI 助手用来称呼用户的名字是用户的别名,必须归入用户实体的 aliases,绝对不能归入 AI 助手实体。但同样只能提取原文中逐字出现的称呼,不能推测。**
|
||||
* "我叫陈思远,我给AI取名为远仔" → 用户 aliases=["陈思远"],AI助手 aliases=["远仔"]
|
||||
* "我叫vv" → 用户 aliases=["vv"](没有给AI取名的表达,名字归用户)
|
||||
* [speaker=assistant] "好的,VV" → 用户 aliases=["VV"](AI 在称呼用户,原文中出现了"VV")
|
||||
* [speaker=assistant] "我叫陈仔" → AI助手 aliases=["陈仔"](AI 在自我介绍,这是 AI 的别名)
|
||||
* ❌ 错误:将"远仔"放入用户的 aliases("远仔"是给AI取的名字,不是用户的名字)
|
||||
* ❌ 错误:用户说"我叫vv",却把"vv"放入 AI 助手的 aliases
|
||||
* ❌ 错误:AI 称呼用户为"VV",却把"VV"放入 AI 助手的 aliases
|
||||
* ❌ 错误:原文只有"陈思远",却在 aliases 中添加"思远大人""远哥""小远"等从未出现的变体(这是幻觉)
|
||||
{% else %}
|
||||
- Include: nicknames, full names, abbreviations, alternative names
|
||||
- Order: **The FIRST alias will be used as the user's primary display name (other_name). Put the most important/frequently used name FIRST**
|
||||
@@ -96,7 +116,17 @@ Extract entities and knowledge triplets from the given statement.
|
||||
* "I'm John, people call me Johnny" → aliases=["John", "Johnny"] (John is first, will become other_name)
|
||||
* "People call me Mike, my full name is Michael" → aliases=["Mike", "Michael"] (Mike appears first, will become other_name)
|
||||
- Empty: If no aliases, use `[]`
|
||||
- Important: Only extract aliases explicitly mentioned in current conversation, do not infer or add unmentioned names
|
||||
- **🚨🚨🚨 NO HALLUCINATION: Only extract aliases that appear VERBATIM in the original text. NEVER infer, derive, or fabricate names not present in the text. For example, seeing "John Smith" does NOT allow adding "Johnny", "Smithy", "Mr. Smith" unless those exact strings appear in the conversation.**
|
||||
- **🚨 Ownership distinction: By default, all names mentioned by the user belong to the user entity. Only assign a name to the AI/assistant entity when an explicit second-person naming expression (e.g., "I'll call you", "your name is") is present.**
|
||||
- **🚨 Speaker perspective: When speaker is "assistant", names the AI uses to address the user are the USER's aliases and MUST go into the user entity's aliases, NEVER into the AI assistant entity's aliases. But only extract names that appear verbatim in the text, never infer.**
|
||||
* "I'm Alex, I'll call you Buddy" → User aliases=["Alex"], AI assistant aliases=["Buddy"]
|
||||
* "I'm vv" → User aliases=["vv"] (no AI-naming expression, name belongs to user)
|
||||
* [speaker=assistant] "Sure thing, VV" → User aliases=["VV"] (AI addressing the user, "VV" appears in text)
|
||||
* [speaker=assistant] "I'm Jarvis" → AI assistant aliases=["Jarvis"] (AI self-introduction, this is AI's alias)
|
||||
* ❌ Wrong: putting "Buddy" in user's aliases ("Buddy" is a name for the AI, not the user)
|
||||
* ❌ Wrong: User says "I'm vv" but "vv" is put in AI assistant's aliases
|
||||
* ❌ Wrong: AI calls user "VV" but "VV" is put in AI assistant's aliases
|
||||
* ❌ Wrong: Text only has "John Smith" but aliases include "Johnny", "Smithy" (hallucinated variants)
|
||||
{% endif %}
|
||||
|
||||
|
||||
@@ -122,7 +152,60 @@ Extract entities and knowledge triplets from the given statement.
|
||||
|
||||
|
||||
|
||||
4. **ALIASES ORDER:**
|
||||
4. **AI/ASSISTANT ENTITY SPECIAL HANDLING:**
|
||||
{% if language == "zh" %}
|
||||
- **🚨 默认规则:如果对话中没有出现明确指向 AI/助手的命名表达,则所有名字都归属于用户实体。不要猜测或推断某个名字是给 AI 取的。**
|
||||
- 只有当用户**明确**对 AI/助手进行命名时,才创建 AI/助手实体并将对应名字放入其 aliases
|
||||
- AI/助手实体的 name 字段:使用 "AI助手"
|
||||
- 用户给 AI 取的名字:放入 AI/助手实体的 aliases
|
||||
- **🚨 禁止将用户给 AI 取的名字放入用户实体的 aliases 中**
|
||||
- **必须出现以下明确的命名表达才能判定为给 AI 取名:**「给你取名」「叫你」「称呼你为」「给AI取名」「你的名字是」「以后叫你」「你就叫」「你不叫X了」「你现在叫」等**第二人称(你)或明确指向 AI 的命名句式**
|
||||
- **🚨 "你不叫X了"/"你不叫X,你叫Y" 句式:X 和 Y 都是 AI 的名字(旧名和新名),绝对不是用户的名字。因为句子主语是"你"(AI)。**
|
||||
- **以下情况名字归属用户,不是给 AI 取名:**「我叫」「我的名字是」「叫我」「我是」「大家叫我」「我的英文名是」「我的昵称是」等**第一人称(我)的自我介绍句式**
|
||||
- **🚨 speaker=assistant 时的特殊规则:**
|
||||
* AI 用来称呼用户的名字 → 归入**用户**实体的 aliases(但必须是原文中逐字出现的称呼,不能推测)
|
||||
* AI 自称的名字(如"我叫陈仔""我是你的助手")→ 归入**AI助手**实体的 aliases
|
||||
* 判断依据:AI 说"你叫X"或用 X 称呼用户 → X 是用户别名;AI 说"我叫X"或"我是X" → X 是 AI 别名
|
||||
- 示例:
|
||||
* "我叫vv" → 用户实体: name="用户", aliases=["vv"](第一人称自我介绍,名字归用户)
|
||||
* "我的英文名叫vv" → 用户实体: name="用户", aliases=["vv"](第一人称自我介绍,名字归用户)
|
||||
* "我叫陈思远,我给AI取名为远仔" → 用户实体: name="用户", aliases=["陈思远"];AI实体: name="AI助手", aliases=["远仔"]
|
||||
* "叫你小助,我自己叫老王" → 用户实体: name="用户", aliases=["老王"];AI实体: name="AI助手", aliases=["小助"]
|
||||
* "你不叫远仔了,你现在叫陈仔" → AI实体: name="AI助手", aliases=["陈仔"]("远仔"是AI旧名,"陈仔"是AI新名,都归AI。不要把"远仔"或"陈仔"放入用户的aliases)
|
||||
* [speaker=assistant] "好的VV,今天想干点啥?" → 用户实体: name="用户", aliases=["VV"](AI 在称呼用户,原文中出现了"VV")
|
||||
* [speaker=assistant] "你叫陈思远,我叫陈仔" → 用户实体: name="用户", aliases=["陈思远"];AI实体: name="AI助手", aliases=["陈仔"]
|
||||
* ❌ 错误:用户说"我叫vv",却把"vv"放入 AI 助手的 aliases(没有任何给 AI 取名的表达)
|
||||
* ❌ 错误:AI 称呼用户为"VV",却把"VV"放入 AI 助手的 aliases
|
||||
* ❌ 错误:aliases=["陈思远", "远仔"]("远仔"是给AI取的名字,不是用户的名字)
|
||||
* ❌ 错误:原文只有"陈思远",却在 aliases 中添加"思远大人""远哥""小远"等从未出现的变体(这是幻觉)
|
||||
{% else %}
|
||||
- **🚨 Default rule: If there is NO explicit AI/assistant naming expression in the conversation, ALL names belong to the user entity. Do NOT guess or infer that a name is for the AI.**
|
||||
- Only create an AI/assistant entity when the user **explicitly** names the AI/assistant
|
||||
- AI/assistant entity name field: use "AI Assistant"
|
||||
- Names the user gives to the AI: put in the AI/assistant entity's aliases
|
||||
- **🚨 NEVER put names given to the AI into the user entity's aliases**
|
||||
- **An AI-naming expression MUST be present to assign a name to the AI:** "I'll call you", "your name is", "I name you", "let me call you", "you'll be called", "you're not called X anymore", "your new name is", etc. — **second-person ("you") or explicit AI-directed naming patterns**
|
||||
- **🚨 "You're not called X anymore" / "You're not X, you're Y" pattern: BOTH X and Y are AI's names (old and new). They are NOT user's names. The subject is "you" (the AI).**
|
||||
- **These patterns mean the name belongs to the USER, NOT the AI:** "I'm", "my name is", "call me", "I am", "people call me", "my English name is", "my nickname is", etc. — **first-person ("I"/"me") self-introduction patterns**
|
||||
- **🚨 Special rules when speaker=assistant:**
|
||||
* Names the AI uses to address the user → belong to the **user** entity's aliases (but only extract names that appear verbatim in the text, never infer)
|
||||
* Names the AI uses for itself (e.g., "I'm Jarvis", "I am your assistant") → belong to the **AI assistant** entity's aliases
|
||||
* Rule: AI says "you are X" or calls user X → X is user's alias; AI says "I'm X" or "I am X" → X is AI's alias
|
||||
- Examples:
|
||||
* "I'm vv" → User entity: name="User", aliases=["vv"] (first-person intro, name belongs to user)
|
||||
* "My English name is vv" → User entity: name="User", aliases=["vv"] (first-person intro, name belongs to user)
|
||||
* "I'm Alex, I'll call you Buddy" → User entity: name="User", aliases=["Alex"]; AI entity: name="AI Assistant", aliases=["Buddy"]
|
||||
* "Call yourself Jarvis, my name is Tony" → User entity: name="User", aliases=["Tony"]; AI entity: name="AI Assistant", aliases=["Jarvis"]
|
||||
* "You're not called Jarvis anymore, your new name is Friday" → AI entity: name="AI Assistant", aliases=["Friday"] (both "Jarvis" and "Friday" are AI names, NOT user names)
|
||||
* [speaker=assistant] "Sure thing, VV" → User entity: name="User", aliases=["VV"] (AI addressing the user, "VV" appears in text)
|
||||
* [speaker=assistant] "You're Alex, and I'm Jarvis" → User entity: name="User", aliases=["Alex"]; AI entity: name="AI Assistant", aliases=["Jarvis"]
|
||||
* ❌ Wrong: User says "I'm vv" but "vv" is put in AI assistant's aliases (no AI-naming expression exists)
|
||||
* ❌ Wrong: AI calls user "VV" but "VV" is put in AI assistant's aliases
|
||||
* ❌ Wrong: aliases=["Alex", "Buddy"] ("Buddy" is a name for the AI, not the user)
|
||||
* ❌ Wrong: Text only has "John Smith" but aliases include "Johnny", "Smithy" (hallucinated variants)
|
||||
{% endif %}
|
||||
|
||||
5. **ALIASES ORDER:**
|
||||
{% if language == "zh" %}
|
||||
- 顺序优先级:按出现顺序,先出现的在前
|
||||
{% else %}
|
||||
@@ -202,8 +285,19 @@ Output:
|
||||
{"entity_idx": 0, "name": "Tripod", "type": "Equipment", "description": "Photography equipment accessory", "example": "", "aliases": ["Camera Tripod"], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
|
||||
**Example 4 (User vs AI alias distinction - English output):** "I'm Alex, and I'll call you Buddy"
|
||||
Output:
|
||||
{
|
||||
"triplets": [
|
||||
{"subject_name": "User", "subject_id": 0, "predicate": "NAMED", "object_name": "AI Assistant", "object_id": 1, "value": "Buddy"}
|
||||
],
|
||||
"entities": [
|
||||
{"entity_idx": 0, "name": "User", "type": "Person", "description": "The user", "example": "", "aliases": ["Alex"], "is_explicit_memory": false},
|
||||
{"entity_idx": 1, "name": "AI Assistant", "type": "Person", "description": "The user's AI assistant", "example": "", "aliases": ["Buddy"], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
{% else %}
|
||||
**Example 1 (English input → Chinese output):** "I plan to travel to Paris next week and visit the Louvre."
|
||||
Output:
|
||||
{
|
||||
"triplets": [
|
||||
@@ -258,6 +352,39 @@ Output:
|
||||
]
|
||||
}
|
||||
|
||||
**Example 6 (用户与AI别名区分 - Chinese):** "我称呼自己为陈思远,我给AI取名为远仔"
|
||||
Output:
|
||||
{
|
||||
"triplets": [
|
||||
{"subject_name": "用户", "subject_id": 0, "predicate": "NAMED", "object_name": "AI助手", "object_id": 1, "value": "远仔"}
|
||||
],
|
||||
"entities": [
|
||||
{"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["陈思远"], "is_explicit_memory": false},
|
||||
{"entity_idx": 1, "name": "AI助手", "type": "Person", "description": "用户的AI助手", "example": "", "aliases": ["远仔"], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
|
||||
**Example 7 (纯用户自我介绍,无AI命名 - Chinese):** "我叫vv"
|
||||
Output:
|
||||
{
|
||||
"triplets": [],
|
||||
"entities": [
|
||||
{"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["vv"], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
|
||||
**Example 8 (给AI改名 - Chinese):** "你不叫远仔了,你现在叫陈仔"
|
||||
Output:
|
||||
{
|
||||
"triplets": [
|
||||
{"subject_name": "用户", "subject_id": 0, "predicate": "NAMED", "object_name": "AI助手", "object_id": 1, "value": "陈仔"}
|
||||
],
|
||||
"entities": [
|
||||
{"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": [], "is_explicit_memory": false},
|
||||
{"entity_idx": 1, "name": "AI助手", "type": "Person", "description": "用户的AI助手", "example": "", "aliases": ["陈仔"], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
{% endif %}
|
||||
===End of Examples===
|
||||
|
||||
@@ -14,6 +14,7 @@ from pydantic import BaseModel, Field
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.models.models_model import ModelProvider, ModelType
|
||||
from app.core.models.volcano_chat import VolcanoChatOpenAI
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -25,6 +26,9 @@ class RedBearModelConfig(BaseModel):
|
||||
api_key: str
|
||||
base_url: Optional[str] = None
|
||||
is_omni: bool = False # 是否为 Omni 模型
|
||||
deep_thinking: bool = False # 是否启用深度思考模式
|
||||
thinking_budget_tokens: Optional[int] = None # 深度思考 token 预算
|
||||
support_thinking: bool = False # 模型是否支持 enable_thinking 参数(capability 含 thinking)
|
||||
# 请求超时时间(秒)- 默认120秒以支持复杂的LLM调用,可通过环境变量 LLM_TIMEOUT 配置
|
||||
timeout: float = Field(default_factory=lambda: float(os.getenv("LLM_TIMEOUT", "120.0")))
|
||||
# 最大重试次数 - 默认2次以避免过长等待,可通过环境变量 LLM_MAX_RETRIES 配置
|
||||
@@ -44,7 +48,7 @@ class RedBearModelFactory:
|
||||
# 打印供应商信息用于调试
|
||||
from app.core.logging_config import get_business_logger
|
||||
logger = get_business_logger()
|
||||
logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}, is_omni: {config.is_omni}")
|
||||
logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}, is_omni: {config.is_omni}, deep_thinking: {config.deep_thinking}")
|
||||
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
||||
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
||||
@@ -58,7 +62,7 @@ class RedBearModelFactory:
|
||||
write=60.0,
|
||||
pool=10.0,
|
||||
)
|
||||
return {
|
||||
params: Dict[str, Any] = {
|
||||
"model": config.model_name,
|
||||
"base_url": config.base_url,
|
||||
"api_key": config.api_key,
|
||||
@@ -66,6 +70,23 @@ class RedBearModelFactory:
|
||||
"max_retries": config.max_retries,
|
||||
**config.extra_params
|
||||
}
|
||||
# 流式模式下启用 stream_usage 以获取 token 统计
|
||||
is_streaming = bool(config.extra_params.get("streaming"))
|
||||
if is_streaming:
|
||||
params["stream_usage"] = True
|
||||
# 只有支持 thinking 的模型才传 enable_thinking
|
||||
if config.support_thinking:
|
||||
model_kwargs: Dict[str, Any] = config.extra_params.get("model_kwargs", {})
|
||||
if is_streaming:
|
||||
model_kwargs["enable_thinking"] = config.deep_thinking
|
||||
if config.deep_thinking:
|
||||
model_kwargs["incremental_output"] = True
|
||||
if config.thinking_budget_tokens:
|
||||
model_kwargs["thinking_budget"] = config.thinking_budget_tokens
|
||||
else:
|
||||
model_kwargs["enable_thinking"] = False
|
||||
params["model_kwargs"] = model_kwargs
|
||||
return params
|
||||
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]:
|
||||
# 使用 httpx.Timeout 对象来设置详细的超时配置
|
||||
@@ -78,7 +99,7 @@ class RedBearModelFactory:
|
||||
write=60.0, # 写入超时:60秒
|
||||
pool=10.0, # 连接池超时:10秒
|
||||
)
|
||||
return {
|
||||
params: Dict[str, Any] = {
|
||||
"model": config.model_name,
|
||||
"base_url": config.base_url,
|
||||
"api_key": config.api_key,
|
||||
@@ -86,16 +107,49 @@ class RedBearModelFactory:
|
||||
"max_retries": config.max_retries,
|
||||
**config.extra_params
|
||||
}
|
||||
# 流式模式下启用 stream_usage 以获取 token 统计
|
||||
if config.extra_params.get("streaming"):
|
||||
params["stream_usage"] = True
|
||||
# 深度思考模式
|
||||
is_streaming = bool(config.extra_params.get("streaming"))
|
||||
if is_streaming and not config.is_omni:
|
||||
if provider == ModelProvider.VOLCANO:
|
||||
# 火山引擎深度思考仅流式调用支持,非流式时不传 thinking 参数
|
||||
thinking_config: Dict[str, Any] = {
|
||||
"type": "enabled" if config.deep_thinking else "disabled"
|
||||
}
|
||||
if config.deep_thinking and config.thinking_budget_tokens:
|
||||
thinking_config["budget_tokens"] = config.thinking_budget_tokens
|
||||
params["extra_body"] = {"thinking": thinking_config}
|
||||
else:
|
||||
# 始终显式传递 enable_thinking,不支持该参数的模型(如 DeepSeek-R1)会直接忽略
|
||||
model_kwargs: Dict[str, Any] = config.extra_params.get("model_kwargs", {})
|
||||
model_kwargs["enable_thinking"] = config.deep_thinking
|
||||
if config.deep_thinking and config.thinking_budget_tokens:
|
||||
model_kwargs["thinking_budget"] = config.thinking_budget_tokens
|
||||
params["model_kwargs"] = model_kwargs
|
||||
return params
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
# DashScope (通义千问) 使用自己的参数格式
|
||||
# 注意: DashScopeEmbeddings 不支持 timeout 和 base_url 参数
|
||||
# 只支持: model, dashscope_api_key, max_retries, client
|
||||
return {
|
||||
params = {
|
||||
"model": config.model_name,
|
||||
"dashscope_api_key": config.api_key,
|
||||
"max_retries": config.max_retries,
|
||||
**config.extra_params
|
||||
}
|
||||
# 只有支持 thinking 的模型才传 enable_thinking
|
||||
if config.support_thinking:
|
||||
is_streaming = bool(config.extra_params.get("streaming"))
|
||||
model_kwargs: Dict[str, Any] = config.extra_params.get("model_kwargs", {})
|
||||
if is_streaming:
|
||||
model_kwargs["enable_thinking"] = config.deep_thinking
|
||||
if config.deep_thinking:
|
||||
model_kwargs["incremental_output"] = True
|
||||
if config.thinking_budget_tokens:
|
||||
model_kwargs["thinking_budget"] = config.thinking_budget_tokens
|
||||
else:
|
||||
model_kwargs["enable_thinking"] = False
|
||||
params["model_kwargs"] = model_kwargs
|
||||
return params
|
||||
elif provider == ModelProvider.BEDROCK:
|
||||
# Bedrock 使用 AWS 凭证
|
||||
# api_key 格式: "access_key_id:secret_access_key" 或只是 access_key_id
|
||||
@@ -134,6 +188,13 @@ class RedBearModelFactory:
|
||||
elif "region_name" not in params:
|
||||
params["region_name"] = "us-east-1" # 默认区域
|
||||
|
||||
# 深度思考模式:Claude 3.7 Sonnet 等支持思考的模型
|
||||
# 通过 additional_model_request_fields 传递 thinking 块,关闭时不传(Bedrock 无 disabled 选项)
|
||||
if config.deep_thinking:
|
||||
budget = config.thinking_budget_tokens or 10000
|
||||
params["additional_model_request_fields"] = {
|
||||
"thinking": {"type": "enabled", "budget_tokens": budget}
|
||||
}
|
||||
return params
|
||||
else:
|
||||
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
@@ -160,7 +221,9 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
||||
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
||||
return ChatOpenAI
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.VOLCANO]:
|
||||
if provider == ModelProvider.VOLCANO:
|
||||
return VolcanoChatOpenAI
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
if type == ModelType.LLM:
|
||||
return OpenAI
|
||||
elif type == ModelType.CHAT:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Union
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from app.core.models.base import RedBearModelConfig, get_provider_embedding_class, RedBearModelFactory
|
||||
@@ -22,11 +22,40 @@ class RedBearEmbeddings(Embeddings):
|
||||
self._model = self._create_model(config)
|
||||
self._client = None
|
||||
|
||||
def _create_model(self, config: RedBearModelConfig) -> Embeddings:
|
||||
@staticmethod
|
||||
def _create_model(config: RedBearModelConfig) -> Embeddings:
|
||||
"""根据配置创建 LangChain 模型"""
|
||||
embedding_class = get_provider_embedding_class(config.provider)
|
||||
model_params = RedBearModelFactory.get_model_params(config)
|
||||
return embedding_class(**model_params)
|
||||
provider = config.provider.lower()
|
||||
# Embedding models only need connection params, never LLM-specific ones
|
||||
# (e.g. enable_thinking, model_kwargs) — build params directly.
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
import httpx
|
||||
params = {
|
||||
"model": config.model_name,
|
||||
"base_url": config.base_url,
|
||||
"api_key": config.api_key,
|
||||
"timeout": httpx.Timeout(timeout=config.timeout, connect=60.0),
|
||||
"max_retries": config.max_retries,
|
||||
"check_embedding_ctx_length": False,
|
||||
"encoding_format": "float"
|
||||
}
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
params = {
|
||||
"model": config.model_name,
|
||||
"dashscope_api_key": config.api_key,
|
||||
"max_retries": config.max_retries,
|
||||
}
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
params = {
|
||||
"model": config.model_name,
|
||||
"base_url": config.base_url,
|
||||
}
|
||||
elif provider == ModelProvider.BEDROCK:
|
||||
params = RedBearModelFactory.get_model_params(config)
|
||||
else:
|
||||
params = RedBearModelFactory.get_model_params(config)
|
||||
return embedding_class(**params)
|
||||
|
||||
def _create_volcano_client(self, config: RedBearModelConfig):
|
||||
"""创建火山引擎客户端"""
|
||||
|
||||
@@ -11,6 +11,7 @@ models:
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: bedrock
|
||||
|
||||
- name: amazon nova
|
||||
type: llm
|
||||
provider: bedrock
|
||||
@@ -27,6 +28,7 @@ models:
|
||||
- stream-tool-call
|
||||
- vision
|
||||
logo: bedrock
|
||||
|
||||
- name: anthropic claude
|
||||
type: llm
|
||||
provider: bedrock
|
||||
@@ -35,6 +37,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -44,6 +47,7 @@ models:
|
||||
- stream-tool-call
|
||||
- document
|
||||
logo: bedrock
|
||||
|
||||
- name: cohere
|
||||
type: llm
|
||||
provider: bedrock
|
||||
@@ -58,6 +62,7 @@ models:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
logo: bedrock
|
||||
|
||||
- name: deepseek
|
||||
type: llm
|
||||
provider: bedrock
|
||||
@@ -66,6 +71,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -74,6 +80,7 @@ models:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
logo: bedrock
|
||||
|
||||
- name: meta
|
||||
type: llm
|
||||
provider: bedrock
|
||||
@@ -87,6 +94,7 @@ models:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
logo: bedrock
|
||||
|
||||
- name: mistral
|
||||
type: llm
|
||||
provider: bedrock
|
||||
@@ -100,6 +108,7 @@ models:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
logo: bedrock
|
||||
|
||||
- name: openai
|
||||
type: llm
|
||||
provider: bedrock
|
||||
@@ -114,6 +123,7 @@ models:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
logo: bedrock
|
||||
|
||||
- name: qwen
|
||||
type: llm
|
||||
provider: bedrock
|
||||
@@ -128,6 +138,7 @@ models:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
logo: bedrock
|
||||
|
||||
- name: amazon.rerank-v1:0
|
||||
type: rerank
|
||||
provider: bedrock
|
||||
@@ -139,6 +150,7 @@ models:
|
||||
tags:
|
||||
- 重排序模型
|
||||
logo: bedrock
|
||||
|
||||
- name: cohere.rerank-v3-5:0
|
||||
type: rerank
|
||||
provider: bedrock
|
||||
@@ -150,6 +162,7 @@ models:
|
||||
tags:
|
||||
- 重排序模型
|
||||
logo: bedrock
|
||||
|
||||
- name: amazon.nova-2-multimodal-embeddings-v1:0
|
||||
type: embedding
|
||||
provider: bedrock
|
||||
@@ -163,6 +176,7 @@ models:
|
||||
- 文本嵌入模型
|
||||
- vision
|
||||
logo: bedrock
|
||||
|
||||
- name: amazon.titan-embed-text-v1
|
||||
type: embedding
|
||||
provider: bedrock
|
||||
@@ -174,6 +188,7 @@ models:
|
||||
tags:
|
||||
- 文本嵌入模型
|
||||
logo: bedrock
|
||||
|
||||
- name: amazon.titan-embed-text-v2:0
|
||||
type: embedding
|
||||
provider: bedrock
|
||||
@@ -185,6 +200,7 @@ models:
|
||||
tags:
|
||||
- 文本嵌入模型
|
||||
logo: bedrock
|
||||
|
||||
- name: cohere.embed-english-v3
|
||||
type: embedding
|
||||
provider: bedrock
|
||||
@@ -196,6 +212,7 @@ models:
|
||||
tags:
|
||||
- 文本嵌入模型
|
||||
logo: bedrock
|
||||
|
||||
- name: cohere.embed-multilingual-v3
|
||||
type: embedding
|
||||
provider: bedrock
|
||||
|
||||
@@ -6,36 +6,42 @@ models:
|
||||
description: DeepSeek-R1-Distill-Qwen-14B大语言模型,支持智能体思考,32000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
|
||||
- name: deepseek-r1-distill-qwen-32b
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: DeepSeek-R1-Distill-Qwen-32B大语言模型,支持智能体思考,32000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
|
||||
- name: deepseek-r1
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: DeepSeek-R1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
|
||||
- name: deepseek-v3.1
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -48,6 +54,7 @@ models:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
|
||||
- name: deepseek-v3.2-exp
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -60,6 +67,7 @@ models:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
|
||||
- name: deepseek-v3.2
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -72,6 +80,7 @@ models:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
|
||||
- name: deepseek-v3
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -84,6 +93,7 @@ models:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
|
||||
- name: farui-plus
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -98,6 +108,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: glm-4.7
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -112,6 +123,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qvq-max-latest
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -119,7 +131,8 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- vision
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -127,6 +140,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qvq-max
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -134,7 +148,8 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- vision
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -142,6 +157,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-coder-turbo-0919
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -155,13 +171,15 @@ models:
|
||||
- 代码模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-max-latest
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen-max-latest大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -169,6 +187,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-max-longcontext
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -183,13 +202,15 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-max
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen-max大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -197,6 +218,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-mt-plus
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -210,6 +232,7 @@ models:
|
||||
- 翻译模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-mt-turbo
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -223,6 +246,7 @@ models:
|
||||
- 翻译模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-plus-0112
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -237,6 +261,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-plus-0125
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -251,6 +276,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-plus-0723
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -265,6 +291,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-plus-0806
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -279,6 +306,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-plus-0919
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -293,6 +321,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-plus-1125
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -307,6 +336,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-plus-1127
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -321,6 +351,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-plus-1220
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -335,6 +366,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-vl-max
|
||||
type: chat
|
||||
provider: dashscope
|
||||
@@ -342,8 +374,8 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -352,6 +384,7 @@ models:
|
||||
- agent-thought
|
||||
- video
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-vl-plus-0809
|
||||
type: chat
|
||||
provider: dashscope
|
||||
@@ -359,8 +392,8 @@ models:
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -369,6 +402,7 @@ models:
|
||||
- agent-thought
|
||||
- video
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-vl-plus-2025-01-02
|
||||
type: chat
|
||||
provider: dashscope
|
||||
@@ -376,8 +410,8 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -386,6 +420,7 @@ models:
|
||||
- agent-thought
|
||||
- video
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-vl-plus-2025-01-25
|
||||
type: chat
|
||||
provider: dashscope
|
||||
@@ -393,8 +428,8 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -403,6 +438,7 @@ models:
|
||||
- agent-thought
|
||||
- video
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-vl-plus-latest
|
||||
type: chat
|
||||
provider: dashscope
|
||||
@@ -410,8 +446,8 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -420,6 +456,7 @@ models:
|
||||
- agent-thought
|
||||
- video
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen-vl-plus
|
||||
type: chat
|
||||
provider: dashscope
|
||||
@@ -427,8 +464,8 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -437,6 +474,7 @@ models:
|
||||
- agent-thought
|
||||
- video
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen2.5-0.5b-instruct
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -451,13 +489,15 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-14b
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-14b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -465,13 +505,15 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-235b-a22b-instruct-2507
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-235b-a22b-instruct-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -479,13 +521,15 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-235b-a22b-thinking-2507
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-235b-a22b-thinking-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -493,13 +537,15 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-235b-a22b
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-235b-a22b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -507,13 +553,15 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-30b-a3b-instruct-2507
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-30b-a3b-instruct-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -521,13 +569,15 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-30b-a3b
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-30b-a3b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -535,13 +585,15 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-32b
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-32b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -549,13 +601,15 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-4b
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-4b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -563,13 +617,15 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-8b
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-8b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -577,65 +633,75 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-coder-30b-a3b-instruct
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-coder-30b-a3b-instruct大语言模型,支持智能体思考,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-coder-480b-a35b-instruct
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-coder-480b-a35b-instruct大语言模型,支持智能体思考,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-coder-plus-2025-09-23
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-coder-plus-2025-09-23大语言模型,支持智能体思考,1000000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-coder-plus
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-coder-plus大语言模型,支持智能体思考,1000000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-max-2025-09-23
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-max-2025-09-23大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -644,13 +710,15 @@ models:
|
||||
- stream-tool-call
|
||||
- 联网搜索
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-max-2026-01-23
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-max-2026-01-23大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -659,13 +727,15 @@ models:
|
||||
- stream-tool-call
|
||||
- 联网搜索
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-max-preview
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-max-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -673,13 +743,15 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-max
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-max大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -688,13 +760,15 @@ models:
|
||||
- stream-tool-call
|
||||
- 联网搜索
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-next-80b-a3b-instruct
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-next-80b-a3b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -702,13 +776,15 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-next-80b-a3b-thinking
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-next-80b-a3b-thinking大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -716,6 +792,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-omni-flash-2025-12-01
|
||||
type: llm
|
||||
provider: dashscope
|
||||
@@ -723,9 +800,9 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- audio
|
||||
- vision
|
||||
- video
|
||||
- audio
|
||||
is_omni: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -735,6 +812,7 @@ models:
|
||||
- video
|
||||
- audio
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-vl-235b-a22b-instruct
|
||||
type: chat
|
||||
provider: dashscope
|
||||
@@ -742,8 +820,9 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -754,6 +833,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-vl-235b-a22b-thinking
|
||||
type: chat
|
||||
provider: dashscope
|
||||
@@ -761,8 +841,9 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -773,6 +854,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-vl-30b-a3b-instruct
|
||||
type: chat
|
||||
provider: dashscope
|
||||
@@ -780,8 +862,9 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -792,6 +875,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-vl-30b-a3b-thinking
|
||||
type: chat
|
||||
provider: dashscope
|
||||
@@ -799,8 +883,9 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -811,6 +896,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-vl-flash
|
||||
type: chat
|
||||
provider: dashscope
|
||||
@@ -818,8 +904,9 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -830,6 +917,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-vl-plus-2025-09-23
|
||||
type: chat
|
||||
provider: dashscope
|
||||
@@ -837,8 +925,9 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -847,6 +936,7 @@ models:
|
||||
- agent-thought
|
||||
- video
|
||||
logo: dashscope
|
||||
|
||||
- name: qwen3-vl-plus
|
||||
type: chat
|
||||
provider: dashscope
|
||||
@@ -854,8 +944,9 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -864,45 +955,52 @@ models:
|
||||
- agent-thought
|
||||
- video
|
||||
logo: dashscope
|
||||
|
||||
- name: qwq-32b
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwq-32b大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwq-plus-0305
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwq-plus-0305大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: qwq-plus
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwq-plus大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
|
||||
- name: gte-rerank-v2
|
||||
type: rerank
|
||||
provider: dashscope
|
||||
@@ -914,6 +1012,7 @@ models:
|
||||
tags:
|
||||
- 重排序模型
|
||||
logo: dashscope
|
||||
|
||||
- name: gte-rerank
|
||||
type: rerank
|
||||
provider: dashscope
|
||||
@@ -925,6 +1024,7 @@ models:
|
||||
tags:
|
||||
- 重排序模型
|
||||
logo: dashscope
|
||||
|
||||
- name: multimodal-embedding-v1
|
||||
type: embedding
|
||||
provider: dashscope
|
||||
@@ -932,13 +1032,14 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 嵌入模型
|
||||
- 多模态模型
|
||||
- vision
|
||||
logo: dashscope
|
||||
|
||||
- name: text-embedding-v1
|
||||
type: embedding
|
||||
provider: dashscope
|
||||
@@ -951,6 +1052,7 @@ models:
|
||||
- 嵌入模型
|
||||
- 文本嵌入
|
||||
logo: dashscope
|
||||
|
||||
- name: text-embedding-v2
|
||||
type: embedding
|
||||
provider: dashscope
|
||||
@@ -963,6 +1065,7 @@ models:
|
||||
- 嵌入模型
|
||||
- 文本嵌入
|
||||
logo: dashscope
|
||||
|
||||
- name: text-embedding-v3
|
||||
type: embedding
|
||||
provider: dashscope
|
||||
@@ -975,6 +1078,7 @@ models:
|
||||
- 嵌入模型
|
||||
- 文本嵌入
|
||||
logo: dashscope
|
||||
|
||||
- name: text-embedding-v4
|
||||
type: embedding
|
||||
provider: dashscope
|
||||
@@ -986,4 +1090,4 @@ models:
|
||||
tags:
|
||||
- 嵌入模型
|
||||
- 文本嵌入
|
||||
logo: dashscope
|
||||
logo: dashscope
|
||||
|
||||
@@ -20,6 +20,7 @@ models:
|
||||
- audio
|
||||
- video
|
||||
logo: openai
|
||||
|
||||
- name: gpt-3.5-turbo-0125
|
||||
type: llm
|
||||
provider: openai
|
||||
@@ -34,6 +35,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: openai
|
||||
|
||||
- name: gpt-3.5-turbo-1106
|
||||
type: llm
|
||||
provider: openai
|
||||
@@ -48,6 +50,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: openai
|
||||
|
||||
- name: gpt-3.5-turbo-16k
|
||||
type: llm
|
||||
provider: openai
|
||||
@@ -62,6 +65,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: openai
|
||||
|
||||
- name: gpt-3.5-turbo-instruct
|
||||
type: llm
|
||||
provider: openai
|
||||
@@ -73,6 +77,7 @@ models:
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: openai
|
||||
|
||||
- name: gpt-3.5-turbo
|
||||
type: llm
|
||||
provider: openai
|
||||
@@ -87,6 +92,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: openai
|
||||
|
||||
- name: gpt-4-0125-preview
|
||||
type: llm
|
||||
provider: openai
|
||||
@@ -101,6 +107,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: openai
|
||||
|
||||
- name: gpt-4-1106-preview
|
||||
type: llm
|
||||
provider: openai
|
||||
@@ -115,6 +122,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: openai
|
||||
|
||||
- name: gpt-4-turbo-2024-04-09
|
||||
type: llm
|
||||
provider: openai
|
||||
@@ -131,6 +139,7 @@ models:
|
||||
- stream-tool-call
|
||||
- vision
|
||||
logo: openai
|
||||
|
||||
- name: gpt-4-turbo-preview
|
||||
type: llm
|
||||
provider: openai
|
||||
@@ -145,6 +154,7 @@ models:
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: openai
|
||||
|
||||
- name: gpt-4-turbo
|
||||
type: llm
|
||||
provider: openai
|
||||
@@ -161,6 +171,7 @@ models:
|
||||
- stream-tool-call
|
||||
- vision
|
||||
logo: openai
|
||||
|
||||
- name: o1-preview
|
||||
type: llm
|
||||
provider: openai
|
||||
@@ -173,6 +184,7 @@ models:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
logo: openai
|
||||
|
||||
- name: o1
|
||||
type: llm
|
||||
provider: openai
|
||||
@@ -181,6 +193,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -190,6 +203,7 @@ models:
|
||||
- vision
|
||||
- structured-output
|
||||
logo: openai
|
||||
|
||||
- name: o3-2025-04-16
|
||||
type: llm
|
||||
provider: openai
|
||||
@@ -198,6 +212,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -207,13 +222,15 @@ models:
|
||||
- stream-tool-call
|
||||
- structured-output
|
||||
logo: openai
|
||||
|
||||
- name: o3-mini-2025-01-31
|
||||
type: llm
|
||||
provider: openai
|
||||
description: o3-mini-2025-01-31大语言模型,支持智能体思考、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -222,13 +239,15 @@ models:
|
||||
- stream-tool-call
|
||||
- structured-output
|
||||
logo: openai
|
||||
|
||||
- name: o3-mini
|
||||
type: llm
|
||||
provider: openai
|
||||
description: o3-mini大语言模型,支持智能体思考、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -237,6 +256,7 @@ models:
|
||||
- stream-tool-call
|
||||
- structured-output
|
||||
logo: openai
|
||||
|
||||
- name: o3-pro-2025-06-10
|
||||
type: llm
|
||||
provider: openai
|
||||
@@ -245,6 +265,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -253,6 +274,7 @@ models:
|
||||
- vision
|
||||
- structured-output
|
||||
logo: openai
|
||||
|
||||
- name: o3-pro
|
||||
type: llm
|
||||
provider: openai
|
||||
@@ -261,6 +283,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -269,6 +292,7 @@ models:
|
||||
- vision
|
||||
- structured-output
|
||||
logo: openai
|
||||
|
||||
- name: o3
|
||||
type: llm
|
||||
provider: openai
|
||||
@@ -277,6 +301,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -286,6 +311,7 @@ models:
|
||||
- stream-tool-call
|
||||
- structured-output
|
||||
logo: openai
|
||||
|
||||
- name: o4-mini-2025-04-16
|
||||
type: llm
|
||||
provider: openai
|
||||
@@ -294,6 +320,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -303,6 +330,7 @@ models:
|
||||
- stream-tool-call
|
||||
- structured-output
|
||||
logo: openai
|
||||
|
||||
- name: o4-mini
|
||||
type: llm
|
||||
provider: openai
|
||||
@@ -311,6 +339,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -320,6 +349,7 @@ models:
|
||||
- stream-tool-call
|
||||
- structured-output
|
||||
logo: openai
|
||||
|
||||
- name: text-embedding-3-large
|
||||
type: embedding
|
||||
provider: openai
|
||||
@@ -331,6 +361,7 @@ models:
|
||||
tags:
|
||||
- 文本向量模型
|
||||
logo: openai
|
||||
|
||||
- name: text-embedding-3-small
|
||||
type: embedding
|
||||
provider: openai
|
||||
@@ -342,6 +373,7 @@ models:
|
||||
tags:
|
||||
- 文本向量模型
|
||||
logo: openai
|
||||
|
||||
- name: text-embedding-ada-002
|
||||
type: embedding
|
||||
provider: openai
|
||||
|
||||
@@ -10,6 +10,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -24,6 +25,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -38,6 +40,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -52,6 +55,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -82,6 +86,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -96,6 +101,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -110,6 +116,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -124,6 +131,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -139,6 +147,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
|
||||
52
api/app/core/models/volcano_chat.py
Normal file
52
api/app/core/models/volcano_chat.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
火山引擎 ChatOpenAI 扩展
|
||||
|
||||
ChatOpenAI 在解析流式 SSE 时只取 delta.content,会丢弃 delta.reasoning_content。
|
||||
此类仅重写 _convert_chunk_to_generation_chunk,将 reasoning_content 补入 additional_kwargs。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
|
||||
class VolcanoChatOpenAI(ChatOpenAI):
|
||||
"""火山引擎 Chat 模型,支持深度思考内容(reasoning_content)的流式和非流式透传。"""
|
||||
|
||||
def _create_chat_result(self, response: Union[dict, Any], generation_info: Optional[dict] = None) -> ChatResult:
|
||||
result = super()._create_chat_result(response, generation_info)
|
||||
# 将非流式响应中的 reasoning_content 补入 additional_kwargs
|
||||
choices = response.choices if hasattr(response, "choices") else response.get("choices", [])
|
||||
if choices:
|
||||
message = choices[0].message if hasattr(choices[0], "message") else choices[0].get("message", {})
|
||||
reasoning = (
|
||||
getattr(message, "reasoning_content", None)
|
||||
or (message.get("reasoning_content") if isinstance(message, dict) else None)
|
||||
)
|
||||
if reasoning and result.generations:
|
||||
result.generations[0].message.additional_kwargs["reasoning_content"] = reasoning
|
||||
return result
|
||||
|
||||
def _convert_chunk_to_generation_chunk(
|
||||
self,
|
||||
chunk: dict,
|
||||
default_chunk_class: type,
|
||||
base_generation_info: Optional[dict],
|
||||
) -> Optional[ChatGenerationChunk]:
|
||||
gen_chunk = super()._convert_chunk_to_generation_chunk(
|
||||
chunk, default_chunk_class, base_generation_info
|
||||
)
|
||||
if gen_chunk is None:
|
||||
return None
|
||||
|
||||
# 从原始 chunk 中提取 reasoning_content
|
||||
choices = chunk.get("choices") or chunk.get("chunk", {}).get("choices", [])
|
||||
if choices:
|
||||
delta = choices[0].get("delta") or {}
|
||||
reasoning: Any = delta.get("reasoning_content")
|
||||
if reasoning:
|
||||
gen_chunk.message.additional_kwargs["reasoning_content"] = reasoning
|
||||
|
||||
return gen_chunk
|
||||
@@ -27,7 +27,7 @@ class DateTimeTool(BuiltinTool):
|
||||
type=ParameterType.STRING,
|
||||
description="操作类型",
|
||||
required=True,
|
||||
enum=["format", "convert_timezone", "timestamp_to_datetime", "now"]
|
||||
enum=["format", "convert_timezone", "timestamp_to_datetime", "now", "datetime_to_timestamp"]
|
||||
),
|
||||
ToolParameter(
|
||||
name="input_value",
|
||||
|
||||
@@ -99,7 +99,7 @@ class SimpleMCPClient:
|
||||
# 建立 SSE 连接
|
||||
response = await self._session.get(self.server_url)
|
||||
|
||||
if response.status not in (200, 202):
|
||||
if not (200 <= response.status < 300):
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}")
|
||||
|
||||
@@ -190,9 +190,7 @@ class SimpleMCPClient:
|
||||
|
||||
try:
|
||||
async with self._session.post(self._endpoint_url, json=request) as response:
|
||||
# MCP SSE 协议:POST 请求返回 200 或 202 均为正常
|
||||
# 202 Accepted 表示请求已接受,结果通过 SSE 流异步返回
|
||||
if response.status not in (200, 202):
|
||||
if not (200 <= response.status < 300):
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"请求失败 {response.status}: {error_text}")
|
||||
|
||||
@@ -207,7 +205,7 @@ class SimpleMCPClient:
|
||||
raise MCPConnectionError("endpoint URL 未初始化")
|
||||
|
||||
async with self._session.post(self._endpoint_url, json=notification) as response:
|
||||
if response.status not in (200, 202):
|
||||
if not (200 <= response.status < 300):
|
||||
logger.warning(f"通知发送失败: {response.status}")
|
||||
|
||||
async def _initialize_modelscope_session(self):
|
||||
@@ -225,7 +223,7 @@ class SimpleMCPClient:
|
||||
|
||||
try:
|
||||
async with self._session.post(self.server_url, json=init_request) as response:
|
||||
if response.status != 200:
|
||||
if not (200 <= response.status < 300):
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}")
|
||||
|
||||
|
||||
@@ -32,13 +32,16 @@ from app.core.workflow.nodes.configs import (
|
||||
NoteNodeConfig,
|
||||
ParameterExtractorNodeConfig,
|
||||
QuestionClassifierNodeConfig,
|
||||
VariableAggregatorNodeConfig
|
||||
VariableAggregatorNodeConfig,
|
||||
ListOperatorNodeConfig,
|
||||
DocExtractorNodeConfig,
|
||||
)
|
||||
from app.core.workflow.nodes.cycle_graph.config import (
|
||||
ConditionDetail as LoopConditionDetail,
|
||||
ConditionsConfig,
|
||||
CycleVariable
|
||||
)
|
||||
from app.core.workflow.nodes.list_operator.config import FilterCondition
|
||||
from app.core.workflow.nodes.enums import (
|
||||
ValueInputType,
|
||||
ComparisonOperator,
|
||||
@@ -90,6 +93,8 @@ class DifyConverter(BaseConverter):
|
||||
NodeType.VAR_AGGREGATOR: self.convert_variable_aggregator_node_config,
|
||||
NodeType.TOOL: self.convert_tool_node_config,
|
||||
NodeType.NOTES: self.convert_notes_config,
|
||||
NodeType.LIST_OPERATOR: self.convert_list_operator_node_config,
|
||||
NodeType.DOCUMENT_EXTRACTOR: self.convert_document_extractor_node_config,
|
||||
NodeType.CYCLE_START: lambda x: {},
|
||||
NodeType.BREAK: lambda x: {},
|
||||
}
|
||||
@@ -126,7 +131,7 @@ class DifyConverter(BaseConverter):
|
||||
selector = var_selector.split('.')
|
||||
if len(selector) not in [2, 3] and var_selector != "context":
|
||||
raise Exception(f"invalid variable selector: {var_selector}")
|
||||
if len(selector) == 3:
|
||||
if len(selector) == 3 and selector[0] in ("conversation", "sys"):
|
||||
selector = selector[1:]
|
||||
if selector[0] == "conversation":
|
||||
selector[0] = "conv"
|
||||
@@ -213,7 +218,9 @@ class DifyConverter(BaseConverter):
|
||||
"end with": ComparisonOperator.END_WITH,
|
||||
"not contains": ComparisonOperator.NOT_CONTAINS,
|
||||
"exists": ComparisonOperator.NOT_EMPTY,
|
||||
"not exists": ComparisonOperator.EMPTY
|
||||
"not exists": ComparisonOperator.EMPTY,
|
||||
"in": ComparisonOperator.IN,
|
||||
"not in": ComparisonOperator.NOT_IN,
|
||||
}
|
||||
return operator_map.get(operator, operator)
|
||||
|
||||
@@ -476,11 +483,11 @@ class DifyConverter(BaseConverter):
|
||||
node_data = node["data"]
|
||||
result = IterationNodeConfig.model_construct(
|
||||
input=self._process_list_variable_literal(node_data["iterator_selector"]),
|
||||
parallel=node_data["is_parallel"],
|
||||
parallel_count=node_data["parallel_nums"],
|
||||
parallel=node_data.get("is_parallel", False),
|
||||
parallel_count=node_data.get("parallel_nums", 4),
|
||||
output=self._process_list_variable_literal(node_data["output_selector"]),
|
||||
output_type=self.variable_type_map(node_data.get("output_type")),
|
||||
flatten=node_data["flatten_output"],
|
||||
flatten=node_data.get("flatten_output", False),
|
||||
).model_dump()
|
||||
|
||||
self.config_validate(node["id"], node["data"]["title"], IterationNodeConfig, result)
|
||||
@@ -489,7 +496,23 @@ class DifyConverter(BaseConverter):
|
||||
def convert_assigner_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
assignments = []
|
||||
for assignment in node_data["items"]:
|
||||
|
||||
# Support both formats:
|
||||
# 1. New format: node_data["items"] list
|
||||
# 2. Flat format: assigned_variable_selector + input_variable_selector + write_mode
|
||||
if "items" in node_data:
|
||||
raw_items = node_data["items"]
|
||||
elif "assigned_variable_selector" in node_data and "input_variable_selector" in node_data:
|
||||
raw_items = [{
|
||||
"variable_selector": node_data["assigned_variable_selector"],
|
||||
"value": node_data["input_variable_selector"],
|
||||
"input_type": ValueInputType.VARIABLE,
|
||||
"operation": node_data.get("write_mode", "over-write"),
|
||||
}]
|
||||
else:
|
||||
raw_items = []
|
||||
|
||||
for assignment in raw_items:
|
||||
if assignment.get("operation") is None or assignment.get("value") is None:
|
||||
continue
|
||||
assignments.append(
|
||||
@@ -771,3 +794,46 @@ class DifyConverter(BaseConverter):
|
||||
show_author=node_data.get("showAuthor", True)
|
||||
).model_dump()
|
||||
return result
|
||||
|
||||
def convert_list_operator_node_config(self, node: dict) -> dict:
|
||||
"""Dify list-operator — convert variable path array to {{ }} selector format."""
|
||||
node_data = node["data"]
|
||||
variable_path = node_data.get("variable", [])
|
||||
input_list = self._process_list_variable_literal(variable_path) or ""
|
||||
filter_by = node_data.get("filter_by", {"enabled": False, "conditions": []})
|
||||
# Convert each condition's comparison_operator from Dify format to native
|
||||
if filter_by.get("conditions"):
|
||||
converted_conditions = []
|
||||
for cond in filter_by["conditions"]:
|
||||
converted_conditions.append({
|
||||
**cond,
|
||||
"comparison_operator": self.convert_compare_operator(
|
||||
cond.get("comparison_operator", "")
|
||||
)
|
||||
})
|
||||
filter_by = {**filter_by, "conditions": converted_conditions}
|
||||
result = {
|
||||
"input_list": input_list,
|
||||
"filter_by": filter_by,
|
||||
"order_by": node_data.get("order_by", {"enabled": False, "key": "", "value": "asc"}),
|
||||
"limit": node_data.get("limit", {"enabled": False, "size": -1}),
|
||||
"extract_by": node_data.get("extract_by", {"enabled": False, "serial": "1"}),
|
||||
}
|
||||
self.config_validate(node["id"], node["data"]["title"], ListOperatorNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_document_extractor_node_config(self, node: dict) -> dict:
|
||||
"""Convert Dify document-extractor node to MemoryBear DocExtractorNodeConfig.
|
||||
|
||||
Dify document-extractor data fields:
|
||||
variable_selector: list[str] - file variable path
|
||||
"""
|
||||
node_data = node["data"]
|
||||
file_selector = self._process_list_variable_literal(
|
||||
node_data.get("variable_selector", [])
|
||||
) or ""
|
||||
result = DocExtractorNodeConfig.model_construct(
|
||||
file_selector=file_selector,
|
||||
).model_dump()
|
||||
self.config_validate(node["id"], node["data"]["title"], DocExtractorNodeConfig, result)
|
||||
return result
|
||||
|
||||
@@ -45,6 +45,8 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
"question-classifier": NodeType.QUESTION_CLASSIFIER,
|
||||
"variable-aggregator": NodeType.VAR_AGGREGATOR,
|
||||
"tool": NodeType.TOOL,
|
||||
"list-operator": NodeType.LIST_OPERATOR,
|
||||
"document-extractor": NodeType.DOCUMENT_EXTRACTOR,
|
||||
"": NodeType.NOTES
|
||||
}
|
||||
|
||||
|
||||
@@ -22,6 +22,8 @@ from app.core.workflow.nodes.configs import (
|
||||
MemoryReadNodeConfig,
|
||||
MemoryWriteNodeConfig,
|
||||
NoteNodeConfig,
|
||||
ListOperatorNodeConfig,
|
||||
DocExtractorNodeConfig,
|
||||
)
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
@@ -51,6 +53,8 @@ class MemoryBearConverter(BaseConverter):
|
||||
NodeType.MEMORY_READ: MemoryReadNodeConfig,
|
||||
NodeType.MEMORY_WRITE: MemoryWriteNodeConfig,
|
||||
NodeType.NOTES: NoteNodeConfig,
|
||||
NodeType.LIST_OPERATOR: ListOperatorNodeConfig,
|
||||
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNodeConfig,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -59,6 +59,9 @@ class WorkflowResultBuilder:
|
||||
conversation_vars = variable_pool.get_all_conversation_vars()
|
||||
sys_vars = variable_pool.get_all_system_vars()
|
||||
|
||||
# 汇总所有 knowledge 节点的 citations
|
||||
citations = self.aggregate_citations(node_outputs)
|
||||
|
||||
return {
|
||||
"status": "completed" if success else "failed",
|
||||
"output": final_output,
|
||||
@@ -71,9 +74,25 @@ class WorkflowResultBuilder:
|
||||
"conversation_id": execution_context.conversation_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": token_usage,
|
||||
"citations": citations,
|
||||
"error": result.get("error"),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def aggregate_citations(node_outputs: dict) -> list:
|
||||
"""从所有 knowledge 节点的输出中汇总 citations,去重"""
|
||||
seen = set()
|
||||
citations = []
|
||||
for node_output in node_outputs.values():
|
||||
if not isinstance(node_output, dict):
|
||||
continue
|
||||
for c in node_output.get("citations", []):
|
||||
key = c.get("document_id")
|
||||
if key and key not in seen:
|
||||
seen.add(key)
|
||||
citations.append(c)
|
||||
return citations
|
||||
|
||||
@staticmethod
|
||||
def aggregate_token_usage(node_outputs: dict) -> dict[str, int] | None:
|
||||
"""
|
||||
|
||||
@@ -9,10 +9,10 @@ from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
|
||||
def merge_activate_state(x, y):
|
||||
return {
|
||||
k: x.get(k, False) or y.get(k, False)
|
||||
for k in set(x) | set(y)
|
||||
}
|
||||
merged = dict(x)
|
||||
for k, v in y.items():
|
||||
merged[k] = merged.get(k, False) or v
|
||||
return merged
|
||||
|
||||
|
||||
def merge_looping_state(x, y):
|
||||
|
||||
@@ -17,6 +17,51 @@ from app.core.workflow.variable.variable_objects import T, create_variable_insta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VARIABLE_PATTERN = re.compile(r"\{\{\s*(.*?)\s*}}")
|
||||
|
||||
|
||||
class LazyVariableDict:
|
||||
def __init__(self, source, literal):
|
||||
self._source: dict[str, VariableStruct[Any]] = source
|
||||
self._literal: bool = literal
|
||||
self._cache = {}
|
||||
|
||||
def keys(self):
|
||||
return self._source.keys()
|
||||
|
||||
def _resolve(self, key):
|
||||
if key in self._cache:
|
||||
return self._cache[key]
|
||||
var_struct = self._source.get(key)
|
||||
if var_struct is None:
|
||||
raise KeyError(key)
|
||||
value = var_struct.instance.to_literal() if self._literal else var_struct.instance.get_value()
|
||||
self._cache[key] = value
|
||||
return value
|
||||
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self._resolve(key)
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._resolve(key)
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key.startswith('_'):
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{key}'")
|
||||
return self._resolve(key)
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self._source
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._source)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._source)
|
||||
|
||||
|
||||
class VariableSelector:
|
||||
"""变量选择器
|
||||
@@ -117,8 +162,7 @@ class VariablePool:
|
||||
|
||||
@staticmethod
|
||||
def transform_selector(selector):
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
variable_literal = re.sub(pattern, r"\1", selector).strip()
|
||||
variable_literal = VARIABLE_PATTERN.sub(r"\1", selector).strip()
|
||||
selector = VariableSelector.from_string(variable_literal).path
|
||||
if len(selector) != 2:
|
||||
raise ValueError(f"Selector not valid - {selector}")
|
||||
@@ -274,7 +318,7 @@ class VariablePool:
|
||||
namespace: str,
|
||||
key: str,
|
||||
value: Any,
|
||||
var_type: VariableType,
|
||||
var_type: VariableType | None,
|
||||
mut: bool
|
||||
):
|
||||
if self.has(f"{namespace}.{key}"):
|
||||
@@ -303,6 +347,16 @@ class VariablePool:
|
||||
"""
|
||||
return self._get_variable_struct(selector) is not None
|
||||
|
||||
def lazy_namespace(self, namespace: str, literal: bool = False) -> LazyVariableDict:
|
||||
return LazyVariableDict(self.variables.get(namespace, {}), literal)
|
||||
|
||||
def lazy_all_node_outputs(self, literal: bool = False) -> dict[str, LazyVariableDict]:
|
||||
return {
|
||||
ns: LazyVariableDict(vars_dict, literal)
|
||||
for ns, vars_dict in self.variables.items()
|
||||
if ns not in ("sys", "conv")
|
||||
}
|
||||
|
||||
def get_all_system_vars(self, literal=False) -> dict[str, Any]:
|
||||
"""获取所有系统变量
|
||||
|
||||
@@ -439,6 +493,23 @@ class VariablePoolInitializer:
|
||||
var_value = var_default
|
||||
else:
|
||||
var_value = DEFAULT_VALUE(var_type)
|
||||
# Convert FileInput-format dicts to full FileObject dicts
|
||||
if var_type == VariableType.FILE:
|
||||
if not var_value:
|
||||
continue
|
||||
var_value = await self._resolve_file_default(var_value)
|
||||
if not var_value:
|
||||
continue
|
||||
elif var_type == VariableType.ARRAY_FILE:
|
||||
if not var_value:
|
||||
var_value = []
|
||||
else:
|
||||
resolved = []
|
||||
for item in var_value:
|
||||
f = await self._resolve_file_default(item)
|
||||
if f:
|
||||
resolved.append(f)
|
||||
var_value = resolved
|
||||
await variable_pool.new(
|
||||
namespace="conv",
|
||||
key=var_name,
|
||||
@@ -447,6 +518,17 @@ class VariablePoolInitializer:
|
||||
mut=True
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _resolve_file_default(file_def: dict) -> dict | None:
|
||||
"""Accept only already-resolved FileObject dicts (is_file=True).
|
||||
FileInput-format dicts are converted at save time by WorkflowService._resolve_variables_file_defaults.
|
||||
"""
|
||||
if not isinstance(file_def, dict):
|
||||
return None
|
||||
if file_def.get("is_file"):
|
||||
return file_def
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def _init_system_vars(
|
||||
variable_pool: VariablePool,
|
||||
@@ -479,5 +561,3 @@ class VariablePoolInitializer:
|
||||
var_type=var_type,
|
||||
mut=False
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -395,7 +395,8 @@ class BaseNode(ABC):
|
||||
"output": output,
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": token_usage,
|
||||
"error": None
|
||||
"error": None,
|
||||
**self._extract_extra_fields(business_result),
|
||||
}
|
||||
final_output = {
|
||||
"node_outputs": {self.node_id: node_output},
|
||||
@@ -498,6 +499,13 @@ class BaseNode(ABC):
|
||||
# Default implementation returns the business result directly
|
||||
return business_result
|
||||
|
||||
def _extract_extra_fields(self, business_result: Any) -> dict:
|
||||
"""Extracts extra fields to merge into node_output (e.g. citations).
|
||||
|
||||
Subclasses may override to inject additional metadata.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||
"""Extracts token usage information from the business result.
|
||||
|
||||
@@ -552,9 +560,9 @@ class BaseNode(ABC):
|
||||
|
||||
return render_template(
|
||||
template=template,
|
||||
conv_vars=variable_pool.get_all_conversation_vars(literal=True),
|
||||
node_outputs=variable_pool.get_all_node_outputs(literal=True),
|
||||
system_vars=variable_pool.get_all_system_vars(literal=True),
|
||||
conv_vars=variable_pool.lazy_namespace("conv", literal=True),
|
||||
node_outputs=variable_pool.lazy_all_node_outputs(literal=True),
|
||||
system_vars=variable_pool.lazy_namespace("sys", literal=True),
|
||||
strict=strict
|
||||
)
|
||||
|
||||
@@ -579,9 +587,9 @@ class BaseNode(ABC):
|
||||
|
||||
return evaluate_condition(
|
||||
expression=expression,
|
||||
conv_var=variable_pool.get_all_conversation_vars(),
|
||||
node_outputs=variable_pool.get_all_node_outputs(),
|
||||
system_vars=variable_pool.get_all_system_vars()
|
||||
conv_var=variable_pool.lazy_namespace("conv"),
|
||||
node_outputs=variable_pool.lazy_all_node_outputs(),
|
||||
system_vars=variable_pool.lazy_namespace("sys")
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -13,7 +13,7 @@ from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes import BaseNode
|
||||
from app.core.workflow.nodes.code.config import CodeNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -70,7 +70,8 @@ class CodeNode(BaseNode):
|
||||
for output in self.typed_config.output_variables:
|
||||
value = exec_result.get(output.name)
|
||||
if value is None:
|
||||
raise RuntimeError(f"Return value {output.name} does not exist")
|
||||
result[output.name] = DEFAULT_VALUE(output.type)
|
||||
continue
|
||||
match output.type:
|
||||
case VariableType.STRING:
|
||||
if not isinstance(value, str):
|
||||
|
||||
@@ -24,6 +24,8 @@ from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
||||
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
||||
from app.core.workflow.nodes.notes.config import NoteNodeConfig
|
||||
from app.core.workflow.nodes.list_operator.config import ListOperatorNodeConfig
|
||||
from app.core.workflow.nodes.document_extractor.config import DocExtractorNodeConfig
|
||||
|
||||
__all__ = [
|
||||
# 基础类
|
||||
@@ -49,5 +51,7 @@ __all__ = [
|
||||
"MemoryReadNodeConfig",
|
||||
"MemoryWriteNodeConfig",
|
||||
"CodeNodeConfig",
|
||||
"NoteNodeConfig"
|
||||
"NoteNodeConfig",
|
||||
"ListOperatorNodeConfig",
|
||||
"DocExtractorNodeConfig",
|
||||
]
|
||||
|
||||
@@ -11,7 +11,6 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
|
||||
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator, NodeType
|
||||
from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance
|
||||
from app.core.workflow.utils.expression_evaluator import evaluate_expression
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -85,12 +84,7 @@ class LoopRuntime:
|
||||
|
||||
for variable in self.typed_config.cycle_vars:
|
||||
if variable.input_type == ValueInputType.VARIABLE:
|
||||
value = evaluate_expression(
|
||||
expression=variable.value,
|
||||
conv_var=self.variable_pool.get_all_conversation_vars(),
|
||||
node_outputs=self.variable_pool.get_all_node_outputs(),
|
||||
system_vars=self.variable_pool.get_all_system_vars(),
|
||||
)
|
||||
value = self.variable_pool.get_value(variable.value)
|
||||
else:
|
||||
value = TypeTransformer.transform(variable.value, variable.type)
|
||||
await self.child_variable_pool.new(self.node_id, variable.name, value, variable.type, mut=True)
|
||||
@@ -98,12 +92,7 @@ class LoopRuntime:
|
||||
**self.state
|
||||
)
|
||||
loopstate["node_outputs"][self.node_id] = {
|
||||
variable.name: evaluate_expression(
|
||||
expression=variable.value,
|
||||
conv_var=self.variable_pool.get_all_conversation_vars(),
|
||||
node_outputs=self.variable_pool.get_all_node_outputs(),
|
||||
system_vars=self.variable_pool.get_all_system_vars(),
|
||||
)
|
||||
variable.name: self.variable_pool.get_value(variable.value)
|
||||
if variable.input_type == ValueInputType.VARIABLE
|
||||
else TypeTransformer.transform(variable.value, variable.type)
|
||||
for variable in self.typed_config.cycle_vars
|
||||
|
||||
@@ -14,12 +14,22 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def _file_object_to_file_input(f: FileObject) -> FileInput:
|
||||
"""Convert workflow FileObject to multimodal FileInput."""
|
||||
file_type = f.origin_file_type or ""
|
||||
# Prefer mime_type for more accurate type detection
|
||||
if not file_type and f.mime_type:
|
||||
file_type = f.mime_type
|
||||
resolved_type = FileType.trans(f.type) if isinstance(f.type, str) else f.type
|
||||
if resolved_type != FileType.DOCUMENT:
|
||||
raise ValueError(
|
||||
f"Document extractor only supports document files, got type '{f.type}' "
|
||||
f"(name={f.name or f.file_id or f.url})"
|
||||
)
|
||||
return FileInput(
|
||||
type=FileType.DOCUMENT,
|
||||
type=resolved_type,
|
||||
transfer_method=TransferMethod(f.transfer_method),
|
||||
url=f.url or None,
|
||||
upload_file_id=f.file_id or None,
|
||||
file_type=f.origin_file_type or "",
|
||||
file_type=file_type,
|
||||
)
|
||||
|
||||
|
||||
@@ -81,6 +91,7 @@ class DocExtractorNode(BaseNode):
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
svc = MultimodalService(db)
|
||||
for f in files:
|
||||
label = f.name or f.url or f.file_id
|
||||
try:
|
||||
file_input = _file_object_to_file_input(f)
|
||||
# Ensure URL is populated for local files
|
||||
@@ -89,11 +100,11 @@ class DocExtractorNode(BaseNode):
|
||||
# Reuse cached bytes if already fetched
|
||||
if f.get_content():
|
||||
file_input.set_content(f.get_content())
|
||||
text = await svc._extract_document_text(file_input)
|
||||
text = await svc.extract_document_text(file_input)
|
||||
chunks.append(text)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Node {self.node_id}: failed to extract file url={f.url} file_id={f.file_id}: {e}",
|
||||
f"Node {self.node_id}: failed to extract file '{label}': {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
chunks.append("")
|
||||
|
||||
@@ -24,6 +24,7 @@ class NodeType(StrEnum):
|
||||
MEMORY_READ = "memory-read"
|
||||
MEMORY_WRITE = "memory-write"
|
||||
DOCUMENT_EXTRACTOR = "document-extractor"
|
||||
LIST_OPERATOR = "list-operator"
|
||||
|
||||
UNKNOWN = "unknown"
|
||||
NOTES = "notes"
|
||||
@@ -45,6 +46,8 @@ class ComparisonOperator(StrEnum):
|
||||
LE = "le"
|
||||
GT = "gt"
|
||||
GE = "ge"
|
||||
IN = "in"
|
||||
NOT_IN = "not_in"
|
||||
|
||||
|
||||
class LogicOperator(StrEnum):
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.models import RedBearRerank, RedBearModelConfig
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory, ElasticSearchVector
|
||||
from app.core.rag.models.chunk import DocumentChunk
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.db import get_db_read
|
||||
from app.models import knowledge_model, knowledgeshare_model, ModelType
|
||||
from app.repositories import knowledge_repository, knowledgeshare_repository
|
||||
from app.models import knowledge_model, ModelType
|
||||
from app.repositories import knowledge_repository
|
||||
from app.schemas.chunk_schema import RetrieveType
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
@@ -24,13 +28,26 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
|
||||
self.vector_service: ElasticSearchVector | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {
|
||||
"output": VariableType.ARRAY_STRING
|
||||
}
|
||||
|
||||
def _extract_output(self, business_result: Any) -> Any:
|
||||
"""下游节点只拿 chunks 列表"""
|
||||
if isinstance(business_result, dict) and "chunks" in business_result:
|
||||
return business_result["chunks"]
|
||||
return business_result
|
||||
|
||||
def _extract_citations(self, business_result: Any) -> list:
|
||||
if isinstance(business_result, dict):
|
||||
return business_result.get("citations", [])
|
||||
return []
|
||||
|
||||
def _extract_extra_fields(self, business_result: Any) -> dict:
|
||||
return {"citations": self._extract_citations(business_result)}
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
return {
|
||||
"query": self._render_template(self.typed_config.query, variable_pool),
|
||||
@@ -85,46 +102,54 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
unique.append(doc)
|
||||
return unique
|
||||
|
||||
def _get_existing_kb_ids(self, db, kb_ids):
|
||||
def rerank(self, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]:
|
||||
"""
|
||||
Resolve all accessible and valid knowledge base IDs for retrieval.
|
||||
|
||||
This includes:
|
||||
- Private knowledge bases owned by the user
|
||||
- Shared knowledge bases
|
||||
- Source knowledge bases mapped via knowledge sharing relationships
|
||||
|
||||
Reorder the list of document blocks and return the top_k results most relevant to the query
|
||||
Args:
|
||||
db: Database session.
|
||||
kb_ids (list[UUID]): Knowledge base IDs from node configuration.
|
||||
query: query string
|
||||
docs: List of document chunk to be rearranged
|
||||
top_k: The number of top-level documents returned
|
||||
|
||||
Returns:
|
||||
list[UUID]: Final list of valid knowledge base IDs.
|
||||
Rearranged document chunk list (sorted in descending order of relevance)
|
||||
|
||||
Raises:
|
||||
ValueError: If the input document list is empty or top_k is invalid
|
||||
"""
|
||||
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Private)
|
||||
|
||||
existing_ids = knowledge_repository.get_chunked_knowledgeids(
|
||||
db=db,
|
||||
filters=filters
|
||||
)
|
||||
|
||||
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Share)
|
||||
|
||||
share_ids = knowledge_repository.get_chunked_knowledgeids(
|
||||
db=db,
|
||||
filters=filters
|
||||
)
|
||||
|
||||
if share_ids:
|
||||
filters = [
|
||||
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(kb_ids)
|
||||
reranker = self.get_reranker_model()
|
||||
# parameter validation
|
||||
if not docs:
|
||||
raise ValueError("retrieval chunks be empty")
|
||||
if top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
try:
|
||||
# Convert to LangChain Document object
|
||||
documents = [
|
||||
Document(
|
||||
page_content=doc.page_content, # Ensure that DocumentChunk possesses this attribute
|
||||
metadata=doc.metadata or {} # Deal with possible None metadata
|
||||
)
|
||||
for doc in docs
|
||||
]
|
||||
items = knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
|
||||
db=db,
|
||||
filters=filters
|
||||
|
||||
# Perform reordering (compress_documents will automatically handle relevance scores and indexing)
|
||||
reranked_docs = list(reranker.compress_documents(documents, query))
|
||||
|
||||
# Sort in descending order based on relevance score
|
||||
reranked_docs.sort(
|
||||
key=lambda x: x.metadata.get("relevance_score", 0),
|
||||
reverse=True
|
||||
)
|
||||
existing_ids.extend(items)
|
||||
return existing_ids
|
||||
# Convert back to a list of DocumentChunk, and save the relevance_score to metadata["score"]
|
||||
result = []
|
||||
for item in reranked_docs[:top_k]:
|
||||
for doc in docs:
|
||||
if doc.page_content == item.page_content:
|
||||
doc.metadata["score"] = item.metadata["relevance_score"]
|
||||
result.append(doc)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to rerank documents: {str(e)}") from e
|
||||
|
||||
def get_reranker_model(self) -> RedBearRerank:
|
||||
"""
|
||||
@@ -164,41 +189,77 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
)
|
||||
return reranker
|
||||
|
||||
def knowledge_retrieval(self, db, query, rs, db_knowledge, kb_config):
|
||||
async def knowledge_retrieval(self, db, query, db_knowledge, kb_config):
|
||||
rs = []
|
||||
if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER:
|
||||
children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id)
|
||||
tasks = []
|
||||
for child in children:
|
||||
if not (child and child.chunk_num > 0 and child.status == 1):
|
||||
continue
|
||||
kb_config.kb_id = child.id
|
||||
self.knowledge_retrieval(db, query, rs, child, kb_config)
|
||||
return
|
||||
self.vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
child_kb_config = kb_config.model_copy()
|
||||
child_kb_config.kb_id = child.id
|
||||
tasks.append(self.knowledge_retrieval(db, query, child, child_kb_config))
|
||||
if tasks:
|
||||
result = await asyncio.gather(*tasks)
|
||||
for _ in result:
|
||||
rs.extend(_)
|
||||
return rs
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
|
||||
match kb_config.retrieve_type:
|
||||
case RetrieveType.PARTICIPLE:
|
||||
rs.extend(self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.similarity_threshold))
|
||||
rs.extend(
|
||||
await asyncio.to_thread(
|
||||
vector_service.search_by_full_text, **{
|
||||
"query": query,
|
||||
"top_k": kb_config.top_k,
|
||||
"indices": indices,
|
||||
"score_threshold": kb_config.similarity_threshold
|
||||
}
|
||||
)
|
||||
)
|
||||
case RetrieveType.SEMANTIC:
|
||||
rs.extend(self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.vector_similarity_weight))
|
||||
rs.extend(
|
||||
await asyncio.to_thread(
|
||||
vector_service.search_by_vector, **{
|
||||
"query": query,
|
||||
"top_k": kb_config.top_k,
|
||||
"indices": indices,
|
||||
"score_threshold": kb_config.vector_similarity_weight
|
||||
}
|
||||
)
|
||||
)
|
||||
case RetrieveType.HYBRID:
|
||||
rs1 = self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.vector_similarity_weight)
|
||||
rs2 = self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.similarity_threshold)
|
||||
rs1_task = asyncio.to_thread(
|
||||
vector_service.search_by_vector, **{
|
||||
"query": query,
|
||||
"top_k": kb_config.top_k,
|
||||
"indices": indices,
|
||||
"score_threshold": kb_config.vector_similarity_weight
|
||||
}
|
||||
)
|
||||
rs2_task = asyncio.to_thread(
|
||||
vector_service.search_by_full_text, **{
|
||||
"query": query,
|
||||
"top_k": kb_config.top_k,
|
||||
"indices": indices,
|
||||
"score_threshold": kb_config.similarity_threshold
|
||||
}
|
||||
)
|
||||
rs1, rs2 = await asyncio.gather(rs1_task, rs2_task)
|
||||
|
||||
# Deduplicate hybrid retrieval results
|
||||
unique_rs = self._deduplicate_docs(rs1, rs2)
|
||||
if not unique_rs:
|
||||
return
|
||||
return []
|
||||
if self.typed_config.reranker_id:
|
||||
self.vector_service.reranker = self.get_reranker_model()
|
||||
rs.extend(self.vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
||||
rs.extend(
|
||||
await asyncio.to_thread(
|
||||
self.rerank,
|
||||
**{"query": query, "docs": unique_rs, "top_k": kb_config.top_k}
|
||||
)
|
||||
)
|
||||
else:
|
||||
rs.extend(sorted(
|
||||
unique_rs,
|
||||
@@ -207,6 +268,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
)[:kb_config.top_k])
|
||||
case _:
|
||||
raise RuntimeError("Unknown retrieval type")
|
||||
return rs
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
"""
|
||||
@@ -238,17 +300,24 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
knowledge_bases = self.typed_config.knowledge_bases
|
||||
|
||||
rs = []
|
||||
tasks = []
|
||||
for kb_config in knowledge_bases:
|
||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
|
||||
if not db_knowledge:
|
||||
raise RuntimeError("The knowledge base does not exist or access is denied.")
|
||||
self.knowledge_retrieval(db, query, rs, db_knowledge, kb_config)
|
||||
tasks.append(self.knowledge_retrieval(db, query, db_knowledge, kb_config))
|
||||
if tasks:
|
||||
result = await asyncio.gather(*tasks)
|
||||
for _ in result:
|
||||
rs.extend(_)
|
||||
|
||||
if not rs:
|
||||
return []
|
||||
if self.typed_config.reranker_id:
|
||||
self.vector_service.reranker = self.get_reranker_model()
|
||||
final_rs = self.vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
||||
final_rs = await asyncio.to_thread(
|
||||
self.rerank,
|
||||
**{"query": query, "docs": rs, "top_k": self.typed_config.reranker_top_k}
|
||||
)
|
||||
else:
|
||||
final_rs = sorted(
|
||||
rs,
|
||||
@@ -259,4 +328,20 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
logger.info(
|
||||
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
|
||||
)
|
||||
return [chunk.page_content for chunk in final_rs]
|
||||
citations = []
|
||||
seen_doc_ids = set()
|
||||
for chunk in final_rs:
|
||||
meta = chunk.metadata or {}
|
||||
doc_id = meta.get("document_id") or meta.get("doc_id")
|
||||
if doc_id and doc_id not in seen_doc_ids:
|
||||
seen_doc_ids.add(doc_id)
|
||||
citations.append({
|
||||
"document_id": str(doc_id),
|
||||
"file_name": meta.get("file_name", ""),
|
||||
"knowledge_id": str(meta.get("knowledge_id", kb_config.kb_id)),
|
||||
"score": meta.get("score", 0.0),
|
||||
})
|
||||
return {
|
||||
"chunks": [chunk.page_content for chunk in final_rs],
|
||||
"citations": citations,
|
||||
}
|
||||
|
||||
3
api/app/core/workflow/nodes/list_operator/__init__.py
Normal file
3
api/app/core/workflow/nodes/list_operator/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .node import ListOperatorNode
|
||||
|
||||
__all__ = ["ListOperatorNode"]
|
||||
49
api/app/core/workflow/nodes/list_operator/config.py
Normal file
49
api/app/core/workflow/nodes/list_operator/config.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from typing import Any
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator
|
||||
|
||||
|
||||
class FilterCondition(BaseModel):
|
||||
key: str = ""
|
||||
comparison_operator: ComparisonOperator = ComparisonOperator.CONTAINS
|
||||
value: str | list[str] | bool = ""
|
||||
|
||||
|
||||
class FilterBy(BaseModel):
|
||||
enabled: bool = False
|
||||
conditions: list[FilterCondition] = Field(default_factory=list)
|
||||
|
||||
|
||||
class OrderByConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
key: str = ""
|
||||
value: str = "asc" # "asc" | "desc"
|
||||
|
||||
|
||||
class Limit(BaseModel):
|
||||
enabled: bool = False
|
||||
size: int = -1
|
||||
|
||||
|
||||
class ExtractConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
serial: str = "1" # 1-based index string, e.g. "1" = first
|
||||
|
||||
@field_validator("serial", mode="before")
|
||||
@classmethod
|
||||
def coerce_serial(cls, v):
|
||||
return str(v)
|
||||
|
||||
|
||||
class ListOperatorNodeConfig(BaseNodeConfig):
|
||||
"""
|
||||
List Operator node config.
|
||||
Operation order: filter -> extract -> order -> limit
|
||||
"""
|
||||
input_list: str = Field(..., description="Variable selector, e.g. {{ sys.files }} or {{ conv.uploaded_files }}")
|
||||
filter_by: FilterBy = Field(default_factory=FilterBy)
|
||||
order_by: OrderByConfig = Field(default_factory=OrderByConfig)
|
||||
limit: Limit = Field(default_factory=Limit)
|
||||
extract_by: ExtractConfig = Field(default_factory=ExtractConfig)
|
||||
150
api/app/core/workflow/nodes/list_operator/node.py
Normal file
150
api/app/core/workflow/nodes/list_operator/node.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator
|
||||
from app.core.workflow.nodes.list_operator.config import ListOperatorNodeConfig, FilterCondition
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# File object fields that hold string values
|
||||
_FILE_STRING_KEYS = {"type", "name", "url", "extension", "mime_type", "transfer_method", "origin_file_type", "file_id"}
|
||||
_FILE_NUMBER_KEYS = {"size"}
|
||||
|
||||
|
||||
class ListOperatorNode(BaseNode):
|
||||
def __init__(self, node_config: dict, workflow_config: dict, down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: ListOperatorNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {
|
||||
"result": VariableType.ANY,
|
||||
"first_record": VariableType.ANY,
|
||||
"last_record": VariableType.ANY,
|
||||
}
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
self.typed_config = ListOperatorNodeConfig(**self.config)
|
||||
cfg = self.typed_config
|
||||
|
||||
# Resolve input variable from path selector
|
||||
items: list = self.get_variable(cfg.input_list, variable_pool)
|
||||
if not isinstance(items, list):
|
||||
raise TypeError(f"Variable '{cfg.input_list}' must be an array, got {type(items)}")
|
||||
|
||||
result = list(items)
|
||||
|
||||
# 1. Filter
|
||||
if cfg.filter_by.enabled and cfg.filter_by.conditions:
|
||||
for condition in cfg.filter_by.conditions:
|
||||
result = [item for item in result if self._match_condition(item, condition, variable_pool)]
|
||||
|
||||
# 2. Extract (take single item by 1-based serial index)
|
||||
if cfg.extract_by.enabled:
|
||||
serial_str = self._resolve_value(cfg.extract_by.serial, variable_pool)
|
||||
idx = int(serial_str) - 1
|
||||
if idx < 0 or idx >= len(result):
|
||||
raise ValueError(f"extract_by.serial={cfg.extract_by.serial} out of range (list length={len(result)})")
|
||||
result = [result[idx]]
|
||||
|
||||
# 3. Order
|
||||
if cfg.order_by.enabled:
|
||||
reverse = cfg.order_by.value == "desc"
|
||||
key_fn = self._make_sort_key(cfg.order_by.key)
|
||||
result = sorted(result, key=key_fn, reverse=reverse)
|
||||
|
||||
# 4. Limit (take first N)
|
||||
if cfg.limit.enabled and cfg.limit.size > 0:
|
||||
result = result[:cfg.limit.size]
|
||||
|
||||
return {
|
||||
"result": result,
|
||||
"first_record": result[0] if result else None,
|
||||
"last_record": result[-1] if result else None,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _resolve_value(value: str, variable_pool: VariablePool) -> Any:
|
||||
"""If value is a {{ namespace.key }} variable selector, resolve it from the pool.
|
||||
Otherwise return the raw string."""
|
||||
import re
|
||||
m = re.fullmatch(r"\{\{\s*(\w+\.\w+)\s*}}", value.strip())
|
||||
if m:
|
||||
resolved = variable_pool.get_value(value, default=value, strict=False)
|
||||
return resolved
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _make_sort_key(key: str):
|
||||
def key_fn(item):
|
||||
if isinstance(item, dict):
|
||||
return item.get(key) or ""
|
||||
return item
|
||||
return key_fn
|
||||
|
||||
def _match_condition(self, item: Any, cond: FilterCondition, variable_pool: VariablePool) -> bool:
|
||||
op = cond.comparison_operator
|
||||
value = cond.value
|
||||
|
||||
# Resolve value if it's a variable reference {{ namespace.key }}
|
||||
if isinstance(value, str):
|
||||
value = self._resolve_value(value, variable_pool)
|
||||
|
||||
# Resolve left value
|
||||
if isinstance(item, dict):
|
||||
left = item.get(cond.key)
|
||||
else:
|
||||
left = item # primitive array: compare element directly
|
||||
|
||||
# Determine if this field should be compared as a string
|
||||
is_string_field = isinstance(item, dict) and cond.key in _FILE_STRING_KEYS
|
||||
|
||||
# Numeric operators
|
||||
if op == ComparisonOperator.EQ:
|
||||
if is_string_field:
|
||||
return str(left) == str(value)
|
||||
return self._safe_num(left) == self._safe_num(value)
|
||||
if op == ComparisonOperator.NE:
|
||||
if is_string_field:
|
||||
return str(left) != str(value)
|
||||
return self._safe_num(left) != self._safe_num(value)
|
||||
if op == ComparisonOperator.LT:
|
||||
return self._safe_num(left) < self._safe_num(value)
|
||||
if op == ComparisonOperator.LE:
|
||||
return self._safe_num(left) <= self._safe_num(value)
|
||||
if op == ComparisonOperator.GT:
|
||||
return self._safe_num(left) > self._safe_num(value)
|
||||
if op == ComparisonOperator.GE:
|
||||
return self._safe_num(left) >= self._safe_num(value)
|
||||
|
||||
# String / sequence operators
|
||||
left_str = str(left) if left is not None else ""
|
||||
if op == ComparisonOperator.CONTAINS:
|
||||
return str(value) in left_str
|
||||
if op == ComparisonOperator.NOT_CONTAINS:
|
||||
return str(value) not in left_str
|
||||
if op == ComparisonOperator.START_WITH:
|
||||
return left_str.startswith(str(value))
|
||||
if op == ComparisonOperator.END_WITH:
|
||||
return left_str.endswith(str(value))
|
||||
if op == ComparisonOperator.IN:
|
||||
return left_str in (value if isinstance(value, list) else [str(value)])
|
||||
if op == ComparisonOperator.NOT_IN:
|
||||
return left_str not in (value if isinstance(value, list) else [str(value)])
|
||||
if op == ComparisonOperator.EMPTY:
|
||||
return not left
|
||||
if op == ComparisonOperator.NOT_EMPTY:
|
||||
return bool(left)
|
||||
|
||||
raise ValueError(f"Unsupported operator: {op}")
|
||||
|
||||
@staticmethod
|
||||
def _safe_num(v) -> float:
|
||||
try:
|
||||
return float(v)
|
||||
except (TypeError, ValueError):
|
||||
return 0.0
|
||||
@@ -213,9 +213,10 @@ class LLMNode(BaseNode):
|
||||
messages = messages[:-1] + history_message + messages[-1:]
|
||||
self.messages = messages
|
||||
else:
|
||||
# 使用简单的 prompt 格式(向后兼容)
|
||||
# 使用简单的 prompt 格式(向后兼容)——包装为标准消息列表以兼容所有 provider
|
||||
prompt_template = self.config.get("prompt", "")
|
||||
self.messages = self._render_template(prompt_template, variable_pool)
|
||||
rendered = self._render_template(prompt_template, variable_pool)
|
||||
self.messages = [{"role": "user", "content": rendered}]
|
||||
|
||||
return llm
|
||||
|
||||
@@ -245,7 +246,10 @@ class LLMNode(BaseNode):
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
|
||||
|
||||
# 返回 AIMessage(包含响应元数据)
|
||||
return AIMessage(content=content, response_metadata=response.response_metadata)
|
||||
return AIMessage(content=content, response_metadata={
|
||||
**response.response_metadata,
|
||||
"token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage')
|
||||
})
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""提取输入数据(用于记录)"""
|
||||
@@ -304,15 +308,16 @@ class LLMNode(BaseNode):
|
||||
|
||||
# 调用 LLM(流式,支持字符串或消息列表)
|
||||
last_meta_data = {}
|
||||
last_usage_metadata = {}
|
||||
async for chunk in llm.astream(self.messages):
|
||||
# 提取内容
|
||||
if hasattr(chunk, 'content'):
|
||||
content = self.process_model_output(chunk.content)
|
||||
else:
|
||||
content = str(chunk)
|
||||
if hasattr(chunk, 'response_metadata'):
|
||||
if chunk.response_metadata:
|
||||
last_meta_data = chunk.response_metadata
|
||||
if hasattr(chunk, 'response_metadata') and chunk.response_metadata:
|
||||
last_meta_data = chunk.response_metadata
|
||||
if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
|
||||
last_usage_metadata = chunk.usage_metadata
|
||||
|
||||
# 只有当内容不为空时才处理
|
||||
if content:
|
||||
@@ -335,7 +340,10 @@ class LLMNode(BaseNode):
|
||||
# 构建完整的 AIMessage(包含元数据)
|
||||
final_message = AIMessage(
|
||||
content=full_response,
|
||||
response_metadata=last_meta_data
|
||||
response_metadata={
|
||||
**last_meta_data,
|
||||
"token_usage": last_usage_metadata or last_meta_data.get('token_usage')
|
||||
}
|
||||
)
|
||||
|
||||
# yield 完成标记
|
||||
|
||||
@@ -27,6 +27,7 @@ from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from app.core.workflow.nodes.breaker import BreakNode
|
||||
from app.core.workflow.nodes.tool import ToolNode
|
||||
from app.core.workflow.nodes.document_extractor import DocExtractorNode
|
||||
from app.core.workflow.nodes.list_operator import ListOperatorNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -51,7 +52,8 @@ WorkflowNode = Union[
|
||||
MemoryReadNode,
|
||||
MemoryWriteNode,
|
||||
CodeNode,
|
||||
DocExtractorNode
|
||||
DocExtractorNode,
|
||||
ListOperatorNode
|
||||
]
|
||||
|
||||
|
||||
@@ -83,7 +85,8 @@ class NodeFactory:
|
||||
NodeType.MEMORY_READ: MemoryReadNode,
|
||||
NodeType.MEMORY_WRITE: MemoryWriteNode,
|
||||
NodeType.CODE: CodeNode,
|
||||
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNode
|
||||
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNode,
|
||||
NodeType.LIST_OPERATOR: ListOperatorNode
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -12,7 +12,7 @@ from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
from app.db import get_db_read
|
||||
from app.models import ModelType
|
||||
from app.services.model_service import ModelConfigService
|
||||
@@ -45,6 +45,12 @@ class ParameterExtractorNode(BaseNode):
|
||||
"model_id": str(self.typed_config.model_id),
|
||||
}
|
||||
|
||||
def _extract_output(self, business_result: Any) -> Any:
|
||||
final_output = {}
|
||||
for param in self.typed_config.params:
|
||||
final_output[param.name] = business_result.get(param.name) or DEFAULT_VALUE(self.output_types[param.name])
|
||||
return final_output
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
outputs = {}
|
||||
for param in self.typed_config.params:
|
||||
@@ -109,6 +115,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
api_key = api_config.api_key
|
||||
api_base = api_config.api_base
|
||||
is_omni = api_config.is_omni
|
||||
capability = api_config.capability
|
||||
model_type = config.type
|
||||
|
||||
llm = RedBearLLM(
|
||||
@@ -201,7 +208,10 @@ class ParameterExtractorNode(BaseNode):
|
||||
])
|
||||
|
||||
model_resp = await llm.ainvoke(messages)
|
||||
self.response_metadata = model_resp.response_metadata
|
||||
self.response_metadata = {
|
||||
**model_resp.response_metadata,
|
||||
"token_usage": getattr(model_resp, 'usage_metadata', None) or model_resp.response_metadata.get('token_usage')
|
||||
}
|
||||
model_message = self.process_model_output(model_resp.content)
|
||||
result = json_repair.repair_json(model_message, return_objects=True)
|
||||
logger.info(f"node: {self.node_id} get params:{result}")
|
||||
|
||||
@@ -62,6 +62,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
api_key = api_config.api_key
|
||||
base_url = api_config.api_base
|
||||
is_omni = api_config.is_omni
|
||||
capability = api_config.capability
|
||||
model_type = config.type
|
||||
|
||||
return RedBearLLM(
|
||||
@@ -135,7 +136,10 @@ class QuestionClassifierNode(BaseNode):
|
||||
|
||||
response = await llm.ainvoke(messages)
|
||||
result = self.process_model_output(response.content)
|
||||
self.response_metadata = response.response_metadata
|
||||
self.response_metadata = {
|
||||
**response.response_metadata,
|
||||
"token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage')
|
||||
}
|
||||
|
||||
if result in category_names:
|
||||
category = result
|
||||
|
||||
@@ -4,32 +4,33 @@ from typing import Any
|
||||
|
||||
from simpleeval import simple_eval, NameNotDefined, InvalidExpression
|
||||
|
||||
from app.core.workflow.engine.variable_pool import LazyVariableDict, VARIABLE_PATTERN
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_NORMALIZE_PATTERN = re.compile(r"\{\{\s*(\d+)\.(\w+)\s*}}")
|
||||
|
||||
|
||||
class ExpressionEvaluator:
|
||||
"""Safe expression evaluator for workflow variables and node outputs."""
|
||||
|
||||
|
||||
# Reserved namespaces
|
||||
RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"}
|
||||
|
||||
@classmethod
|
||||
def normalize_template(cls, template: str) -> str:
|
||||
pattern = re.compile(
|
||||
r"\{\{\s*(\d+)\.(\w+)\s*}}"
|
||||
)
|
||||
return pattern.sub(
|
||||
return _NORMALIZE_PATTERN.sub(
|
||||
r'{{ node["\1"].\2 }}',
|
||||
template
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def evaluate(
|
||||
cls,
|
||||
expression: str,
|
||||
conv_vars: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
cls,
|
||||
expression: str,
|
||||
conv_vars: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""
|
||||
Safely evaluate an expression using workflow variables.
|
||||
@@ -49,48 +50,47 @@ class ExpressionEvaluator:
|
||||
# Remove Jinja2-style brackets if present
|
||||
expression = expression.strip()
|
||||
expression = cls.normalize_template(expression)
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
expression = re.sub(pattern, r"\1", expression).strip()
|
||||
expression = VARIABLE_PATTERN.sub(r"\1", expression).strip()
|
||||
|
||||
# Build context for evaluation
|
||||
context = {
|
||||
"conv": conv_vars, # conversation variables
|
||||
"node": node_outputs, # node outputs
|
||||
"sys": system_vars or {}, # system variables
|
||||
"conv": conv_vars, # conversation variables
|
||||
"node": node_outputs, # node outputs
|
||||
"sys": system_vars or {}, # system variables
|
||||
}
|
||||
|
||||
context.update(conv_vars)
|
||||
context["nodes"] = node_outputs
|
||||
# context.update(conv_vars)
|
||||
# context["nodes"] = node_outputs
|
||||
context.update(node_outputs)
|
||||
|
||||
|
||||
try:
|
||||
# simpleeval supports safe operations:
|
||||
# arithmetic, comparisons, logical ops, attribute/dict/list access
|
||||
result = simple_eval(expression, names=context)
|
||||
return result
|
||||
|
||||
|
||||
except NameNotDefined as e:
|
||||
logger.error(f"Undefined variable in expression: {expression}, error: {e}")
|
||||
raise ValueError(f"Undefined variable: {e}")
|
||||
|
||||
|
||||
except InvalidExpression as e:
|
||||
logger.error(f"Invalid expression syntax: {expression}, error: {e}")
|
||||
raise ValueError(f"Invalid expression syntax: {e}")
|
||||
|
||||
|
||||
except SyntaxError as e:
|
||||
logger.error(f"Syntax error in expression: {expression}, error: {e}")
|
||||
raise ValueError(f"Syntax error: {e}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Expression evaluation failed: {expression}, error: {e}")
|
||||
raise ValueError(f"Expression evaluation failed: {e}")
|
||||
|
||||
|
||||
@staticmethod
|
||||
def evaluate_bool(
|
||||
expression: str,
|
||||
conv_var: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
expression: str,
|
||||
conv_var: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Evaluate a boolean expression (for conditions).
|
||||
@@ -108,7 +108,7 @@ class ExpressionEvaluator:
|
||||
expression, conv_var, node_outputs, system_vars
|
||||
)
|
||||
return bool(result)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def validate_variable_names(variables: list[dict]) -> list[str]:
|
||||
"""
|
||||
@@ -121,7 +121,7 @@ class ExpressionEvaluator:
|
||||
list[str]: List of error messages. Empty if all names are valid.
|
||||
"""
|
||||
errors = []
|
||||
|
||||
|
||||
for var in variables:
|
||||
var_name = var.get("name", "")
|
||||
|
||||
@@ -134,16 +134,16 @@ class ExpressionEvaluator:
|
||||
errors.append(
|
||||
f"Variable name '{var_name}' is not a valid Python identifier"
|
||||
)
|
||||
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def evaluate_expression(
|
||||
expression: str,
|
||||
conv_var: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any]
|
||||
expression: str,
|
||||
conv_var: dict[str, Any] | LazyVariableDict,
|
||||
node_outputs: dict[str, dict[str, Any] | LazyVariableDict],
|
||||
system_vars: dict[str, Any] | LazyVariableDict
|
||||
) -> Any:
|
||||
"""Evaluate an expression (convenience function)."""
|
||||
return ExpressionEvaluator.evaluate(
|
||||
@@ -152,11 +152,11 @@ def evaluate_expression(
|
||||
|
||||
|
||||
def evaluate_condition(
|
||||
expression: str,
|
||||
conv_var: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> bool:
|
||||
expression: str,
|
||||
conv_var: dict[str, Any] | LazyVariableDict,
|
||||
node_outputs: dict[str, dict[str, Any] | LazyVariableDict],
|
||||
system_vars: dict[str, Any] | LazyVariableDict
|
||||
) -> Any:
|
||||
"""Evaluate a boolean condition expression (convenience function)."""
|
||||
return ExpressionEvaluator.evaluate_bool(
|
||||
expression, conv_var, node_outputs, system_vars
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/3/10 13:36
|
||||
import mimetypes
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse, unquote
|
||||
|
||||
TRANSFORM_FILE_TYPE = {
|
||||
'text/plain': 'document/text',
|
||||
'text/markdown': 'document/markdown',
|
||||
@@ -52,5 +55,143 @@ ALLOWED_FILE_TYPES = [
|
||||
def mime_to_file_type(mime_type):
|
||||
if mime_type not in ALLOWED_FILE_TYPES:
|
||||
return None
|
||||
|
||||
return TRANSFORM_FILE_TYPE.get(mime_type, mime_type)
|
||||
|
||||
|
||||
def build_file_object_dict_from_url(url: str, file_type: str, origin_file_type: str) -> dict[str, Any]:
|
||||
"""Build a FileObject dict for a remote_url file using only URL parsing (no HTTP request).
|
||||
Used as fallback when HTTP request fails.
|
||||
"""
|
||||
raw_path = url.split("?")[0]
|
||||
name = unquote(os.path.basename(urlparse(url).path)) or None
|
||||
_, ext = os.path.splitext(name or "")
|
||||
extension = ext.lstrip(".").lower() if ext else None
|
||||
guessed_mime = mimetypes.guess_type(url)[0]
|
||||
return {
|
||||
"type": file_type,
|
||||
"url": url,
|
||||
"transfer_method": "remote_url",
|
||||
"origin_file_type": origin_file_type,
|
||||
"file_id": None,
|
||||
"name": name,
|
||||
"size": None,
|
||||
"extension": extension,
|
||||
"mime_type": guessed_mime or origin_file_type,
|
||||
"is_file": True,
|
||||
}
|
||||
|
||||
|
||||
async def fetch_remote_file_meta(
|
||||
url: str,
|
||||
file_type: str,
|
||||
origin_file_type: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Fetch remote file metadata via HEAD (fallback GET) and build a FileObject dict.
|
||||
Falls back to URL-only parsing if the HTTP request fails.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
name = extension = None
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.head(url, follow_redirects=True)
|
||||
if resp.status_code != 200:
|
||||
resp = await client.get(url, follow_redirects=True)
|
||||
|
||||
cl = resp.headers.get("Content-Length")
|
||||
size = int(cl) if cl else None
|
||||
|
||||
ct = resp.headers.get("Content-Type", "").split(";")[0].strip()
|
||||
mime_type = ct or origin_file_type
|
||||
|
||||
cd = resp.headers.get("Content-Disposition", "")
|
||||
if "filename=" in cd:
|
||||
name = cd.split("filename=")[-1].strip('"').strip("'")
|
||||
if not name:
|
||||
name = unquote(os.path.basename(urlparse(url).path)) or None
|
||||
|
||||
if name:
|
||||
_, ext = os.path.splitext(name)
|
||||
extension = ext.lstrip(".").lower() if ext else None
|
||||
if not extension and mime_type:
|
||||
ext = mimetypes.guess_extension(mime_type)
|
||||
extension = ext.lstrip(".").lower() if ext else None
|
||||
except Exception:
|
||||
return build_file_object_dict_from_url(url, file_type, origin_file_type)
|
||||
|
||||
return build_file_object_dict_from_meta(
|
||||
file_type=file_type,
|
||||
transfer_method="remote_url",
|
||||
origin_file_type=origin_file_type,
|
||||
file_id=None,
|
||||
url=url,
|
||||
file_name=name,
|
||||
file_size=size,
|
||||
file_ext=extension,
|
||||
content_type=mime_type,
|
||||
)
|
||||
|
||||
|
||||
def build_file_object_dict_from_meta(
|
||||
file_type: str,
|
||||
transfer_method: str,
|
||||
origin_file_type: str,
|
||||
file_id: str,
|
||||
url: str,
|
||||
file_name: str | None,
|
||||
file_size: int | None,
|
||||
file_ext: str | None,
|
||||
content_type: str | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build a FileObject dict from already-fetched FileMetadata fields."""
|
||||
ext = (file_ext or "").lstrip(".")
|
||||
return {
|
||||
"type": file_type,
|
||||
"url": url,
|
||||
"transfer_method": transfer_method,
|
||||
"origin_file_type": content_type or origin_file_type,
|
||||
"file_id": file_id,
|
||||
"name": file_name,
|
||||
"size": file_size,
|
||||
"extension": ext.lower() if ext else None,
|
||||
"mime_type": content_type,
|
||||
"is_file": True,
|
||||
}
|
||||
|
||||
|
||||
def resolve_local_file_object_dict(
|
||||
db,
|
||||
upload_file_id: str | uuid.UUID,
|
||||
file_type: str,
|
||||
origin_file_type: str,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Query FileMetadata and build a FileObject dict for a local_file.
|
||||
Returns None if the file is not found or not completed.
|
||||
"""
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.core.config import settings
|
||||
|
||||
try:
|
||||
fid = uuid.UUID(str(upload_file_id))
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
meta = db.query(FileMetadata).filter(
|
||||
FileMetadata.id == fid,
|
||||
FileMetadata.status == "completed"
|
||||
).first()
|
||||
if not meta:
|
||||
return None
|
||||
|
||||
url = f"{settings.FILE_LOCAL_SERVER_URL}/storage/permanent/{fid}"
|
||||
return build_file_object_dict_from_meta(
|
||||
file_type=file_type,
|
||||
transfer_method="local_file",
|
||||
origin_file_type=origin_file_type,
|
||||
file_id=str(fid),
|
||||
url=url,
|
||||
file_name=meta.file_name,
|
||||
file_size=meta.file_size,
|
||||
file_ext=meta.file_ext,
|
||||
content_type=meta.content_type,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""
|
||||
模板渲染器
|
||||
Template Renderer
|
||||
|
||||
使用 Jinja2 提供安全的模板渲染功能,支持变量引用和表达式。
|
||||
Provides safe template rendering using Jinja2, supporting variable references
|
||||
and expressions.
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -10,11 +11,15 @@ from typing import Any
|
||||
|
||||
from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined
|
||||
|
||||
from app.core.workflow.engine.variable_pool import LazyVariableDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_NORMALIZE_PATTERN = re.compile(r"\{\{\s*(\d+)\.(\w+)\s*}}")
|
||||
|
||||
|
||||
class SafeUndefined(Undefined):
|
||||
"""访问未定义属性不会报错,返回空字符串"""
|
||||
"""Return empty string instead of raising error when accessing undefined variables"""
|
||||
__slots__ = ()
|
||||
|
||||
def _fail_with_undefined_error(self, *args, **kwargs):
|
||||
@@ -26,26 +31,22 @@ class SafeUndefined(Undefined):
|
||||
|
||||
|
||||
class TemplateRenderer:
|
||||
"""模板渲染器"""
|
||||
|
||||
def __init__(self, strict: bool = True):
|
||||
"""初始化渲染器
|
||||
|
||||
"""Initialize renderer
|
||||
|
||||
Args:
|
||||
strict: 是否使用严格模式(未定义变量会抛出异常)
|
||||
strict: Whether to enable strict mode (raise error on undefined variables)
|
||||
"""
|
||||
self.strict = strict
|
||||
self.env = Environment(
|
||||
undefined=StrictUndefined if strict else SafeUndefined,
|
||||
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML
|
||||
autoescape=False # Disable auto-escaping since we handle plain text instead of HTML
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def normalize_template(template: str) -> str:
|
||||
pattern = re.compile(
|
||||
r"\{\{\s*(\d+)\.(\w+)\s*}}"
|
||||
)
|
||||
return pattern.sub(
|
||||
"""Normalize template syntax (convert numeric node reference to dict access)"""
|
||||
return _NORMALIZE_PATTERN.sub(
|
||||
r'{{ node["\1"].\2 }}',
|
||||
template
|
||||
)
|
||||
@@ -53,24 +54,24 @@ class TemplateRenderer:
|
||||
def render(
|
||||
self,
|
||||
template: str,
|
||||
conv_vars: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
conv_vars: dict[str, Any] | LazyVariableDict,
|
||||
node_outputs: dict[str, Any] | dict[str, LazyVariableDict],
|
||||
system_vars: dict[str, Any] | LazyVariableDict | None = None
|
||||
) -> str:
|
||||
"""渲染模板
|
||||
|
||||
"""Render template
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
conv_vars: 会话变量
|
||||
node_outputs: 节点输出结果
|
||||
system_vars: 系统变量
|
||||
|
||||
template: Template string
|
||||
conv_vars: Conversation variables
|
||||
node_outputs: Node outputs
|
||||
system_vars: System variables
|
||||
|
||||
Returns:
|
||||
渲染后的字符串
|
||||
|
||||
Rendered string
|
||||
|
||||
Raises:
|
||||
ValueError: 模板语法错误或变量未定义
|
||||
|
||||
ValueError: If template syntax is invalid or variables are undefined
|
||||
|
||||
Examples:
|
||||
>>> renderer = TemplateRenderer()
|
||||
>>> renderer.render(
|
||||
@@ -80,122 +81,119 @@ class TemplateRenderer:
|
||||
... {}
|
||||
... )
|
||||
'Hello World!'
|
||||
|
||||
|
||||
>>> renderer.render(
|
||||
... "分析结果: {{node.analyze.output}}",
|
||||
... "Analysis result: {{node.analyze.output}}",
|
||||
... {},
|
||||
... {"analyze": {"output": "正面情绪"}},
|
||||
... {"analyze": {"output": "positive sentiment"}},
|
||||
... {}
|
||||
... )
|
||||
'分析结果: 正面情绪'
|
||||
'Analysis result: positive sentiment'
|
||||
"""
|
||||
# 构建命名空间上下文
|
||||
# Build namespace context
|
||||
context = {
|
||||
"conv": conv_vars, # 会话变量:{{conv.user_name}}
|
||||
"node": node_outputs, # 节点输出:{{node.node_1.output}}
|
||||
"sys": system_vars, # 系统变量:{{sys.execution_id}}
|
||||
"conv": conv_vars, # Conversation variables: {{conv.user_name}}
|
||||
"node": node_outputs, # Node outputs: {{node.node_1.output}}
|
||||
"sys": system_vars, # System variables: {{sys.execution_id}}
|
||||
}
|
||||
|
||||
# 支持直接通过节点ID访问节点输出:{{llm_qa.output}}
|
||||
# 将所有节点输出添加到顶层上下文
|
||||
# Allow direct access to node outputs by node ID: {{llm_qa.output}}
|
||||
if node_outputs:
|
||||
context.update(node_outputs)
|
||||
|
||||
# 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}}
|
||||
if conv_vars:
|
||||
context.update(conv_vars)
|
||||
|
||||
context["nodes"] = node_outputs or {} # 旧语法兼容
|
||||
# # 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}}
|
||||
# if conv_vars:
|
||||
# context.update(conv_vars)
|
||||
#
|
||||
# context["nodes"] = node_outputs or {} # 旧语法兼容
|
||||
template = self.normalize_template(template)
|
||||
try:
|
||||
tmpl = self.env.from_string(template)
|
||||
return tmpl.render(**context)
|
||||
|
||||
except TemplateSyntaxError as e:
|
||||
logger.error(f"模板语法错误: {template}, 错误: {e}")
|
||||
raise ValueError(f"模板语法错误: {e}")
|
||||
|
||||
logger.error(f"Template syntax error: {template}, error: {e}")
|
||||
raise ValueError(f"Template syntax error: {e}")
|
||||
except UndefinedError as e:
|
||||
logger.error(f"模板中引用了未定义的变量: {template}, 错误: {e}")
|
||||
raise ValueError(f"未定义的变量: {e}")
|
||||
|
||||
logger.error(f"Undefined variable in template: {template}, error: {e}")
|
||||
raise ValueError(f"Undefined variable: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"模板渲染异常: {template}, 错误: {e}")
|
||||
raise ValueError(f"模板渲染失败: {e}")
|
||||
logger.error(f"Template rendering error: {template}, error: {e}")
|
||||
raise ValueError(f"Template rendering failed: {e}")
|
||||
|
||||
def validate(self, template: str) -> list[str]:
|
||||
"""验证模板语法
|
||||
|
||||
"""Validate template syntax
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
|
||||
template: Template string
|
||||
|
||||
Returns:
|
||||
错误列表,如果为空则验证通过
|
||||
|
||||
List of errors (empty if valid)
|
||||
|
||||
Examples:
|
||||
>>> renderer = TemplateRenderer()
|
||||
>>> renderer.validate("Hello {{var.name}}!")
|
||||
[]
|
||||
|
||||
>>> renderer.validate("Hello {{var.name") # 缺少结束标记
|
||||
['模板语法错误: ...']
|
||||
|
||||
>>> renderer.validate("Hello {{var.name") # Missing closing tag
|
||||
['Template syntax error: ...']
|
||||
"""
|
||||
errors = []
|
||||
|
||||
try:
|
||||
self.env.from_string(template)
|
||||
except TemplateSyntaxError as e:
|
||||
errors.append(f"模板语法错误: {e}")
|
||||
errors.append(f"Template syntax error: {e}")
|
||||
except Exception as e:
|
||||
errors.append(f"模板验证失败: {e}")
|
||||
errors.append(f"Template validation failed: {e}")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
# 全局渲染器实例(严格模式)
|
||||
# Global renderer instances (strict / lenient)
|
||||
_strict_renderer = TemplateRenderer(strict=True)
|
||||
_lenient_renderer = TemplateRenderer(strict=False)
|
||||
|
||||
|
||||
def render_template(
|
||||
template: str,
|
||||
conv_vars: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any],
|
||||
conv_vars: dict[str, Any] | LazyVariableDict,
|
||||
node_outputs: dict[str, Any] | dict[str, LazyVariableDict],
|
||||
system_vars: dict[str, Any] | LazyVariableDict,
|
||||
strict: bool = True
|
||||
) -> str:
|
||||
"""渲染模板(便捷函数)
|
||||
|
||||
"""Render template (convenience function)
|
||||
|
||||
Args:
|
||||
strict: 严格模式
|
||||
template: 模板字符串
|
||||
conv_vars: 会话变量
|
||||
node_outputs: 节点输出
|
||||
system_vars: 系统变量
|
||||
|
||||
strict: Whether to use strict mode
|
||||
template: Template string
|
||||
conv_vars: Conversation variables
|
||||
node_outputs: Node outputs
|
||||
system_vars: System variables
|
||||
|
||||
Returns:
|
||||
渲染后的字符串
|
||||
|
||||
Rendered string
|
||||
|
||||
Examples:
|
||||
>>> render_template(
|
||||
... "请分析: {{var.text}}",
|
||||
... {"text": "这是一段文本"},
|
||||
... "Analyze: {{var.text}}",
|
||||
... {"text": "This is a text"},
|
||||
... {},
|
||||
... {}
|
||||
... )
|
||||
'请分析: 这是一段文本'
|
||||
'Analyze: This is a text'
|
||||
"""
|
||||
renderer = _strict_renderer if strict else _lenient_renderer
|
||||
return renderer.render(template, conv_vars, node_outputs, system_vars)
|
||||
|
||||
|
||||
def validate_template(template: str) -> list[str]:
|
||||
"""验证模板语法(便捷函数)
|
||||
|
||||
"""Validate template syntax (convenience function)
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
|
||||
template: Template string
|
||||
|
||||
Returns:
|
||||
错误列表
|
||||
List of errors
|
||||
"""
|
||||
return _strict_renderer.validate(template)
|
||||
|
||||
@@ -301,7 +301,7 @@ class WorkflowValidator:
|
||||
for node in nodes:
|
||||
if node.get("type") not in [NodeType.START, NodeType.CYCLE_START, NodeType.END] and not node.get("name"):
|
||||
errors.append(
|
||||
f"节点 {node.get('id')} 缺少名称(发布时必须提供)"
|
||||
f"节点 {node.get('name')} 缺少名称(发布时必须提供)"
|
||||
)
|
||||
|
||||
# 2. 验证所有非 start/end 节点都有配置
|
||||
@@ -311,7 +311,7 @@ class WorkflowValidator:
|
||||
config = node.get("config")
|
||||
if not config or not isinstance(config, dict):
|
||||
errors.append(
|
||||
f"节点 {node.get('id')} 缺少配置(发布时必须提供)"
|
||||
f"节点 {node.get('name')} 缺少配置(发布时必须提供)"
|
||||
)
|
||||
|
||||
# 3. 验证必填变量
|
||||
|
||||
@@ -91,7 +91,7 @@ def DEFAULT_VALUE(var_type: VariableType) -> Any:
|
||||
case VariableType.OBJECT:
|
||||
return {}
|
||||
case VariableType.FILE:
|
||||
return None
|
||||
return {}
|
||||
case VariableType.ARRAY_STRING:
|
||||
return []
|
||||
case VariableType.ARRAY_NUMBER:
|
||||
@@ -113,6 +113,12 @@ class FileObject(BaseModel):
|
||||
origin_file_type: str
|
||||
file_id: str | None
|
||||
|
||||
# Extended file metadata
|
||||
name: str | None = None
|
||||
size: int | None = None
|
||||
extension: str | None = None
|
||||
mime_type: str | None = None
|
||||
|
||||
content_cache: dict = Field(default_factory=dict)
|
||||
is_file: bool
|
||||
|
||||
|
||||
@@ -66,20 +66,10 @@ class FileVariable(BaseVariable):
|
||||
type = 'file'
|
||||
|
||||
def valid_value(self, value) -> FileObject:
|
||||
|
||||
if isinstance(value, dict):
|
||||
if not value.get("is_file"):
|
||||
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
|
||||
return FileObject(
|
||||
**{
|
||||
"type": str(value.get('type')),
|
||||
"transfer_method": value.get("transfer_method"),
|
||||
"url": value.get('url'),
|
||||
"file_id": value.get("file_id"),
|
||||
"origin_file_type": value.get("origin_file_type"),
|
||||
"is_file": True
|
||||
}
|
||||
)
|
||||
return FileObject(**value)
|
||||
if isinstance(value, FileObject):
|
||||
return value
|
||||
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
|
||||
@@ -88,7 +78,7 @@ class FileVariable(BaseVariable):
|
||||
return f'{"!"if self.value.type == FileType.IMAGE else ""}[file]({self.value.url})'
|
||||
|
||||
def get_value(self) -> Any:
|
||||
return self.value.model_dump()
|
||||
return self.value.model_dump(exclude={"content_cache"})
|
||||
|
||||
async def get_content(self):
|
||||
total_bytes = 0
|
||||
@@ -186,6 +176,8 @@ def create_variable_instance(var_type: VariableType, value: Any) -> T:
|
||||
return BooleanVariable(value)
|
||||
case VariableType.OBJECT:
|
||||
return DictVariable(value)
|
||||
case VariableType.FILE:
|
||||
return FileVariable(value)
|
||||
case VariableType.ARRAY_STRING:
|
||||
return make_array(StringVariable, value)
|
||||
case VariableType.ARRAY_NUMBER:
|
||||
|
||||
@@ -62,6 +62,7 @@ async def lifespan(app: FastAPI):
|
||||
else:
|
||||
logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)")
|
||||
await create_all_indexes()
|
||||
logger.info("All neo4j indexes and constraints created successfully!")
|
||||
logger.info("应用程序启动完成")
|
||||
|
||||
|
||||
|
||||
@@ -81,7 +81,7 @@ class ModelConfig(BaseModel):
|
||||
|
||||
# 模型配置参数
|
||||
capability = Column(ARRAY(String), default=list, nullable=False, server_default=text("'{}'::varchar[]"),
|
||||
comment="模型能力列表(如['vision', 'audio', 'video'])")
|
||||
comment="模型能力列表(如['vision', 'audio', 'video', 'thinking'])")
|
||||
is_omni = Column(Boolean, default=False, nullable=False, server_default="false", comment="是否为Omni模型(使用特殊API调用)")
|
||||
config = Column(JSON, comment="模型配置参数")
|
||||
# - temperature : 控制生成文本的随机性。值越高,输出越随机、越有创造性;值越低,输出越确定、越保守。
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, time
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import func, Table, MetaData
|
||||
from uuid import UUID
|
||||
from typing import Dict, Optional, Any
|
||||
|
||||
@@ -192,10 +192,63 @@ class HomePageRepository:
|
||||
|
||||
return workspaces, app_count_dict, user_count_dict
|
||||
|
||||
@staticmethod
|
||||
def get_latest_version_introduction(db: Session) -> tuple[Optional[str], Optional[Dict[str, Any]]]:
|
||||
"""
|
||||
从数据库获取最新已发布的版本说明
|
||||
使用反射方式读取表结构,不依赖 premium 模型类
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
(版本号,版本说明字典) 的元组
|
||||
如果数据库中没有已发布的版本,返回 (None, None)
|
||||
"""
|
||||
try:
|
||||
metadata = MetaData()
|
||||
|
||||
version_notes = Table('version_notes', metadata, autoload_with=db.bind)
|
||||
|
||||
# 获取最新已发布的版本(按发布时间倒序,日期相同时按版本号倒序)
|
||||
query = db.query(version_notes).filter(
|
||||
version_notes.c.is_published == True
|
||||
).order_by(
|
||||
version_notes.c.release_date.desc(),
|
||||
version_notes.c.version.desc()
|
||||
)
|
||||
|
||||
note = query.first()
|
||||
|
||||
if not note:
|
||||
return None, None
|
||||
|
||||
version_info = {
|
||||
"introduction": {
|
||||
"codeName": note.code_name or "",
|
||||
"releaseDate": int(datetime.combine(note.release_date, time()).timestamp() * 1000) if note.release_date else 0,
|
||||
"upgradePosition": note.upgrade_position or "",
|
||||
"coreUpgrades": note.core_upgrades or []
|
||||
},
|
||||
"introduction_en": {
|
||||
"codeName": note.code_name_en or note.code_name or "",
|
||||
"releaseDate": int(datetime.combine(note.release_date, time()).timestamp() * 1000) if note.release_date else 0,
|
||||
"upgradePosition": note.upgrade_position_en or note.upgrade_position or "",
|
||||
"coreUpgrades": note.core_upgrades_en or []
|
||||
}
|
||||
}
|
||||
|
||||
return note.version, version_info
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None, None
|
||||
|
||||
@staticmethod
|
||||
def get_version_introduction(db: Session, version: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
从数据库获取版本说明(优先读取已发布的版本)
|
||||
从数据库获取指定版本说明(优先读取已发布的版本)
|
||||
使用反射方式读取表结构,不依赖 premium 模型类
|
||||
|
||||
Args:
|
||||
@@ -207,11 +260,8 @@ class HomePageRepository:
|
||||
如果数据库中没有该版本,返回 None
|
||||
"""
|
||||
try:
|
||||
from sqlalchemy import Table, MetaData
|
||||
|
||||
metadata = MetaData()
|
||||
version_notes = Table('version_notes', metadata, autoload_with=db.engine)
|
||||
version_note_items = Table('version_note_items', metadata, autoload_with=db.engine)
|
||||
|
||||
note = db.query(version_notes).filter(
|
||||
version_notes.c.version == version,
|
||||
@@ -221,31 +271,18 @@ class HomePageRepository:
|
||||
if not note:
|
||||
return None
|
||||
|
||||
items = db.query(version_note_items).filter(
|
||||
version_note_items.c.note_id == note.id
|
||||
).order_by(version_note_items.c.sort_order).all()
|
||||
|
||||
core_upgrades = []
|
||||
for item in items:
|
||||
title = item.title
|
||||
content = item.content
|
||||
if content:
|
||||
core_upgrades.append(f"{title}<br>{content}")
|
||||
else:
|
||||
core_upgrades.append(title)
|
||||
|
||||
return {
|
||||
"introduction": {
|
||||
"codeName": "",
|
||||
"releaseDate": note.release_date.isoformat() if note.release_date else "",
|
||||
"upgradePosition": "",
|
||||
"coreUpgrades": core_upgrades
|
||||
"codeName": note.code_name or "",
|
||||
"releaseDate": int(datetime.combine(note.release_date, time()).timestamp() * 1000) if note.release_date else 0,
|
||||
"upgradePosition": note.upgrade_position or "",
|
||||
"coreUpgrades": note.core_upgrades or []
|
||||
},
|
||||
"introduction_en": {
|
||||
"codeName": "",
|
||||
"releaseDate": note.release_date.isoformat() if note.release_date else "",
|
||||
"upgradePosition": "",
|
||||
"coreUpgrades": core_upgrades
|
||||
"codeName": note.code_name_en or note.code_name or "",
|
||||
"releaseDate": int(datetime.combine(note.release_date, time()).timestamp() * 1000) if note.release_date else 0,
|
||||
"upgradePosition": note.upgrade_position_en or note.upgrade_position or "",
|
||||
"coreUpgrades": note.core_upgrades_en or []
|
||||
}
|
||||
}
|
||||
except Exception:
|
||||
|
||||
@@ -78,6 +78,15 @@ class MemoryConfigRepository:
|
||||
OPTIONAL MATCH (n) WHERE n.end_user_id = $end_user_id RETURN 'ALL' AS Label, COUNT(n) AS Count
|
||||
"""
|
||||
|
||||
# 批量查询多个用户的记忆数量(简化版本,只返回total)
|
||||
SEARCH_FOR_ALL_BATCH = """
|
||||
MATCH (n) WHERE n.end_user_id IN $end_user_ids
|
||||
RETURN
|
||||
n.end_user_id as user_id,
|
||||
count(n) as total
|
||||
ORDER BY user_id
|
||||
"""
|
||||
|
||||
# Extracted entity details within group/app/user
|
||||
SEARCH_FOR_DETIALS = """
|
||||
MATCH (n:ExtractedEntity)
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
import asyncio
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
async def create_fulltext_indexes():
|
||||
"""Create full-text indexes for keyword search with BM25 scoring."""
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
|
||||
|
||||
# 创建 Statements 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
|
||||
""")
|
||||
|
||||
# # 创建 Dialogues 索引
|
||||
# await connector.execute_query("""
|
||||
# CREATE FULLTEXT INDEX dialoguesFulltext IF NOT EXISTS FOR (d:Dialogue) ON EACH [d.content]
|
||||
@@ -21,27 +21,35 @@ async def create_fulltext_indexes():
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX entitiesFulltext IF NOT EXISTS FOR (e:ExtractedEntity) ON EACH [e.name]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
|
||||
""")
|
||||
|
||||
# 创建 Chunks 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX chunksFulltext IF NOT EXISTS FOR (c:Chunk) ON EACH [c.content]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
|
||||
""")
|
||||
|
||||
# 创建 MemorySummary 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX summariesFulltext IF NOT EXISTS FOR (m:MemorySummary) ON EACH [m.content]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
""")
|
||||
# 创建 Community 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX communitiesFulltext IF NOT EXISTS FOR (c:Community) ON EACH [c.name, c.summary]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
|
||||
|
||||
# 创建 Perceptual 感知记忆索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX perceptualFulltext IF NOT EXISTS FOR (p:Perceptual) ON EACH [p.summary, p.topic, p.domain]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def create_vector_indexes():
|
||||
"""Create vector indexes for fast embedding similarity search.
|
||||
|
||||
@@ -50,8 +58,7 @@ async def create_vector_indexes():
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
|
||||
|
||||
|
||||
# Statement embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX statement_embedding_index IF NOT EXISTS
|
||||
@@ -62,8 +69,7 @@ async def create_vector_indexes():
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
|
||||
|
||||
|
||||
# Chunk embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX chunk_embedding_index IF NOT EXISTS
|
||||
@@ -75,7 +81,6 @@ async def create_vector_indexes():
|
||||
}}
|
||||
""")
|
||||
|
||||
|
||||
# Entity name embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX entity_embedding_index IF NOT EXISTS
|
||||
@@ -86,8 +91,7 @@ async def create_vector_indexes():
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
|
||||
|
||||
|
||||
# Memory summary embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX summary_embedding_index IF NOT EXISTS
|
||||
@@ -98,7 +102,7 @@ async def create_vector_indexes():
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
|
||||
|
||||
# Community summary embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS
|
||||
@@ -108,8 +112,8 @@ async def create_vector_indexes():
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
|
||||
""")
|
||||
|
||||
# Dialogue embedding index (optional)
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX dialogue_embedding_index IF NOT EXISTS
|
||||
@@ -120,15 +124,27 @@ async def create_vector_indexes():
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
|
||||
|
||||
# Perceptual summary embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX perceptual_summary_embedding_index IF NOT EXISTS
|
||||
FOR (p:Perceptual)
|
||||
ON p.summary_embedding
|
||||
OPTIONS {indexConfig: {
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def create_unique_constraints():
|
||||
"""Create uniqueness constraints for core node identifiers.
|
||||
Ensures concurrent MERGE operations remain safe and prevents duplicates.
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
try:
|
||||
# Dialogue.id unique
|
||||
await connector.execute_query(
|
||||
"""
|
||||
@@ -136,7 +152,7 @@ async def create_unique_constraints():
|
||||
FOR (d:Dialogue) REQUIRE d.id IS UNIQUE
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
# Statement.id unique
|
||||
await connector.execute_query(
|
||||
"""
|
||||
@@ -144,7 +160,7 @@ async def create_unique_constraints():
|
||||
FOR (s:Statement) REQUIRE s.id IS UNIQUE
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
# Chunk.id unique
|
||||
await connector.execute_query(
|
||||
"""
|
||||
@@ -152,13 +168,13 @@ async def create_unique_constraints():
|
||||
FOR (c:Chunk) REQUIRE c.id IS UNIQUE
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def create_all_indexes():
|
||||
"""Create all indexes and constraints in one go."""
|
||||
await create_fulltext_indexes()
|
||||
await create_vector_indexes()
|
||||
await create_unique_constraints()
|
||||
print("✓ All indexes and constraints created successfully!")
|
||||
|
||||
|
||||
@@ -1449,3 +1449,44 @@ ON CREATE SET r.end_user_id = edge.end_user_id,
|
||||
r.created_at = edge.created_at
|
||||
RETURN elementId(r) AS uuid
|
||||
"""
|
||||
|
||||
SEARCH_PERCEPTUAL_BY_KEYWORD = """
|
||||
CALL db.index.fulltext.queryNodes("perceptualFulltext", $q) YIELD node AS p, score
|
||||
WHERE p.end_user_id = $end_user_id
|
||||
RETURN p.id AS id,
|
||||
p.end_user_id AS end_user_id,
|
||||
p.perceptual_type AS perceptual_type,
|
||||
p.file_path AS file_path,
|
||||
p.file_name AS file_name,
|
||||
p.file_ext AS file_ext,
|
||||
p.summary AS summary,
|
||||
p.keywords AS keywords,
|
||||
p.topic AS topic,
|
||||
p.domain AS domain,
|
||||
p.created_at AS created_at,
|
||||
p.file_type AS file_type,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
PERCEPTUAL_EMBEDDING_SEARCH = """
|
||||
CALL db.index.vector.queryNodes('perceptual_summary_embedding_index', $limit * 100, $embedding)
|
||||
YIELD node AS p, score
|
||||
WHERE p.summary_embedding IS NOT NULL AND p.end_user_id = $end_user_id
|
||||
RETURN p.id AS id,
|
||||
p.end_user_id AS end_user_id,
|
||||
p.perceptual_type AS perceptual_type,
|
||||
p.file_path AS file_path,
|
||||
p.file_name AS file_name,
|
||||
p.file_ext AS file_ext,
|
||||
p.summary AS summary,
|
||||
p.keywords AS keywords,
|
||||
p.topic AS topic,
|
||||
p.domain AS domain,
|
||||
p.created_at AS created_at,
|
||||
p.file_type AS file_type,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
@@ -8,6 +8,7 @@ from app.repositories.neo4j.cypher_queries import (
|
||||
ENTITY_EMBEDDING_SEARCH,
|
||||
EXPAND_COMMUNITY_STATEMENTS,
|
||||
MEMORY_SUMMARY_EMBEDDING_SEARCH,
|
||||
PERCEPTUAL_EMBEDDING_SEARCH,
|
||||
SEARCH_CHUNK_BY_CHUNK_ID,
|
||||
SEARCH_CHUNKS_BY_CONTENT,
|
||||
SEARCH_COMMUNITIES_BY_KEYWORD,
|
||||
@@ -15,6 +16,7 @@ from app.repositories.neo4j.cypher_queries import (
|
||||
SEARCH_ENTITIES_BY_NAME,
|
||||
SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
|
||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
|
||||
SEARCH_PERCEPTUAL_BY_KEYWORD,
|
||||
SEARCH_STATEMENTS_BY_CREATED_AT,
|
||||
SEARCH_STATEMENTS_BY_KEYWORD,
|
||||
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
|
||||
@@ -34,11 +36,11 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _update_activation_values_batch(
|
||||
connector: Neo4jConnector,
|
||||
nodes: List[Dict[str, Any]],
|
||||
node_label: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
max_retries: int = 3
|
||||
connector: Neo4jConnector,
|
||||
nodes: List[Dict[str, Any]],
|
||||
node_label: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
max_retries: int = 3
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
批量更新节点的激活值
|
||||
@@ -58,7 +60,7 @@ async def _update_activation_values_batch(
|
||||
"""
|
||||
if not nodes:
|
||||
return []
|
||||
|
||||
|
||||
# 延迟导入以避免循环依赖
|
||||
from app.core.memory.storage_services.forgetting_engine.access_history_manager import (
|
||||
AccessHistoryManager,
|
||||
@@ -66,7 +68,7 @@ async def _update_activation_values_batch(
|
||||
from app.core.memory.storage_services.forgetting_engine.actr_calculator import (
|
||||
ACTRCalculator,
|
||||
)
|
||||
|
||||
|
||||
# 创建计算器和管理器实例
|
||||
actr_calculator = ACTRCalculator()
|
||||
access_manager = AccessHistoryManager(
|
||||
@@ -74,7 +76,7 @@ async def _update_activation_values_batch(
|
||||
actr_calculator=actr_calculator,
|
||||
max_retries=max_retries
|
||||
)
|
||||
|
||||
|
||||
# 提取节点ID列表并去重(保持原始顺序)
|
||||
seen_ids = set()
|
||||
unique_node_ids = []
|
||||
@@ -83,7 +85,7 @@ async def _update_activation_values_batch(
|
||||
if node_id and node_id not in seen_ids:
|
||||
seen_ids.add(node_id)
|
||||
unique_node_ids.append(node_id)
|
||||
|
||||
|
||||
if not unique_node_ids:
|
||||
logger.warning(f"批量更新激活值:没有有效的节点ID")
|
||||
return nodes
|
||||
@@ -95,7 +97,7 @@ async def _update_activation_values_batch(
|
||||
f"批量更新激活值:检测到重复节点,具有有效ID的节点数量={id_nodes_count}, "
|
||||
f"去重后唯一ID数量={len(unique_node_ids)}"
|
||||
)
|
||||
|
||||
|
||||
# 批量记录访问
|
||||
try:
|
||||
updated_nodes = await access_manager.record_batch_access(
|
||||
@@ -103,14 +105,14 @@ async def _update_activation_values_batch(
|
||||
node_label=node_label,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
|
||||
logger.info(
|
||||
f"批量更新激活值成功: {node_label}, "
|
||||
f"更新数量={len(updated_nodes)}/{len(unique_node_ids)}"
|
||||
)
|
||||
|
||||
|
||||
return updated_nodes
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"批量更新激活值失败: {node_label}, 错误: {str(e)}"
|
||||
@@ -120,9 +122,9 @@ async def _update_activation_values_batch(
|
||||
|
||||
|
||||
async def _update_search_results_activation(
|
||||
connector: Neo4jConnector,
|
||||
results: Dict[str, List[Dict[str, Any]]],
|
||||
end_user_id: Optional[str] = None
|
||||
connector: Neo4jConnector,
|
||||
results: Dict[str, List[Dict[str, Any]]],
|
||||
end_user_id: Optional[str] = None
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
更新搜索结果中所有知识节点的激活值
|
||||
@@ -144,11 +146,11 @@ async def _update_search_results_activation(
|
||||
'entities': 'ExtractedEntity',
|
||||
'summaries': 'MemorySummary'
|
||||
}
|
||||
|
||||
|
||||
# 并行更新所有类型的节点
|
||||
update_tasks = []
|
||||
update_keys = []
|
||||
|
||||
|
||||
for key, label in knowledge_node_types.items():
|
||||
if key in results and results[key]:
|
||||
update_tasks.append(
|
||||
@@ -160,13 +162,13 @@ async def _update_search_results_activation(
|
||||
)
|
||||
)
|
||||
update_keys.append(key)
|
||||
|
||||
|
||||
if not update_tasks:
|
||||
return results
|
||||
|
||||
|
||||
# 并行执行所有更新
|
||||
update_results = await asyncio.gather(*update_tasks, return_exceptions=True)
|
||||
|
||||
|
||||
# 更新结果字典,保留原始搜索分数
|
||||
updated_results = results.copy()
|
||||
for key, update_result in zip(update_keys, update_results):
|
||||
@@ -175,10 +177,10 @@ async def _update_search_results_activation(
|
||||
# 保留原始的 score 字段(BM25/Embedding 分数)
|
||||
original_nodes = results[key]
|
||||
updated_nodes = update_result
|
||||
|
||||
|
||||
# 创建 ID 到更新节点的映射(用于快速查找激活值数据)
|
||||
updated_map = {node.get('id'): node for node in updated_nodes if node.get('id')}
|
||||
|
||||
|
||||
# 合并数据:保留所有原始节点(包括重复的),用更新后的激活值数据填充
|
||||
merged_nodes = []
|
||||
for original_node in original_nodes:
|
||||
@@ -186,7 +188,7 @@ async def _update_search_results_activation(
|
||||
if node_id and node_id in updated_map:
|
||||
# 从原始节点开始,用更新后的激活值数据覆盖
|
||||
merged_node = original_node.copy()
|
||||
|
||||
|
||||
# 更新激活值相关字段
|
||||
activation_fields = {
|
||||
'activation_value',
|
||||
@@ -196,35 +198,35 @@ async def _update_search_results_activation(
|
||||
'importance_score',
|
||||
'version',
|
||||
'statement', # Statement 节点的内容字段
|
||||
'content' # MemorySummary 节点的内容字段
|
||||
'content' # MemorySummary 节点的内容字段
|
||||
}
|
||||
|
||||
|
||||
# 只更新激活值相关字段,保留原始节点的其他字段
|
||||
for field in activation_fields:
|
||||
if field in updated_map[node_id]:
|
||||
merged_node[field] = updated_map[node_id][field]
|
||||
|
||||
|
||||
merged_nodes.append(merged_node)
|
||||
else:
|
||||
# 如果没有更新数据,保留原始节点
|
||||
merged_nodes.append(original_node)
|
||||
|
||||
|
||||
updated_results[key] = merged_nodes
|
||||
else:
|
||||
# 更新失败,记录错误但保留原始结果
|
||||
logger.warning(
|
||||
f"更新 {key} 激活值失败: {str(update_result)}"
|
||||
)
|
||||
|
||||
|
||||
return updated_results
|
||||
|
||||
|
||||
async def search_graph(
|
||||
connector: Neo4jConnector,
|
||||
q: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: List[str] = None,
|
||||
connector: Neo4jConnector,
|
||||
q: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: List[str] = None,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Search across Statements, Entities, Chunks, and Summaries using a free-text query.
|
||||
@@ -249,41 +251,45 @@ async def search_graph(
|
||||
"""
|
||||
if include is None:
|
||||
include = ["statements", "chunks", "entities", "summaries"]
|
||||
|
||||
|
||||
# Prepare tasks for parallel execution
|
||||
tasks = []
|
||||
task_keys = []
|
||||
|
||||
|
||||
if "statements" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_KEYWORD,
|
||||
json_format=True,
|
||||
q=q,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("statements")
|
||||
|
||||
|
||||
if "entities" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
|
||||
json_format=True,
|
||||
q=q,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("entities")
|
||||
|
||||
|
||||
if "chunks" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_CHUNKS_BY_CONTENT,
|
||||
json_format=True,
|
||||
q=q,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("chunks")
|
||||
|
||||
|
||||
if "summaries" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
|
||||
json_format=True,
|
||||
q=q,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
@@ -293,15 +299,16 @@ async def search_graph(
|
||||
if "communities" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_COMMUNITIES_BY_KEYWORD,
|
||||
json_format=True,
|
||||
q=q,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("communities")
|
||||
|
||||
|
||||
# Execute all queries in parallel
|
||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
|
||||
# Build results dictionary
|
||||
results = {}
|
||||
for key, result in zip(task_keys, task_results):
|
||||
@@ -310,14 +317,14 @@ async def search_graph(
|
||||
results[key] = []
|
||||
else:
|
||||
results[key] = result
|
||||
|
||||
|
||||
# Deduplicate results before updating activation values
|
||||
# This prevents duplicates from propagating through the pipeline
|
||||
from app.core.memory.src.search import _deduplicate_results
|
||||
for key in results:
|
||||
if isinstance(results[key], list):
|
||||
results[key] = _deduplicate_results(results[key])
|
||||
|
||||
|
||||
# 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary)
|
||||
# Skip activation updates if only searching summaries (optimization)
|
||||
needs_activation_update = any(
|
||||
@@ -331,17 +338,17 @@ async def search_graph(
|
||||
results=results,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_graph_by_embedding(
|
||||
connector: Neo4jConnector,
|
||||
embedder_client,
|
||||
query_text: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: List[str] = ["statements", "chunks", "entities","summaries"],
|
||||
connector: Neo4jConnector,
|
||||
embedder_client,
|
||||
query_text: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: List[str] = ["statements", "chunks", "entities", "summaries"],
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Embedding-based semantic search across Statements, Chunks, and Entities.
|
||||
@@ -355,13 +362,13 @@ async def search_graph_by_embedding(
|
||||
- Returns up to 'limit' per included type
|
||||
"""
|
||||
import time
|
||||
|
||||
|
||||
# Get embedding for the query
|
||||
embed_start = time.time()
|
||||
embeddings = await embedder_client.response([query_text])
|
||||
embed_time = time.time() - embed_start
|
||||
print(f"[PERF] Embedding generation took: {embed_time:.4f}s")
|
||||
|
||||
logger.debug(f"[PERF] Embedding generation took: {embed_time:.4f}s")
|
||||
|
||||
if not embeddings or not embeddings[0]:
|
||||
logger.warning(
|
||||
f"search_graph_by_embedding: embedding 生成失败或为空,"
|
||||
@@ -378,6 +385,7 @@ async def search_graph_by_embedding(
|
||||
if "statements" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
STATEMENT_EMBEDDING_SEARCH,
|
||||
json_format=True,
|
||||
embedding=embedding,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
@@ -388,6 +396,7 @@ async def search_graph_by_embedding(
|
||||
if "chunks" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
CHUNK_EMBEDDING_SEARCH,
|
||||
json_format=True,
|
||||
embedding=embedding,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
@@ -398,6 +407,7 @@ async def search_graph_by_embedding(
|
||||
if "entities" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
ENTITY_EMBEDDING_SEARCH,
|
||||
json_format=True,
|
||||
embedding=embedding,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
@@ -408,6 +418,7 @@ async def search_graph_by_embedding(
|
||||
if "summaries" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
MEMORY_SUMMARY_EMBEDDING_SEARCH,
|
||||
json_format=True,
|
||||
embedding=embedding,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
@@ -418,6 +429,7 @@ async def search_graph_by_embedding(
|
||||
if "communities" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
COMMUNITY_EMBEDDING_SEARCH,
|
||||
json_format=True,
|
||||
embedding=embedding,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
@@ -428,8 +440,8 @@ async def search_graph_by_embedding(
|
||||
query_start = time.time()
|
||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
query_time = time.time() - query_start
|
||||
print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
|
||||
|
||||
logger.debug(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
|
||||
|
||||
# Build results dictionary
|
||||
results: Dict[str, List[Dict[str, Any]]] = {
|
||||
"statements": [],
|
||||
@@ -438,7 +450,7 @@ async def search_graph_by_embedding(
|
||||
"summaries": [],
|
||||
"communities": [],
|
||||
}
|
||||
|
||||
|
||||
for key, result in zip(task_keys, task_results):
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f"search_graph_by_embedding: {key} 向量查询异常: {result}")
|
||||
@@ -473,13 +485,15 @@ async def search_graph_by_embedding(
|
||||
logger.info(f"[PERF] Skipping activation updates (only summaries)")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: str,
|
||||
entities: List[Dict[str, Any]],
|
||||
use_contains_fallback: bool = True,
|
||||
batch_size: int = 500,
|
||||
max_concurrency: int = 5,
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: str,
|
||||
entities: List[Dict[str, Any]],
|
||||
use_contains_fallback: bool = True,
|
||||
batch_size: int = 500,
|
||||
max_concurrency: int = 5,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
为第二层去重消歧批量检索候选实体(适配新版 cypher_queries):
|
||||
@@ -560,14 +574,14 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
|
||||
|
||||
|
||||
async def search_graph_by_keyword_temporal(
|
||||
connector: Neo4jConnector,
|
||||
query_text: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
connector: Neo4jConnector,
|
||||
query_text: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
) -> Dict[str, List[Any]]:
|
||||
"""
|
||||
Temporal keyword search across Statements.
|
||||
@@ -579,7 +593,7 @@ async def search_graph_by_keyword_temporal(
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
if not query_text:
|
||||
print(f"query_text不能为空")
|
||||
logger.warning(f"query_text不能为空")
|
||||
return {"statements": []}
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
|
||||
@@ -591,7 +605,7 @@ async def search_graph_by_keyword_temporal(
|
||||
invalid_date=invalid_date,
|
||||
limit=limit,
|
||||
)
|
||||
print(f"查询结果为:\n{statements}")
|
||||
logger.debug(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
@@ -605,13 +619,13 @@ async def search_graph_by_keyword_temporal(
|
||||
|
||||
|
||||
async def search_graph_by_temporal(
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
@@ -632,10 +646,6 @@ async def search_graph_by_temporal(
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
|
||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
@@ -643,15 +653,15 @@ async def search_graph_by_temporal(
|
||||
results=results,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_graph_by_dialog_id(
|
||||
connector: Neo4jConnector,
|
||||
dialog_id: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
connector: Neo4jConnector,
|
||||
dialog_id: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Dialogues.
|
||||
@@ -661,7 +671,7 @@ async def search_graph_by_dialog_id(
|
||||
- Returns up to 'limit' dialogues
|
||||
"""
|
||||
if not dialog_id:
|
||||
print(f"dialog_id不能为空")
|
||||
logger.warning(f"dialog_id不能为空")
|
||||
return {"dialogues": []}
|
||||
|
||||
dialogues = await connector.execute_query(
|
||||
@@ -674,13 +684,13 @@ async def search_graph_by_dialog_id(
|
||||
|
||||
|
||||
async def search_graph_by_chunk_id(
|
||||
connector: Neo4jConnector,
|
||||
chunk_id : str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
connector: Neo4jConnector,
|
||||
chunk_id: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
if not chunk_id:
|
||||
print(f"chunk_id不能为空")
|
||||
logger.warning(f"chunk_id不能为空")
|
||||
return {"chunks": []}
|
||||
chunks = await connector.execute_query(
|
||||
SEARCH_CHUNK_BY_CHUNK_ID,
|
||||
@@ -692,10 +702,10 @@ async def search_graph_by_chunk_id(
|
||||
|
||||
|
||||
async def search_graph_community_expand(
|
||||
connector: Neo4jConnector,
|
||||
community_ids: List[str],
|
||||
end_user_id: str,
|
||||
limit: int = 10,
|
||||
connector: Neo4jConnector,
|
||||
community_ids: List[str],
|
||||
end_user_id: str,
|
||||
limit: int = 10,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
三期:社区展开检索 —— 主题 → 细节两级检索。
|
||||
@@ -748,12 +758,11 @@ async def search_graph_community_expand(
|
||||
|
||||
|
||||
async def search_graph_by_created_at(
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
|
||||
created_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
created_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
@@ -767,16 +776,11 @@ async def search_graph_by_created_at(
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_CREATED_AT,
|
||||
end_user_id=end_user_id,
|
||||
|
||||
|
||||
|
||||
created_at=created_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{end_user_id: {end_user_id} created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
@@ -784,16 +788,16 @@ async def search_graph_by_created_at(
|
||||
results=results,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_graph_by_valid_at(
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
|
||||
valid_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
valid_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
@@ -807,16 +811,11 @@ async def search_graph_by_valid_at(
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_VALID_AT,
|
||||
end_user_id=end_user_id,
|
||||
|
||||
|
||||
|
||||
valid_at=valid_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
|
||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
@@ -824,16 +823,16 @@ async def search_graph_by_valid_at(
|
||||
results=results,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_graph_g_created_at(
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
|
||||
created_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
created_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
@@ -847,16 +846,11 @@ async def search_graph_g_created_at(
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_G_CREATED_AT,
|
||||
end_user_id=end_user_id,
|
||||
|
||||
|
||||
|
||||
created_at=created_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
@@ -864,16 +858,16 @@ async def search_graph_g_created_at(
|
||||
results=results,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_graph_g_valid_at(
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
|
||||
valid_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
valid_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
@@ -887,16 +881,10 @@ async def search_graph_g_valid_at(
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_G_VALID_AT,
|
||||
end_user_id=end_user_id,
|
||||
|
||||
|
||||
valid_at=valid_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
|
||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
@@ -904,16 +892,16 @@ async def search_graph_g_valid_at(
|
||||
results=results,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_graph_l_created_at(
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
|
||||
created_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
created_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
@@ -927,16 +915,11 @@ async def search_graph_l_created_at(
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_L_CREATED_AT,
|
||||
end_user_id=end_user_id,
|
||||
|
||||
|
||||
|
||||
created_at=created_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
@@ -944,16 +927,16 @@ async def search_graph_l_created_at(
|
||||
results=results,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_graph_l_valid_at(
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
|
||||
valid_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
valid_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
@@ -967,16 +950,11 @@ async def search_graph_l_valid_at(
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_L_VALID_AT,
|
||||
end_user_id=end_user_id,
|
||||
|
||||
|
||||
|
||||
valid_at=valid_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
|
||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
@@ -984,5 +962,89 @@ async def search_graph_l_valid_at(
|
||||
results=results,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_perceptual(
|
||||
connector: Neo4jConnector,
|
||||
q: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Search Perceptual memory nodes using fulltext keyword search.
|
||||
|
||||
Matches against summary, topic, and domain fields via the perceptualFulltext index.
|
||||
|
||||
Args:
|
||||
connector: Neo4j connector
|
||||
q: Query text
|
||||
end_user_id: Optional user filter
|
||||
limit: Max results
|
||||
|
||||
Returns:
|
||||
Dictionary with 'perceptuals' key containing matched perceptual memory nodes
|
||||
"""
|
||||
try:
|
||||
perceptuals = await connector.execute_query(
|
||||
SEARCH_PERCEPTUAL_BY_KEYWORD,
|
||||
q=q,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"search_perceptual: keyword search failed: {e}")
|
||||
perceptuals = []
|
||||
|
||||
# Deduplicate
|
||||
from app.core.memory.src.search import _deduplicate_results
|
||||
perceptuals = _deduplicate_results(perceptuals)
|
||||
|
||||
return {"perceptuals": perceptuals}
|
||||
|
||||
|
||||
async def search_perceptual_by_embedding(
|
||||
connector: Neo4jConnector,
|
||||
embedder_client,
|
||||
query_text: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Search Perceptual memory nodes using embedding-based semantic search.
|
||||
|
||||
Uses cosine similarity on summary_embedding via the perceptual_summary_embedding_index.
|
||||
|
||||
Args:
|
||||
connector: Neo4j connector
|
||||
embedder_client: Embedding client with async response() method
|
||||
query_text: Query text to embed
|
||||
end_user_id: Optional user filter
|
||||
limit: Max results
|
||||
|
||||
Returns:
|
||||
Dictionary with 'perceptuals' key containing matched perceptual memory nodes
|
||||
"""
|
||||
embeddings = await embedder_client.response([query_text])
|
||||
if not embeddings or not embeddings[0]:
|
||||
logger.warning(f"search_perceptual_by_embedding: embedding generation failed for '{query_text[:50]}'")
|
||||
return {"perceptuals": []}
|
||||
|
||||
embedding = embeddings[0]
|
||||
|
||||
try:
|
||||
perceptuals = await connector.execute_query(
|
||||
PERCEPTUAL_EMBEDDING_SEARCH,
|
||||
embedding=embedding,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"search_perceptual_by_embedding: vector search failed: {e}")
|
||||
perceptuals = []
|
||||
|
||||
from app.core.memory.src.search import _deduplicate_results
|
||||
perceptuals = _deduplicate_results(perceptuals)
|
||||
|
||||
return {"perceptuals": perceptuals}
|
||||
|
||||
@@ -11,10 +11,28 @@ Classes:
|
||||
from typing import Any, List, Dict
|
||||
|
||||
from neo4j import AsyncGraphDatabase, basic_auth
|
||||
from neo4j.time import DateTime as Neo4jDateTime, Date as Neo4jDate, Time as Neo4jTime, Duration as Neo4jDuration
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
def _convert_neo4j_types(value: Any) -> Any:
|
||||
"""递归将 neo4j 原生时间类型转为 Python 原生类型 / ISO 字符串,确保可被 json.dumps 序列化。"""
|
||||
if isinstance(value, Neo4jDateTime):
|
||||
return value.to_native().isoformat() if value.tzinfo else value.iso_format()
|
||||
if isinstance(value, Neo4jDate):
|
||||
return value.iso_format()
|
||||
if isinstance(value, Neo4jTime):
|
||||
return value.iso_format()
|
||||
if isinstance(value, Neo4jDuration):
|
||||
return str(value)
|
||||
if isinstance(value, dict):
|
||||
return {k: _convert_neo4j_types(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_convert_neo4j_types(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
class Neo4jConnector:
|
||||
"""Neo4j数据库连接器
|
||||
|
||||
@@ -59,11 +77,12 @@ class Neo4jConnector:
|
||||
"""
|
||||
await self.driver.close()
|
||||
|
||||
async def execute_query(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]:
|
||||
async def execute_query(self, query: str, json_format=False, **kwargs: Any) -> List[Dict[str, Any]]:
|
||||
"""执行Cypher查询
|
||||
|
||||
Args:
|
||||
query: Cypher查询语句
|
||||
json_format: json格式化
|
||||
**kwargs: 查询参数,将作为参数传递给Cypher查询
|
||||
|
||||
Returns:
|
||||
@@ -78,7 +97,10 @@ class Neo4jConnector:
|
||||
**kwargs
|
||||
)
|
||||
records, summary, keys = result
|
||||
return [record.data() for record in records]
|
||||
if json_format:
|
||||
return [_convert_neo4j_types(record.data()) for record in records]
|
||||
else:
|
||||
return [record.data() for record in records]
|
||||
|
||||
async def execute_write_transaction(self, transaction_func, **kwargs: Any) -> Any:
|
||||
"""在写事务中执行操作
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Any, Annotated
|
||||
from typing import Any, Annotated, Literal
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import desc, select
|
||||
from fastapi import Depends
|
||||
|
||||
from app.models.workflow_model import (
|
||||
@@ -128,29 +128,36 @@ class WorkflowExecutionRepository:
|
||||
Returns:
|
||||
执行记录列表
|
||||
"""
|
||||
return self.db.query(WorkflowExecution).filter(
|
||||
stmt = select(WorkflowExecution).filter(
|
||||
WorkflowExecution.app_id == app_id
|
||||
).order_by(
|
||||
desc(WorkflowExecution.started_at)
|
||||
).limit(limit).offset(offset).all()
|
||||
).limit(limit).offset(offset)
|
||||
return list(self.db.execute(stmt).scalars())
|
||||
|
||||
def get_by_conversation_id(
|
||||
self,
|
||||
conversation_id: uuid.UUID
|
||||
conversation_id: uuid.UUID,
|
||||
status: Literal["running", "completed", "failed"] = None,
|
||||
limit_count: int = 50
|
||||
) -> list[WorkflowExecution]:
|
||||
"""根据会话 ID 获取执行记录列表
|
||||
|
||||
Args:
|
||||
limit_count:
|
||||
conversation_id: 会话 ID
|
||||
status: 状态(可选)
|
||||
|
||||
Returns:
|
||||
执行记录列表
|
||||
"""
|
||||
return self.db.query(WorkflowExecution).filter(
|
||||
stmt = select(WorkflowExecution).filter(
|
||||
WorkflowExecution.conversation_id == conversation_id
|
||||
).order_by(
|
||||
desc(WorkflowExecution.started_at)
|
||||
).all()
|
||||
)
|
||||
if status:
|
||||
stmt = stmt.filter(WorkflowExecution.status == status)
|
||||
stmt = stmt.order_by(desc(WorkflowExecution.started_at)).limit(limit_count)
|
||||
return list(self.db.execute(stmt).scalars())
|
||||
|
||||
def count_by_app_id(self, app_id: uuid.UUID) -> int:
|
||||
"""统计应用的执行次数
|
||||
@@ -199,11 +206,12 @@ class WorkflowNodeExecutionRepository:
|
||||
Returns:
|
||||
节点执行记录列表(按执行顺序排序)
|
||||
"""
|
||||
return self.db.query(WorkflowNodeExecution).filter(
|
||||
stmt = select(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.execution_id == execution_id
|
||||
).order_by(
|
||||
WorkflowNodeExecution.execution_order
|
||||
).all()
|
||||
)
|
||||
return list(self.db.execute(stmt).scalars())
|
||||
|
||||
def get_by_node_id(
|
||||
self,
|
||||
@@ -219,12 +227,13 @@ class WorkflowNodeExecutionRepository:
|
||||
Returns:
|
||||
节点执行记录列表
|
||||
"""
|
||||
return self.db.query(WorkflowNodeExecution).filter(
|
||||
stmt = select(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.execution_id == execution_id,
|
||||
WorkflowNodeExecution.node_id == node_id
|
||||
).order_by(
|
||||
WorkflowNodeExecution.retry_count
|
||||
).all()
|
||||
)
|
||||
return list(self.db.execute(stmt).scalars())
|
||||
|
||||
|
||||
# ==================== 依赖注入函数 ====================
|
||||
|
||||
@@ -241,6 +241,8 @@ class ModelParameters(BaseModel):
|
||||
presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0, description="存在惩罚")
|
||||
n: int = Field(default=1, ge=1, le=10, description="生成的回复数量")
|
||||
stop: Optional[List[str]] = Field(default=None, description="停止序列")
|
||||
deep_thinking: bool = Field(default=False, description="是否启用深度思考模式(需模型支持,如 DeepSeek-R1、QwQ 等)")
|
||||
thinking_budget_tokens: Optional[int] = Field(default=None, ge=1024, le=131072, description="深度思考 token 预算(仅部分模型支持)")
|
||||
|
||||
|
||||
class VariableDefinition(BaseModel):
|
||||
@@ -612,6 +614,7 @@ class AppChatRequest(BaseModel):
|
||||
user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)")
|
||||
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
|
||||
stream: bool = Field(default=False, description="是否流式返回")
|
||||
thinking: bool = Field(default=False, description="是否启用深度思考(需Agent配置支持)")
|
||||
files: List[FileInput] = Field(default_factory=list, description="附件列表(支持多文件)")
|
||||
|
||||
|
||||
@@ -641,6 +644,7 @@ class CitationSource(BaseModel):
|
||||
class DraftRunResponse(BaseModel):
|
||||
"""试运行响应(非流式)"""
|
||||
message: str = Field(..., description="AI 回复消息")
|
||||
reasoning_content: Optional[str] = Field(default=None, description="深度思考内容")
|
||||
conversation_id: Optional[str] = Field(default=None, description="会话ID(用于多轮对话)")
|
||||
usage: Optional[Dict[str, Any]] = Field(default=None, description="Token 使用情况")
|
||||
elapsed_time: Optional[float] = Field(default=None, description="耗时(秒)")
|
||||
@@ -648,6 +652,12 @@ class DraftRunResponse(BaseModel):
|
||||
citations: List[CitationSource] = Field(default_factory=list, description="引用来源")
|
||||
audio_url: Optional[str] = Field(default=None, description="TTS 语音URL")
|
||||
|
||||
def model_dump(self, **kwargs):
|
||||
data = super().model_dump(**kwargs)
|
||||
if not data.get("reasoning_content"):
|
||||
data.pop("reasoning_content", None)
|
||||
return data
|
||||
|
||||
|
||||
class OpeningResponse(BaseModel):
|
||||
"""应用开场白响应"""
|
||||
|
||||
@@ -31,7 +31,8 @@ class ChatRequest(BaseModel):
|
||||
stream: bool = Field(default=False, description="是否流式返回")
|
||||
web_search: bool = Field(default=False, description="是否启用网络搜索")
|
||||
memory: bool = Field(default=True, description="是否启用记忆功能")
|
||||
files: Optional[List[FileInput]] = Field(default=None, description="附件列表(支持多文件)")
|
||||
thinking: bool = Field(default=False, description="是否启用深度思考(需Agent配置支持)")
|
||||
files: List[FileInput] = Field(default_factory=list, description="附件列表(支持多文件)")
|
||||
|
||||
|
||||
# ---------- Output Schemas ----------
|
||||
|
||||
@@ -17,6 +17,7 @@ class Write_UserInput(BaseModel):
|
||||
end_user_id: str
|
||||
config_id: Optional[str] = None
|
||||
|
||||
|
||||
class AgentMemory_Long_Term(ABC):
|
||||
"""长期记忆配置常量"""
|
||||
STORAGE_NEO4J = "neo4j"
|
||||
@@ -25,8 +26,9 @@ class AgentMemory_Long_Term(ABC):
|
||||
STRATEGY_CHUNK = "chunk"
|
||||
STRATEGY_TIME = "time"
|
||||
DEFAULT_SCOPE = 6
|
||||
TIME_SCOPE=5
|
||||
class AgentMemoryDataset(ABC):
|
||||
PRONOUN=['我','本人','在下','自己','咱','鄙人','吴','余']
|
||||
NAME='用户'
|
||||
TIME_SCOPE = 5
|
||||
|
||||
|
||||
class AgentMemoryDataset(ABC):
|
||||
PRONOUN = ['我', '本人', '在下', '自己', '咱', '鄙人', '吴', '余']
|
||||
NAME = '用户'
|
||||
|
||||
@@ -138,21 +138,13 @@ class CreateEndUserRequest(BaseModel):
|
||||
"""Request schema for creating an end user.
|
||||
|
||||
Attributes:
|
||||
workspace_id: Workspace ID (required)
|
||||
other_id: External user identifier (required)
|
||||
other_name: Display name for the end user
|
||||
memory_config_id: Optional memory config ID. If not provided, uses workspace default.
|
||||
"""
|
||||
workspace_id: str = Field(..., description="Workspace ID (required)")
|
||||
other_id: str = Field(..., description="External user identifier (required)")
|
||||
other_name: Optional[str] = Field("", description="Display name")
|
||||
|
||||
@field_validator("workspace_id")
|
||||
@classmethod
|
||||
def validate_workspace_id(cls, v: str) -> str:
|
||||
"""Validate that workspace_id is not empty."""
|
||||
if not v or not v.strip():
|
||||
raise ValueError("workspace_id is required and cannot be empty")
|
||||
return v.strip()
|
||||
memory_config_id: Optional[str] = Field(None, description="Memory config ID. Falls back to workspace default if not provided.")
|
||||
|
||||
@field_validator("other_id")
|
||||
@classmethod
|
||||
@@ -171,11 +163,13 @@ class CreateEndUserResponse(BaseModel):
|
||||
other_id: External user identifier
|
||||
other_name: Display name
|
||||
workspace_id: Workspace the user belongs to
|
||||
memory_config_id: Connected memory config ID
|
||||
"""
|
||||
id: str = Field(..., description="End user UUID")
|
||||
other_id: str = Field(..., description="External user identifier")
|
||||
other_name: str = Field("", description="Display name")
|
||||
workspace_id: str = Field(..., description="Workspace ID")
|
||||
memory_config_id: Optional[str] = Field(None, description="Connected memory config ID")
|
||||
|
||||
|
||||
class MemoryConfigItem(BaseModel):
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||
from app.db import get_db
|
||||
from app.models import MultiAgentConfig, AgentConfig, ModelType
|
||||
from app.models import WorkflowConfig
|
||||
@@ -20,11 +21,11 @@ from app.schemas.model_schema import ModelInfo
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
from app.services.workflow_service import WorkflowService
|
||||
from app.schemas import FileType
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -43,18 +44,17 @@ class AppChatService:
|
||||
message: str,
|
||||
conversation_id: uuid.UUID,
|
||||
config: AgentConfig,
|
||||
user_id: Optional[str] = None,
|
||||
files: list[FileInput],
|
||||
user_id: str,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
workspace_id: Optional[str] = None,
|
||||
files: Optional[List[FileInput]] = None
|
||||
workspace_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""聊天(非流式)"""
|
||||
start_time = time.time()
|
||||
config_id = None
|
||||
|
||||
# 应用 features 配置
|
||||
features_config: dict = config.features or {}
|
||||
@@ -93,7 +93,8 @@ class AppChatService:
|
||||
tools.extend(skill_tools)
|
||||
if skill_prompts:
|
||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)
|
||||
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval,
|
||||
user_id)
|
||||
tools.extend(kb_tools)
|
||||
memory_flag = False
|
||||
if memory:
|
||||
@@ -116,7 +117,9 @@ class AppChatService:
|
||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
|
||||
deep_thinking=model_parameters.get("deep_thinking", False),
|
||||
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
|
||||
capability=api_key_obj.capability or [],
|
||||
)
|
||||
|
||||
model_info = ModelInfo(
|
||||
@@ -168,11 +171,6 @@ class AppChatService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=None,
|
||||
end_user_id=user_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag,
|
||||
files=processed_files # 传递处理后的文件
|
||||
)
|
||||
|
||||
@@ -209,7 +207,8 @@ class AppChatService:
|
||||
"model": api_key_obj.model_name,
|
||||
"usage": result.get("usage", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}),
|
||||
"audio_url": None,
|
||||
"citations": filtered_citations
|
||||
"citations": filtered_citations,
|
||||
"reasoning_content": result.get("reasoning_content")
|
||||
}
|
||||
if files:
|
||||
for f in files:
|
||||
@@ -229,6 +228,26 @@ class AppChatService:
|
||||
# 保存消息
|
||||
if audio_url:
|
||||
assistant_meta["audio_url"] = audio_url
|
||||
if memory_flag:
|
||||
connected_config = get_end_user_connected_config(user_id, self.db)
|
||||
memory_config_id: str = connected_config.get("memory_config_id")
|
||||
file_list = []
|
||||
for file in files:
|
||||
file_dict = file.model_dump()
|
||||
file_dict["upload_file_id"] = str(file_dict["upload_file_id"]) if file_dict["upload_file_id"] else None
|
||||
file_list.append(file_dict)
|
||||
messages = [
|
||||
{"role": "user", "content": message, "files": file_list},
|
||||
{"role": "assistant", "content": result["content"]}
|
||||
]
|
||||
if memory_config_id:
|
||||
await write_long_term(
|
||||
storage_type,
|
||||
user_id,
|
||||
messages,
|
||||
user_rag_memory_id,
|
||||
memory_config_id
|
||||
)
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
@@ -247,6 +266,7 @@ class AppChatService:
|
||||
"conversation_id": conversation_id,
|
||||
"message_id": str(message_id),
|
||||
"message": result["content"],
|
||||
"reasoning_content": result.get("reasoning_content"),
|
||||
"usage": result.get("usage", {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
@@ -264,20 +284,19 @@ class AppChatService:
|
||||
message: str,
|
||||
conversation_id: uuid.UUID,
|
||||
config: AgentConfig,
|
||||
files: list[FileInput],
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
workspace_id: Optional[str] = None,
|
||||
files: Optional[List[FileInput]] = None
|
||||
workspace_id: Optional[str] = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""聊天(流式)"""
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
config_id = None
|
||||
message_id = uuid.uuid4()
|
||||
|
||||
# 应用 features 配置
|
||||
@@ -319,7 +338,8 @@ class AppChatService:
|
||||
tools.extend(skill_tools)
|
||||
if skill_prompts:
|
||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)
|
||||
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(
|
||||
config.knowledge_retrieval, user_id)
|
||||
tools.extend(kb_tools)
|
||||
# 添加长期记忆工具
|
||||
memory_flag = False
|
||||
@@ -343,7 +363,10 @@ class AppChatService:
|
||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
streaming=True
|
||||
streaming=True,
|
||||
deep_thinking=model_parameters.get("deep_thinking", False),
|
||||
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
|
||||
capability=api_key_obj.capability or [],
|
||||
)
|
||||
|
||||
model_info = ModelInfo(
|
||||
@@ -392,6 +415,7 @@ class AppChatService:
|
||||
|
||||
# 流式调用 Agent(支持多模态),同时并行启动 TTS
|
||||
full_content = ""
|
||||
full_reasoning = ""
|
||||
total_tokens = 0
|
||||
|
||||
text_queue: asyncio.Queue = asyncio.Queue()
|
||||
@@ -411,15 +435,13 @@ class AppChatService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=None,
|
||||
end_user_id=user_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag,
|
||||
files=processed_files
|
||||
):
|
||||
if isinstance(chunk, int):
|
||||
total_tokens = chunk
|
||||
elif isinstance(chunk, dict) and chunk.get("type") == "reasoning":
|
||||
full_reasoning += chunk['content']
|
||||
yield f"event: reasoning\ndata: {json.dumps({'content': chunk['content']}, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
full_content += chunk
|
||||
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
|
||||
@@ -459,14 +481,15 @@ class AppChatService:
|
||||
|
||||
# 保存消息
|
||||
human_meta = {
|
||||
"files":[],
|
||||
"files": [],
|
||||
"history_files": {}
|
||||
}
|
||||
assistant_meta = {
|
||||
"model": api_key_obj.model_name,
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens},
|
||||
"audio_url": None,
|
||||
"citations": filtered_citations
|
||||
"citations": filtered_citations,
|
||||
"reasoning_content": full_reasoning or None
|
||||
}
|
||||
|
||||
if files:
|
||||
@@ -484,6 +507,27 @@ class AppChatService:
|
||||
|
||||
if stream_audio_url:
|
||||
assistant_meta["audio_url"] = stream_audio_url
|
||||
|
||||
if memory_flag:
|
||||
connected_config = get_end_user_connected_config(user_id, self.db)
|
||||
memory_config_id: str = connected_config.get("memory_config_id")
|
||||
file_list = []
|
||||
for file in files:
|
||||
file_dict = file.model_dump()
|
||||
file_dict["upload_file_id"] = str(file_dict["upload_file_id"]) if file_dict["upload_file_id"] else None
|
||||
file_list.append(file_dict)
|
||||
messages = [
|
||||
{"role": "user", "content": message, "files": file_list},
|
||||
{"role": "assistant", "content": full_content}
|
||||
]
|
||||
if memory_config_id:
|
||||
await write_long_term(
|
||||
storage_type,
|
||||
user_id,
|
||||
messages,
|
||||
user_rag_memory_id,
|
||||
memory_config_id
|
||||
)
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
@@ -618,7 +662,6 @@ class AppChatService:
|
||||
# 2. 创建编排器
|
||||
orchestrator = MultiAgentOrchestrator(self.db, config)
|
||||
|
||||
|
||||
# 3. 流式执行任务
|
||||
async for event in orchestrator.execute_stream(
|
||||
message=message,
|
||||
@@ -631,13 +674,13 @@ class AppChatService:
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
if "sub_usage" in event:
|
||||
# 拦截 sub_usage 事件,累加 token
|
||||
if "event: sub_usage" in event:
|
||||
if "data:" in event:
|
||||
try:
|
||||
data_line = event.split("data: ", 1)[1].strip()
|
||||
data = json.loads(data_line)
|
||||
if "total_tokens" in data:
|
||||
total_tokens += data["total_tokens"]
|
||||
total_tokens += data.get("total_tokens", 0)
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
|
||||
@@ -13,7 +13,7 @@ import uuid
|
||||
from typing import Annotated, Any, Dict, List, Optional, Tuple
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy import and_, delete, func, or_, select
|
||||
from sqlalchemy import and_, delete, func, or_, select, update as sa_update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
@@ -401,7 +401,7 @@ class AppService:
|
||||
def _create_workflow_config(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
data: app_schema.WorkflowConfigCreate,
|
||||
data,
|
||||
now: datetime.datetime
|
||||
):
|
||||
workflow_cfg = WorkflowConfig(
|
||||
@@ -678,7 +678,9 @@ class AppService:
|
||||
self._create_multi_agent_config(app.id, data.multi_agent_config, now)
|
||||
|
||||
if app.type == "workflow" and data.workflow_config:
|
||||
self._create_workflow_config(app.id, data.workflow_config, now)
|
||||
from app.schemas.workflow_schema import WorkflowConfigCreate
|
||||
wf_data = WorkflowConfigCreate(**data.workflow_config) if isinstance(data.workflow_config, dict) else data.workflow_config
|
||||
self._create_workflow_config(app.id, wf_data, now)
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(app)
|
||||
@@ -757,6 +759,17 @@ class AppService:
|
||||
|
||||
# 逻辑删除应用
|
||||
app.is_active = False
|
||||
|
||||
# 更新 app_shares 表中该应用的所有共享记录为失效状态,并更新 updated_at 时间
|
||||
stmt = sa_update(AppShare).where(
|
||||
AppShare.source_app_id == app_id,
|
||||
AppShare.is_active.is_(True)
|
||||
).values(
|
||||
is_active=False,
|
||||
updated_at=datetime.datetime.now()
|
||||
)
|
||||
self.db.execute(stmt)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
logger.info(
|
||||
@@ -1347,6 +1360,7 @@ class AppService:
|
||||
variables=cfg.get("variables", []),
|
||||
execution_config=cfg.get("execution_config", {}),
|
||||
triggers=cfg.get("triggers", []),
|
||||
features=cfg.get("features", {}),
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
|
||||
@@ -534,6 +534,7 @@ class ConversationService:
|
||||
api_key = api_config.api_key
|
||||
api_base = api_config.api_base
|
||||
is_omni = api_config.is_omni
|
||||
capability = api_config.capability
|
||||
model_type = config.type
|
||||
|
||||
llm = RedBearLLM(
|
||||
@@ -542,7 +543,8 @@ class ConversationService:
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
is_omni=is_omni
|
||||
is_omni=is_omni,
|
||||
support_thinking="thinking" in (capability or []),
|
||||
),
|
||||
type=ModelType(model_type)
|
||||
)
|
||||
|
||||
@@ -24,7 +24,7 @@ from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.db import get_db_context
|
||||
from app.models import AgentConfig, ModelConfig, ModelType
|
||||
from app.models import AgentConfig, ModelConfig
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas.app_schema import FileInput, Citation
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
@@ -37,7 +37,6 @@ from app.services.model_parameter_merger import ModelParameterMerger
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
from app.services.tool_service import ToolService
|
||||
from app.schemas import FileType
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -459,7 +458,7 @@ class AgentRunService:
|
||||
|
||||
statement = opening["statement"]
|
||||
suggested_questions = opening["suggested_questions"]
|
||||
|
||||
|
||||
# 如果有变量,进行替换(仅支持 {{var_name}} 格式)
|
||||
if variables:
|
||||
for var_name, var_value in variables.items():
|
||||
@@ -596,6 +595,9 @@ class AgentRunService:
|
||||
max_tokens=effective_params.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
deep_thinking=effective_params.get("deep_thinking", False),
|
||||
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
|
||||
capability=api_key_config.get("capability", []),
|
||||
)
|
||||
|
||||
# 5. 处理会话ID(创建或验证),新会话时写入开场白
|
||||
@@ -661,11 +663,6 @@ class AgentRunService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=context,
|
||||
end_user_id=user_id,
|
||||
config_id=config_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_flag=memory_flag,
|
||||
files=processed_files # 传递处理后的文件
|
||||
)
|
||||
|
||||
@@ -695,7 +692,8 @@ class AgentRunService:
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
})
|
||||
}),
|
||||
"reasoning_content": result.get("reasoning_content")
|
||||
},
|
||||
files=files,
|
||||
processed_files=processed_files,
|
||||
@@ -707,6 +705,7 @@ class AgentRunService:
|
||||
|
||||
response = {
|
||||
"message": result["content"],
|
||||
"reasoning_content": result.get("reasoning_content"),
|
||||
"conversation_id": conversation_id,
|
||||
"usage": result.get("usage", {
|
||||
"prompt_tokens": 0,
|
||||
@@ -844,7 +843,10 @@ class AgentRunService:
|
||||
max_tokens=effective_params.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
streaming=True
|
||||
streaming=True,
|
||||
deep_thinking=effective_params.get("deep_thinking", False),
|
||||
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
|
||||
capability=api_key_config.get("capability", []),
|
||||
)
|
||||
|
||||
# 5. 处理会话ID(创建或验证),新会话时写入开场白
|
||||
@@ -904,6 +906,7 @@ class AgentRunService:
|
||||
|
||||
# 9. 流式调用 Agent(支持多模态),同时并行启动 TTS
|
||||
full_content = ""
|
||||
full_reasoning = ""
|
||||
total_tokens = 0
|
||||
|
||||
# 启动流式 TTS(文本边输出边合成)
|
||||
@@ -918,15 +921,13 @@ class AgentRunService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=context,
|
||||
end_user_id=user_id,
|
||||
config_id=config_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_flag=memory_flag,
|
||||
files=processed_files
|
||||
):
|
||||
if isinstance(chunk, int):
|
||||
total_tokens = chunk
|
||||
elif isinstance(chunk, dict) and chunk.get("type") == "reasoning":
|
||||
full_reasoning += chunk["content"]
|
||||
yield self._format_sse_event("reasoning", {"content": chunk["content"]})
|
||||
else:
|
||||
full_content += chunk
|
||||
yield self._format_sse_event("message", {"content": chunk})
|
||||
@@ -955,7 +956,8 @@ class AgentRunService:
|
||||
app_id=agent_config.app_id,
|
||||
user_id=user_id,
|
||||
meta_data={
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens},
|
||||
"reasoning_content": full_reasoning or None
|
||||
},
|
||||
files=files,
|
||||
processed_files=processed_files,
|
||||
@@ -1676,7 +1678,7 @@ class AgentRunService:
|
||||
"""从 text_queue 取文本按句子切分后喂给 synthesizer"""
|
||||
import re
|
||||
buf = ""
|
||||
sentence_end = re.compile(r'[\u3002\uff01\uff1f\.!?\n]')
|
||||
sentence_end = re.compile(r'[\u3002\uff01\uff1f.!?\n]')
|
||||
while True:
|
||||
chunk = await text_queue.get()
|
||||
if chunk is None:
|
||||
@@ -1905,6 +1907,7 @@ class AgentRunService:
|
||||
"conversation_id": result['conversation_id'],
|
||||
"parameters_used": model_info["parameters"],
|
||||
"message": result.get("message"),
|
||||
"reasoning_content": result.get("reasoning_content"),
|
||||
"usage": usage,
|
||||
"elapsed_time": elapsed,
|
||||
"tokens_per_second": (
|
||||
@@ -2023,7 +2026,7 @@ class AgentRunService:
|
||||
# 需要从 ModelApiKey 获取实际的模型名称,或者在 ModelConfig 中添加 model 字段
|
||||
return None
|
||||
|
||||
def _with_parameters(self, agent_config: AgentConfig, parameters: Dict[str, Any]) -> AgentConfig:
|
||||
def _with_parameters(self, agent_config: AgentConfig, parameters: Dict[str, Any]) -> tuple[AgentConfig, Any]:
|
||||
"""创建一个带有覆盖参数的 agent_config(浅拷贝,只修改 model_parameters)
|
||||
|
||||
Args:
|
||||
@@ -2121,6 +2124,7 @@ class AgentRunService:
|
||||
|
||||
start_time = time.time()
|
||||
full_content = ""
|
||||
full_reasoning = ""
|
||||
returned_conversation_id = model_conversation_id
|
||||
audio_url = None
|
||||
audio_status = None
|
||||
@@ -2179,6 +2183,18 @@ class AgentRunService:
|
||||
"content": chunk
|
||||
}))
|
||||
|
||||
# 转发深度思考事件(带模型标识)
|
||||
if event_type == "reasoning" and event_data:
|
||||
reasoning_chunk = event_data.get("content", "")
|
||||
full_reasoning += reasoning_chunk
|
||||
await event_queue.put(self._format_sse_event("model_reasoning", {
|
||||
"model_index": idx,
|
||||
"model_config_id": model_config_id,
|
||||
"label": model_label,
|
||||
"conversation_id": returned_conversation_id,
|
||||
"content": event_data.get("content", "")
|
||||
}))
|
||||
|
||||
# 从 end 事件中提取 features 输出字段
|
||||
if event_type == "end" and event_data:
|
||||
audio_url = event_data.get("audio_url")
|
||||
@@ -2210,6 +2226,7 @@ class AgentRunService:
|
||||
"conversation_id": returned_conversation_id,
|
||||
"parameters_used": model_info["parameters"],
|
||||
"message": full_content,
|
||||
"reasoning_content": full_reasoning or None,
|
||||
"elapsed_time": elapsed,
|
||||
"audio_url": audio_url,
|
||||
"audio_status": audio_status,
|
||||
@@ -2362,6 +2379,7 @@ class AgentRunService:
|
||||
"label": r["label"],
|
||||
"conversation_id": r.get("conversation_id"),
|
||||
"message": r.get("message"),
|
||||
"reasoning_content": r.get("reasoning_content"),
|
||||
"elapsed_time": r.get("elapsed_time", 0),
|
||||
"audio_url": r.get("audio_url"),
|
||||
"audio_status": r.get("audio_status"),
|
||||
|
||||
@@ -415,6 +415,7 @@ class LLMRouter:
|
||||
api_key=api_key_config.api_key,
|
||||
base_url=api_key_config.api_base,
|
||||
is_omni=api_key_config.is_omni,
|
||||
support_thinking="thinking" in (api_key_config.capability or []),
|
||||
temperature=0.3,
|
||||
max_tokens=500
|
||||
)
|
||||
|
||||
@@ -393,6 +393,7 @@ class MasterAgentRouter:
|
||||
api_key=api_key_config.api_key,
|
||||
base_url=api_key_config.api_base,
|
||||
is_omni=api_key_config.is_omni,
|
||||
support_thinking="thinking" in (api_key_config.capability or []),
|
||||
extra_params = extra_params
|
||||
)
|
||||
|
||||
@@ -403,6 +404,17 @@ class MasterAgentRouter:
|
||||
response = await llm.ainvoke(prompt)
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
|
||||
|
||||
# 提取 token 消耗
|
||||
self._last_routing_tokens = 0
|
||||
if hasattr(response, 'usage_metadata') and response.usage_metadata:
|
||||
um = response.usage_metadata
|
||||
self._last_routing_tokens = um.get("total_tokens", 0) if isinstance(um, dict) else getattr(um, "total_tokens", 0)
|
||||
elif hasattr(response, 'response_metadata') and response.response_metadata:
|
||||
token_usage = response.response_metadata.get("token_usage") or response.response_metadata.get("usage", {})
|
||||
if isinstance(token_usage, dict):
|
||||
self._last_routing_tokens = token_usage.get("total_tokens", 0)
|
||||
logger.info(f"Master Agent 路由 token 消耗: {self._last_routing_tokens}")
|
||||
|
||||
# 提取响应内容
|
||||
if hasattr(response, 'content'):
|
||||
return response.content
|
||||
|
||||
@@ -37,6 +37,7 @@ from app.core.memory.agent.utils.type_classifier import status_typle
|
||||
from app.core.memory.agent.utils.write_tools import write as write_neo4j
|
||||
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.memory.utils.log.audit_logger import audit_logger
|
||||
from app.db import get_db_context
|
||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
@@ -49,10 +50,6 @@ from app.services.memory_konwledges_server import (
|
||||
)
|
||||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||
|
||||
try:
|
||||
from app.core.memory.utils.log.audit_logger import audit_logger
|
||||
except ImportError:
|
||||
audit_logger = None
|
||||
logger = get_logger(__name__)
|
||||
config_logger = get_config_logger()
|
||||
|
||||
@@ -68,24 +65,22 @@ class MemoryAgentService:
|
||||
if str(messages) == 'success':
|
||||
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
|
||||
# 记录成功的操作
|
||||
if audit_logger:
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||
success=True,
|
||||
duration=duration, details={"message_length": len(message)})
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||
success=True,
|
||||
duration=duration, details={"message_length": len(message)})
|
||||
return context
|
||||
else:
|
||||
logger.warning(f"Write operation failed for group {end_user_id}")
|
||||
|
||||
# 记录失败的操作
|
||||
if audit_logger:
|
||||
audit_logger.log_operation(
|
||||
operation="WRITE",
|
||||
config_id=config_id,
|
||||
end_user_id=end_user_id,
|
||||
success=False,
|
||||
duration=duration,
|
||||
error=f"写入失败: {messages[:100]}"
|
||||
)
|
||||
audit_logger.log_operation(
|
||||
operation="WRITE",
|
||||
config_id=config_id,
|
||||
end_user_id=end_user_id,
|
||||
success=False,
|
||||
duration=duration,
|
||||
error=f"写入失败: {messages[:100]}"
|
||||
)
|
||||
|
||||
raise ValueError(f"写入失败: {messages}")
|
||||
|
||||
@@ -338,10 +333,9 @@ class MemoryAgentService:
|
||||
logger.error(error_msg)
|
||||
|
||||
# Log failed operation
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||
success=False, duration=duration, error=error_msg)
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||
success=False, duration=duration, error=error_msg)
|
||||
|
||||
raise ValueError(error_msg)
|
||||
|
||||
@@ -401,10 +395,10 @@ class MemoryAgentService:
|
||||
# Ensure proper error handling and logging
|
||||
error_msg = f"Write operation failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||
success=False, duration=duration, error=error_msg)
|
||||
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||
success=False, duration=duration, error=error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
async def read_memory(
|
||||
@@ -468,12 +462,6 @@ class MemoryAgentService:
|
||||
|
||||
logger.info(f"Read operation for group {end_user_id} with config_id {config_id}")
|
||||
|
||||
# 导入审计日志记录器
|
||||
try:
|
||||
from app.core.memory.utils.log.audit_logger import audit_logger
|
||||
except ImportError:
|
||||
audit_logger = None
|
||||
|
||||
config_load_start = time.time()
|
||||
try:
|
||||
# Use a separate database session to avoid transaction failures
|
||||
@@ -492,16 +480,15 @@ class MemoryAgentService:
|
||||
logger.error(error_msg)
|
||||
|
||||
# Log failed operation
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
end_user_id=end_user_id,
|
||||
success=False,
|
||||
duration=duration,
|
||||
error=error_msg
|
||||
)
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
end_user_id=end_user_id,
|
||||
success=False,
|
||||
duration=duration,
|
||||
error=error_msg
|
||||
)
|
||||
|
||||
raise ValueError(error_msg)
|
||||
|
||||
@@ -515,10 +502,13 @@ class MemoryAgentService:
|
||||
async with make_read_graph() as graph:
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
|
||||
"end_user_id": end_user_id
|
||||
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
|
||||
"memory_config": memory_config}
|
||||
initial_state = {
|
||||
"messages": [HumanMessage(content=message)],
|
||||
"search_switch": search_switch,
|
||||
"end_user_id": end_user_id
|
||||
, "storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"memory_config": memory_config}
|
||||
# 获取节点更新信息
|
||||
_intermediate_outputs = []
|
||||
summary = ''
|
||||
@@ -530,7 +520,7 @@ class MemoryAgentService:
|
||||
for node_name, node_data in update_event.items():
|
||||
# if 'save_neo4j' == node_name:
|
||||
# massages = node_data
|
||||
print(f"处理节点: {node_name}")
|
||||
logger.info(f"处理节点: {node_name}")
|
||||
|
||||
# 处理不同Summary节点的返回结构
|
||||
if 'Summary' in node_name:
|
||||
@@ -557,6 +547,11 @@ class MemoryAgentService:
|
||||
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
|
||||
_intermediate_outputs.extend(retrieve_node)
|
||||
|
||||
# Perceptual_Retrieve 节点
|
||||
perceptual_node = node_data.get('perceptual_data', {}).get('_intermediate', None)
|
||||
if perceptual_node and perceptual_node != [] and perceptual_node != {}:
|
||||
_intermediate_outputs.append(perceptual_node)
|
||||
|
||||
# Verify 节点
|
||||
verify_n = node_data.get('verify', {}).get('_intermediate', None)
|
||||
if verify_n and verify_n != [] and verify_n != {}:
|
||||
@@ -633,15 +628,15 @@ class MemoryAgentService:
|
||||
total_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
end_user_id=end_user_id,
|
||||
success=True,
|
||||
duration=duration
|
||||
)
|
||||
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
end_user_id=end_user_id,
|
||||
success=True,
|
||||
duration=duration
|
||||
)
|
||||
|
||||
return {
|
||||
"answer": summary,
|
||||
@@ -651,16 +646,16 @@ class MemoryAgentService:
|
||||
# Ensure proper error handling and logging
|
||||
error_msg = f"Read operation failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
end_user_id=end_user_id,
|
||||
success=False,
|
||||
duration=duration,
|
||||
error=error_msg
|
||||
)
|
||||
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
end_user_id=end_user_id,
|
||||
success=False,
|
||||
duration=duration,
|
||||
error=error_msg
|
||||
)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
|
||||
|
||||
@@ -280,6 +280,53 @@ class MemoryAPIService:
|
||||
code=BizCode.MEMORY_READ_FAILED
|
||||
)
|
||||
|
||||
def create_end_user(
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
other_id: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Create or retrieve an end user for the workspace.
|
||||
|
||||
Uses get_or_create semantics: if an end user with the same other_id
|
||||
already exists in the workspace, returns the existing one.
|
||||
|
||||
Args:
|
||||
workspace_id: Workspace ID from API key authorization
|
||||
other_id: External user identifier
|
||||
|
||||
Returns:
|
||||
Dict with id, other_id, other_name, and workspace_id
|
||||
|
||||
Raises:
|
||||
BusinessException: If creation fails
|
||||
"""
|
||||
logger.info(f"Creating end user - other_id: {other_id}, workspace_id: {workspace_id}")
|
||||
|
||||
try:
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
|
||||
end_user_repo = EndUserRepository(self.db)
|
||||
end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=None,
|
||||
workspace_id=workspace_id,
|
||||
other_id=other_id,
|
||||
)
|
||||
|
||||
logger.info(f"End user ready: {end_user.id}")
|
||||
return {
|
||||
"id": str(end_user.id),
|
||||
"other_id": end_user.other_id or "",
|
||||
"other_name": end_user.other_name or "",
|
||||
"workspace_id": str(end_user.workspace_id),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create end user for workspace {workspace_id}: {e}")
|
||||
raise BusinessException(
|
||||
message=f"Failed to create end user: {str(e)}",
|
||||
code=BizCode.INTERNAL_ERROR
|
||||
)
|
||||
|
||||
def list_memory_configs(
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
from sqlalchemy import desc, nullslast, or_, and_, cast, String
|
||||
from typing import List, Optional, Dict, Any
|
||||
import uuid
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.models.user_model import User
|
||||
from app.models.app_model import App
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.end_user_model import EndUser, EndUser as EndUserModel
|
||||
from app.models.memory_increment_model import MemoryIncrement
|
||||
|
||||
from app.repositories import (
|
||||
@@ -49,44 +50,40 @@ def get_current_workspace_type(
|
||||
|
||||
|
||||
def get_workspace_end_users(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
current_user: User
|
||||
) -> List[EndUser]:
|
||||
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)
|
||||
|
||||
返回结果按 created_at 从新到旧排序(NULL 值排在最后)
|
||||
"""
|
||||
business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
|
||||
try:
|
||||
# 查询应用(ORM)
|
||||
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
|
||||
|
||||
|
||||
if not apps_orm:
|
||||
business_logger.info("工作空间下没有应用")
|
||||
return []
|
||||
|
||||
|
||||
# 提取所有 app_id
|
||||
# app_ids = [app.id for app in apps_orm]
|
||||
|
||||
# 批量查询所有 end_users(一次查询而非循环查询)
|
||||
# 按 created_at 降序排序,NULL 值排在最后;id 作为次级排序键保证确定性
|
||||
from app.models.end_user_model import EndUser as EndUserModel
|
||||
from sqlalchemy import desc, nullslast
|
||||
end_users_orm = db.query(EndUserModel).filter(
|
||||
EndUserModel.workspace_id == workspace_id
|
||||
).order_by(
|
||||
nullslast(desc(EndUserModel.created_at)),
|
||||
desc(EndUserModel.id)
|
||||
).all()
|
||||
|
||||
|
||||
# 转换为 Pydantic 模型(只在需要时转换)
|
||||
end_users = [EndUserSchema.model_validate(eu) for eu in end_users_orm]
|
||||
|
||||
|
||||
business_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
||||
return end_users
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -94,6 +91,85 @@ def get_workspace_end_users(
|
||||
raise
|
||||
|
||||
|
||||
def get_workspace_end_users_paginated(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
current_user: User,
|
||||
page: int,
|
||||
pagesize: int,
|
||||
keyword: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""获取工作空间的宿主列表(分页版本,支持模糊搜索)
|
||||
|
||||
返回结果按 created_at 从新到旧排序(NULL 值排在最后)
|
||||
支持通过 keyword 参数同时模糊搜索 other_name 和 id 字段
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
workspace_id: 工作空间ID
|
||||
current_user: 当前用户
|
||||
page: 页码(从1开始)
|
||||
pagesize: 每页数量
|
||||
keyword: 搜索关键词(可选,同时模糊匹配 other_name 和 id)
|
||||
|
||||
Returns:
|
||||
dict: 包含 items(宿主列表)和 total(总记录数)的字典
|
||||
"""
|
||||
business_logger.info(f"获取工作空间宿主列表(分页): workspace_id={workspace_id}, keyword={keyword}, page={page}, pagesize={pagesize}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 构建基础查询
|
||||
base_query = db.query(EndUserModel).filter(
|
||||
EndUserModel.workspace_id == workspace_id
|
||||
)
|
||||
|
||||
# 构建搜索条件(过滤空字符串和None)
|
||||
keyword = keyword.strip() if keyword else None
|
||||
|
||||
if keyword:
|
||||
keyword_pattern = f"%{keyword}%"
|
||||
# other_name 匹配始终生效;id 匹配仅对 other_name 为空的记录生效
|
||||
base_query = base_query.filter(
|
||||
or_(
|
||||
EndUserModel.other_name.ilike(keyword_pattern),
|
||||
and_(
|
||||
or_(
|
||||
EndUserModel.other_name.is_(None),
|
||||
EndUserModel.other_name == "",
|
||||
),
|
||||
cast(EndUserModel.id, String).ilike(keyword_pattern),
|
||||
),
|
||||
)
|
||||
)
|
||||
business_logger.info(f"应用模糊搜索: keyword={keyword}(匹配 other_name;other_name 为空时匹配 id)")
|
||||
|
||||
# 获取总记录数
|
||||
total = base_query.count()
|
||||
|
||||
if total == 0:
|
||||
business_logger.info("工作空间下没有宿主")
|
||||
return {"items": [], "total": 0}
|
||||
|
||||
# 分页查询
|
||||
# 按 created_at 降序排序,NULL 值排在最后;id 作为次级排序键保证确定性
|
||||
end_users_orm = base_query.order_by(
|
||||
nullslast(desc(EndUserModel.created_at)),
|
||||
desc(EndUserModel.id)
|
||||
).offset((page - 1) * pagesize).limit(pagesize).all()
|
||||
|
||||
# 转换为 Pydantic 模型
|
||||
end_users = [EndUserSchema.model_validate(eu) for eu in end_users_orm]
|
||||
|
||||
business_logger.info(f"成功获取 {len(end_users)} 个宿主记录,总计 {total} 条")
|
||||
return {"items": end_users, "total": total}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取工作空间宿主列表(分页)失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_workspace_memory_increment(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
@@ -277,15 +353,13 @@ async def get_workspace_total_memory_count(
|
||||
"details": []
|
||||
}
|
||||
|
||||
# 2. 对每个 host_id 调用 search_all 获取 total
|
||||
# 2. 使用 search_all_batch 批量查询所有宿主的记忆数量
|
||||
from app.services import memory_storage_service
|
||||
|
||||
total_count = 0
|
||||
details = []
|
||||
|
||||
# 如果提供了 end_user_id,只查询该用户
|
||||
if end_user_id:
|
||||
search_result = await memory_storage_service.search_all(end_user_id=end_user_id)
|
||||
batch_result = await memory_storage_service.search_all_batch([end_user_id])
|
||||
count = batch_result.get(end_user_id, 0)
|
||||
# 查询用户名称
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
repo = EndUserRepository(db)
|
||||
@@ -293,42 +367,31 @@ async def get_workspace_total_memory_count(
|
||||
user_name = end_user.other_name if end_user else None
|
||||
|
||||
return {
|
||||
"total_memory_count": search_result.get("total", 0),
|
||||
"total_memory_count": count,
|
||||
"host_count": 1,
|
||||
"details": [{
|
||||
"end_user_id": end_user_id,
|
||||
"count": search_result.get("total", 0),
|
||||
"count": count,
|
||||
"name": user_name
|
||||
}]
|
||||
}
|
||||
|
||||
for host in hosts:
|
||||
try:
|
||||
end_user_id_str = str(host.id)
|
||||
|
||||
search_result = await memory_storage_service.search_all(
|
||||
end_user_id=end_user_id_str
|
||||
)
|
||||
|
||||
host_total = search_result.get("total", 0)
|
||||
total_count += host_total
|
||||
|
||||
details.append({
|
||||
"end_user_id": end_user_id_str,
|
||||
"count": host_total,
|
||||
"name": host.other_name # 使用 other_name 字段
|
||||
})
|
||||
|
||||
business_logger.debug(f"EndUser {end_user_id_str} ({host.other_name}) 记忆数: {host_total}")
|
||||
|
||||
except Exception as e:
|
||||
business_logger.warning(f"获取 end_user {host.id} 记忆数失败: {str(e)}")
|
||||
# 失败的 host 记为 0
|
||||
details.append({
|
||||
"end_user_id": str(host.id),
|
||||
"count": 0,
|
||||
"name": host.other_name # 使用 other_name 字段
|
||||
})
|
||||
# 批量查询所有宿主记忆数量(一次 Neo4j 查询)
|
||||
end_user_ids = [str(host.id) for host in hosts]
|
||||
batch_result = await memory_storage_service.search_all_batch(end_user_ids)
|
||||
|
||||
# 构建 host name 映射
|
||||
host_name_map = {str(host.id): host.other_name for host in hosts}
|
||||
|
||||
total_count = sum(batch_result.values())
|
||||
details = [
|
||||
{
|
||||
"end_user_id": uid,
|
||||
"count": batch_result.get(uid, 0),
|
||||
"name": host_name_map.get(uid)
|
||||
}
|
||||
for uid in end_user_ids
|
||||
]
|
||||
|
||||
result = {
|
||||
"total_memory_count": total_count,
|
||||
@@ -443,6 +506,180 @@ def get_rag_user_kb_total_chunk(
|
||||
business_logger.error(f"获取用户知识库总chunk数失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_dashboard_yesterday_changes(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
storage_type: str,
|
||||
today_data: dict
|
||||
) -> dict:
|
||||
"""
|
||||
计算各指标相比昨天的变化百分比。
|
||||
|
||||
- total_app_change / total_knowledge_change:只看活跃记录,
|
||||
百分比 = (截止今日活跃总量 - 截止昨日活跃总量) / 截止昨日活跃总量
|
||||
- total_memory_change / total_api_call_change:
|
||||
百分比 = (今日总量 - 昨日总量) / 昨日总量
|
||||
|
||||
昨日总量为 0 时返回 None。返回值为浮点数,例如 0.5 表示增长 50%。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
workspace_id: 工作空间ID
|
||||
storage_type: 存储类型 'neo4j' | 'rag'
|
||||
today_data: 当前数据,包含 total_memory, total_app, total_knowledge, total_api_call
|
||||
|
||||
Returns:
|
||||
{
|
||||
"total_memory_change": float | None,
|
||||
"total_app_change": float | None,
|
||||
"total_knowledge_change": float | None,
|
||||
"total_api_call_change": float | None
|
||||
}
|
||||
"""
|
||||
from datetime import datetime
|
||||
from sqlalchemy import func
|
||||
from app.models.api_key_model import ApiKey, ApiKeyLog
|
||||
from app.models.knowledge_model import Knowledge
|
||||
from app.models.app_model import App
|
||||
from app.models.appshare_model import AppShare
|
||||
|
||||
business_logger.info(f"计算昨日对比百分比: workspace_id={workspace_id}, storage_type={storage_type}")
|
||||
|
||||
now_local = datetime.now()
|
||||
today_start = now_local.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
changes = {
|
||||
"total_memory_change": None,
|
||||
"total_app_change": None,
|
||||
"total_knowledge_change": None,
|
||||
"total_api_call_change": None,
|
||||
}
|
||||
|
||||
def _calc_percentage(today_val, yesterday_val):
|
||||
"""计算百分比,昨日为0时返回None"""
|
||||
if yesterday_val is None or yesterday_val == 0:
|
||||
return None
|
||||
return round((today_val - yesterday_val) / yesterday_val, 4)
|
||||
|
||||
# --- total_api_call_change: (截止今日累计总数 - 截止昨日累计总数) / 截止昨日累计总数 ---
|
||||
try:
|
||||
api_key_ids = [
|
||||
row[0] for row in db.query(ApiKey.id).filter(
|
||||
ApiKey.workspace_id == workspace_id
|
||||
).all()
|
||||
]
|
||||
if api_key_ids:
|
||||
# 截止今日的累计调用总数
|
||||
total_api_until_now = db.query(func.count(ApiKeyLog.id)).filter(
|
||||
ApiKeyLog.api_key_id.in_(api_key_ids),
|
||||
ApiKeyLog.created_at < now_local
|
||||
).scalar() or 0
|
||||
# 截止昨日的累计调用总数(today_start 即昨日结束)
|
||||
total_api_until_yesterday = db.query(func.count(ApiKeyLog.id)).filter(
|
||||
ApiKeyLog.api_key_id.in_(api_key_ids),
|
||||
ApiKeyLog.created_at < today_start
|
||||
).scalar() or 0
|
||||
changes["total_api_call_change"] = _calc_percentage(total_api_until_now, total_api_until_yesterday)
|
||||
else:
|
||||
changes["total_api_call_change"] = None
|
||||
except Exception as e:
|
||||
business_logger.warning(f"计算API调用昨日对比失败: {str(e)}")
|
||||
|
||||
# --- total_knowledge_change: 只看活跃(status=1)且为顶层知识库(parent_id=workspace_id),百分比 = (今日活跃总量 - 昨日活跃总量) / 昨日活跃总量 ---
|
||||
try:
|
||||
# 截止今日的活跃知识库总量(当前 status=1,parent_id=workspace_id)
|
||||
today_knowledge = db.query(func.count(Knowledge.id)).filter(
|
||||
Knowledge.workspace_id == workspace_id,
|
||||
Knowledge.status == 1,
|
||||
Knowledge.parent_id == Knowledge.workspace_id
|
||||
).scalar() or 0
|
||||
# 截止昨日的活跃知识库总量(昨日之前创建的、当前仍 status=1,parent_id=workspace_id)
|
||||
yesterday_knowledge = db.query(func.count(Knowledge.id)).filter(
|
||||
Knowledge.workspace_id == workspace_id,
|
||||
Knowledge.status == 1,
|
||||
Knowledge.parent_id == Knowledge.workspace_id,
|
||||
Knowledge.created_at < today_start
|
||||
).scalar() or 0
|
||||
|
||||
changes["total_knowledge_change"] = _calc_percentage(today_knowledge, yesterday_knowledge)
|
||||
except Exception as e:
|
||||
business_logger.warning(f"计算知识库昨日对比失败: {str(e)}")
|
||||
|
||||
# --- total_app_change: 只看活跃(is_active=True),百分比 = (今日活跃总量 - 昨日活跃总量) / 昨日活跃总量 ---
|
||||
try:
|
||||
# === 自有app ===
|
||||
today_own_apps = db.query(func.count(App.id)).filter(
|
||||
App.workspace_id == workspace_id,
|
||||
App.is_active == True
|
||||
).scalar() or 0
|
||||
yesterday_own_apps = db.query(func.count(App.id)).filter(
|
||||
App.workspace_id == workspace_id,
|
||||
App.is_active == True,
|
||||
App.created_at < today_start
|
||||
).scalar() or 0
|
||||
|
||||
# === 被分享app ===
|
||||
today_shared_apps = db.query(func.count(AppShare.id)).filter(
|
||||
AppShare.target_workspace_id == workspace_id,
|
||||
AppShare.is_active == True
|
||||
).scalar() or 0
|
||||
yesterday_shared_apps = db.query(func.count(AppShare.id)).filter(
|
||||
AppShare.target_workspace_id == workspace_id,
|
||||
AppShare.is_active == True,
|
||||
AppShare.created_at < today_start
|
||||
).scalar() or 0
|
||||
|
||||
today_total_app = today_own_apps + today_shared_apps
|
||||
yesterday_total_app = yesterday_own_apps + yesterday_shared_apps
|
||||
|
||||
changes["total_app_change"] = _calc_percentage(today_total_app, yesterday_total_app)
|
||||
except Exception as e:
|
||||
business_logger.warning(f"计算应用数量昨日对比失败: {str(e)}")
|
||||
|
||||
# --- total_memory_change: (今日总量 - 昨日总量) / 昨日总量 ---
|
||||
try:
|
||||
today_memory = today_data.get("total_memory")
|
||||
if today_memory is None:
|
||||
changes["total_memory_change"] = None
|
||||
elif storage_type == "neo4j":
|
||||
last_record = db.query(MemoryIncrement).filter(
|
||||
MemoryIncrement.workspace_id == workspace_id,
|
||||
MemoryIncrement.created_at < today_start
|
||||
).order_by(desc(MemoryIncrement.created_at)).first()
|
||||
if last_record is None or last_record.total_num == 0:
|
||||
changes["total_memory_change"] = None
|
||||
else:
|
||||
changes["total_memory_change"] = _calc_percentage(today_memory, last_record.total_num)
|
||||
elif storage_type == "rag":
|
||||
from app.models.document_model import Document
|
||||
from app.models.end_user_model import EndUser as _EndUser
|
||||
from app.models.app_model import App as _App
|
||||
|
||||
end_user_ids = [
|
||||
str(eid) for (eid,) in db.query(_EndUser.id)
|
||||
.join(_App, _EndUser.app_id == _App.id)
|
||||
.filter(_App.workspace_id == workspace_id)
|
||||
.all()
|
||||
]
|
||||
if not end_user_ids:
|
||||
changes["total_memory_change"] = None
|
||||
else:
|
||||
file_names = [f"{uid}.txt" for uid in end_user_ids]
|
||||
yesterday_chunk = int(db.query(func.sum(Document.chunk_num)).filter(
|
||||
Document.file_name.in_(file_names),
|
||||
Document.created_at < today_start
|
||||
).scalar() or 0)
|
||||
if yesterday_chunk == 0:
|
||||
changes["total_memory_change"] = None
|
||||
else:
|
||||
changes["total_memory_change"] = _calc_percentage(today_memory, yesterday_chunk)
|
||||
except Exception as e:
|
||||
business_logger.warning(f"计算记忆总量昨日对比失败: {str(e)}")
|
||||
|
||||
business_logger.info(f"昨日对比百分比计算完成: {changes}")
|
||||
return changes
|
||||
|
||||
|
||||
def get_current_user_total_chunk(
|
||||
end_user_id: str,
|
||||
db: Session,
|
||||
@@ -638,7 +875,24 @@ def get_rag_content(
|
||||
business_logger.error(f"获取文档 {document.id} 的chunks失败: {str(e)}")
|
||||
continue
|
||||
|
||||
# 4. 返回结果
|
||||
# 4. 将所有 page_content 拼接后按角色分割为对话列表
|
||||
merged_text = "\n".join(page_contents)
|
||||
conversations = []
|
||||
if merged_text.strip():
|
||||
import re
|
||||
# 在任意位置匹配 "user:" 或 "assistant:",不限于行首
|
||||
parts = re.split(r'(user|assistant):', merged_text)
|
||||
# parts 结构: ['', 'user', ' content...', 'assistant', ' content...', ...]
|
||||
i = 1
|
||||
while i < len(parts) - 1:
|
||||
role = parts[i].strip()
|
||||
content = parts[i + 1].strip()
|
||||
# 将 content 中的 \n 还原为真实换行
|
||||
content = content.replace("\\n", "\n")
|
||||
if role in ("user", "assistant") and content:
|
||||
conversations.append({"role": role, "content": content})
|
||||
i += 2
|
||||
|
||||
result = {
|
||||
"page": {
|
||||
"page": page,
|
||||
@@ -646,10 +900,10 @@ def get_rag_content(
|
||||
"total": global_total,
|
||||
"hasnext": offset_end < global_total,
|
||||
},
|
||||
"items": page_contents
|
||||
"items": conversations
|
||||
}
|
||||
|
||||
business_logger.info(f"成功获取RAG内容: total={global_total}, page={page}, 返回={len(page_contents)} 条")
|
||||
business_logger.info(f"成功获取RAG内容: total={global_total}, page={page}, 返回={len(conversations)} 条对话")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
@@ -788,4 +1042,60 @@ async def generate_rag_profile(
|
||||
"tags_count": len(tags),
|
||||
"personas_count": len(personas),
|
||||
"insight_generated": bool(insight_sections.get("memory_insight")),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def get_dashboard_common_stats(db: Session, workspace_id) -> dict:
|
||||
"""
|
||||
获取 dashboard 中 neo4j/rag 分支共享的统计数据:
|
||||
total_app、total_knowledge、total_api_call
|
||||
|
||||
Returns:
|
||||
dict: {"total_app": int, "total_knowledge": int, "total_api_call": int}
|
||||
"""
|
||||
result = {"total_app": 0, "total_knowledge": 0, "total_api_call": 0}
|
||||
|
||||
# total_app: 统计当前空间下的所有app数量(包含自有 + 被分享给本工作空间的app)
|
||||
try:
|
||||
from app.services import app_service as _app_svc
|
||||
_, total_app = _app_svc.AppService(db).list_apps(
|
||||
workspace_id=workspace_id, include_shared=True, pagesize=1
|
||||
)
|
||||
result["total_app"] = total_app
|
||||
except Exception as e:
|
||||
business_logger.warning(f"获取应用数量失败: {e}")
|
||||
|
||||
# total_knowledge: 统计顶层知识库(parent_id = workspace_id)
|
||||
try:
|
||||
from sqlalchemy import func as _func
|
||||
from app.models.knowledge_model import Knowledge as _Knowledge
|
||||
total_knowledge = db.query(_func.count(_Knowledge.id)).filter(
|
||||
_Knowledge.workspace_id == workspace_id,
|
||||
_Knowledge.status == 1,
|
||||
_Knowledge.parent_id == _Knowledge.workspace_id
|
||||
).scalar() or 0
|
||||
result["total_knowledge"] = total_knowledge
|
||||
except Exception as e:
|
||||
business_logger.warning(f"获取知识库数量失败: {e}")
|
||||
|
||||
# total_api_call: 截止当前的历史累计调用总数
|
||||
try:
|
||||
from sqlalchemy import func as _api_func
|
||||
from app.models.api_key_model import ApiKey as _ApiKey, ApiKeyLog as _ApiKeyLog
|
||||
|
||||
_api_key_ids = [
|
||||
row[0] for row in db.query(_ApiKey.id).filter(
|
||||
_ApiKey.workspace_id == workspace_id
|
||||
).all()
|
||||
]
|
||||
if _api_key_ids:
|
||||
total_api_calls = db.query(_api_func.count(_ApiKeyLog.id)).filter(
|
||||
_ApiKeyLog.api_key_id.in_(_api_key_ids)
|
||||
).scalar() or 0
|
||||
else:
|
||||
total_api_calls = 0
|
||||
result["total_api_call"] = total_api_calls
|
||||
except Exception as e:
|
||||
business_logger.warning(f"获取API调用统计失败: {e}")
|
||||
|
||||
return result
|
||||
|
||||
@@ -232,7 +232,8 @@ class MemoryPerceptualService:
|
||||
provider=model_config.provider,
|
||||
api_key=model_config.api_key,
|
||||
base_url=model_config.api_base,
|
||||
is_omni=model_config.is_omni
|
||||
is_omni=model_config.is_omni,
|
||||
support_thinking="thinking" in (model_config.capability or []),
|
||||
)
|
||||
)
|
||||
return llm, model_config
|
||||
@@ -243,28 +244,9 @@ class MemoryPerceptualService:
|
||||
memory_config: MemoryConfig,
|
||||
file: FileInput
|
||||
):
|
||||
memories = self.repository.get_by_url(file.url)
|
||||
if memories:
|
||||
business_logger.info(f"Perceptual memory already exists: {file.url}")
|
||||
if end_user_id not in [memory.end_user_id for memory in memories]:
|
||||
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
|
||||
memory_cache = memories[0]
|
||||
memory = self.repository.create_perceptual_memory(
|
||||
end_user_id=uuid.UUID(end_user_id),
|
||||
perceptual_type=PerceptualType(memory_cache.perceptual_type),
|
||||
file_path=memory_cache.file_path,
|
||||
file_name=memory_cache.file_name,
|
||||
file_ext=memory_cache.file_ext,
|
||||
summary=memory_cache.summary,
|
||||
meta_data=memory_cache.meta_data
|
||||
)
|
||||
self.db.commit()
|
||||
return memory
|
||||
else:
|
||||
for memory in memories:
|
||||
if memory.end_user_id == uuid.UUID(end_user_id):
|
||||
return memory
|
||||
llm, model_config = self._get_mutlimodal_client(file.type, memory_config)
|
||||
if model_config is None or llm is None:
|
||||
return None
|
||||
multimodel_service = MultimodalService(self.db, ModelInfo(
|
||||
model_name=model_config.model_name,
|
||||
provider=model_config.provider,
|
||||
@@ -286,15 +268,20 @@ class MemoryPerceptualService:
|
||||
with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
|
||||
opt_system_prompt = f.read()
|
||||
rendered_system_message = Template(opt_system_prompt).render(file_type=file.type, language='zh')
|
||||
except FileNotFoundError:
|
||||
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
|
||||
except FileNotFoundError as e:
|
||||
business_logger.error(f"Failed to generate perceptual memory: {str(e)}")
|
||||
return None
|
||||
messages = [
|
||||
{"role": RoleType.SYSTEM.value, "content": [{"type": "text", "text": rendered_system_message}]},
|
||||
{"role": RoleType.USER.value, "content": [
|
||||
{"type": "text", "text": "Summarize the following file"}, file_message
|
||||
]}
|
||||
]
|
||||
result = await llm.ainvoke(messages)
|
||||
try:
|
||||
result = await llm.ainvoke(messages)
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to generate perceptual memory: {str(e)}")
|
||||
return None
|
||||
content = result.content
|
||||
final_output = ""
|
||||
if isinstance(content, list):
|
||||
|
||||
@@ -613,37 +613,6 @@ async def search_entity(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
return data
|
||||
|
||||
|
||||
async def search_all(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
result = await _neo4j_connector.execute_query(
|
||||
MemoryConfigRepository.SEARCH_FOR_ALL,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
|
||||
# 检查结果是否为空或长度不足
|
||||
if not result or len(result) < 4:
|
||||
data = {
|
||||
"total": 0,
|
||||
"counts": {
|
||||
"dialogue": 0,
|
||||
"chunk": 0,
|
||||
"statement": 0,
|
||||
"entity": 0,
|
||||
},
|
||||
}
|
||||
return data
|
||||
|
||||
data = {
|
||||
"total": result[-1]["Count"],
|
||||
"counts": {
|
||||
"dialogue": result[0]["Count"],
|
||||
"chunk": result[1]["Count"],
|
||||
"statement": result[2]["Count"],
|
||||
"entity": result[3]["Count"],
|
||||
},
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""统一知识库类型分布接口。
|
||||
|
||||
@@ -695,6 +664,37 @@ async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]
|
||||
return result
|
||||
|
||||
|
||||
async def search_all_batch(end_user_ids: List[str]) -> Dict[str, int]:
|
||||
"""批量查询多个用户的记忆数量(简化版本,只返回total)
|
||||
|
||||
Args:
|
||||
end_user_ids: 用户ID列表
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: 以user_id为key的记忆数量字典
|
||||
格式: {"user_id": total_count}
|
||||
"""
|
||||
if not end_user_ids:
|
||||
return {}
|
||||
|
||||
result = await _neo4j_connector.execute_query(
|
||||
MemoryConfigRepository.SEARCH_FOR_ALL_BATCH,
|
||||
end_user_ids=end_user_ids,
|
||||
)
|
||||
|
||||
# 转换结果为字典格式,字典格式在查询中无需遍历结果集,直接返回
|
||||
data = {}
|
||||
for row in result:
|
||||
data[row["user_id"]] = row["total"]
|
||||
|
||||
# 为没有数据的用户填充默认值,转换字典格式还为无数据填充默认值
|
||||
for user_id in end_user_ids:
|
||||
if user_id not in data:
|
||||
data[user_id] = 0
|
||||
|
||||
return data
|
||||
|
||||
|
||||
async def analytics_hot_memory_tags(
|
||||
db: Session,
|
||||
current_user: User,
|
||||
|
||||
@@ -45,12 +45,20 @@ class ModelParameterMerger:
|
||||
"frequency_penalty": 0.0,
|
||||
"presence_penalty": 0.0,
|
||||
"n": 1,
|
||||
"stop": None
|
||||
"stop": None,
|
||||
"deep_thinking": False,
|
||||
"thinking_budget_tokens": None
|
||||
}
|
||||
|
||||
# 合并参数:默认值 -> 模型配置 -> Agent 配置
|
||||
merged = default_params.copy()
|
||||
|
||||
# Pydantic 对象转为 dict
|
||||
if model_config_params and hasattr(model_config_params, 'model_dump'):
|
||||
model_config_params = model_config_params.model_dump()
|
||||
if agent_config_params and hasattr(agent_config_params, 'model_dump'):
|
||||
agent_config_params = agent_config_params.model_dump()
|
||||
|
||||
# 应用模型配置参数
|
||||
if model_config_params:
|
||||
for key in default_params:
|
||||
|
||||
@@ -69,7 +69,8 @@ class ModelConfigService:
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
def get_model_by_name(db: Session, name: str, provider: str | None = None, tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||
def get_model_by_name(db: Session, name: str, provider: str | None = None,
|
||||
tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||
"""根据名称获取模型配置"""
|
||||
model = ModelConfigRepository.get_by_name(db, name, provider=provider, tenant_id=tenant_id)
|
||||
if not model:
|
||||
@@ -77,7 +78,8 @@ class ModelConfigService:
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[ModelConfig]:
|
||||
def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[
|
||||
ModelConfig]:
|
||||
"""按名称模糊匹配获取模型配置列表"""
|
||||
return ModelConfigRepository.search_by_name(db, name, tenant_id=tenant_id, limit=limit)
|
||||
|
||||
@@ -91,7 +93,8 @@ class ModelConfigService:
|
||||
api_base: Optional[str] = None,
|
||||
model_type: str = "llm",
|
||||
test_message: str = "Hello",
|
||||
is_omni: bool = False
|
||||
is_omni: bool = False,
|
||||
capability: Optional[list] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""验证模型配置是否有效
|
||||
|
||||
@@ -122,6 +125,7 @@ class ModelConfigService:
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
is_omni=is_omni,
|
||||
support_thinking="thinking" in (capability or []),
|
||||
temperature=0.7,
|
||||
max_tokens=100
|
||||
)
|
||||
@@ -158,13 +162,13 @@ class ModelConfigService:
|
||||
# 统一使用 RedBearEmbeddings(自动支持火山引擎多模态)
|
||||
embedding = RedBearEmbeddings(model_config)
|
||||
test_texts = [test_message, "测试文本"]
|
||||
|
||||
|
||||
# 火山引擎使用 embed_batch,其他使用 embed_documents
|
||||
if provider.lower() == "volcano":
|
||||
vectors = await asyncio.to_thread(embedding.embed_batch, test_texts)
|
||||
else:
|
||||
vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
|
||||
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
return {
|
||||
@@ -200,11 +204,11 @@ class ModelConfigService:
|
||||
},
|
||||
"error": None
|
||||
}
|
||||
|
||||
|
||||
elif model_type_lower == "image":
|
||||
# 图片生成模型验证
|
||||
from app.core.models.generation import RedBearImageGenerator
|
||||
|
||||
|
||||
generator = RedBearImageGenerator(model_config)
|
||||
result = await generator.agenerate(
|
||||
prompt="a cute panda",
|
||||
@@ -212,7 +216,7 @@ class ModelConfigService:
|
||||
)
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.info(f"成功生成图片,结果: {result}")
|
||||
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"message": "图片生成模型配置验证成功",
|
||||
@@ -224,21 +228,21 @@ class ModelConfigService:
|
||||
},
|
||||
"error": None
|
||||
}
|
||||
|
||||
|
||||
elif model_type_lower == "video":
|
||||
# 视频生成模型验证
|
||||
from app.core.models.generation import RedBearVideoGenerator
|
||||
|
||||
|
||||
generator = RedBearVideoGenerator(model_config)
|
||||
result = await generator.agenerate(
|
||||
prompt="a cute panda playing in bamboo forest",
|
||||
duration=5
|
||||
)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
|
||||
# 视频生成是异步任务,返回任务ID
|
||||
task_id = result.get("task_id") if isinstance(result, dict) else None
|
||||
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"message": "视频生成模型配置验证成功",
|
||||
@@ -265,7 +269,6 @@ class ModelConfigService:
|
||||
# 提取详细的错误信息
|
||||
error_message = str(e)
|
||||
error_type = type(e).__name__
|
||||
print("=========error_message:",error_message.lower())
|
||||
# 特殊处理常见的错误类型
|
||||
if "unsupported countries" in error_message.lower() or "unsupported region" in error_message.lower():
|
||||
# 区域/国家限制(适用于所有提供商)
|
||||
@@ -319,7 +322,8 @@ class ModelConfigService:
|
||||
api_base=api_key_data.api_base,
|
||||
model_type=model_data.type,
|
||||
test_message="Hello",
|
||||
is_omni=model_data.is_omni
|
||||
is_omni=model_data.is_omni,
|
||||
capability=model_data.capability
|
||||
)
|
||||
if not validation_result["valid"]:
|
||||
raise BusinessException(
|
||||
@@ -354,14 +358,16 @@ class ModelConfigService:
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate, tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||
def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate,
|
||||
tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||
"""更新模型配置"""
|
||||
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
||||
if not existing_model:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
if model_data.name and model_data.name != existing_model.name:
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id):
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider,
|
||||
tenant_id=tenant_id):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id)
|
||||
@@ -370,25 +376,27 @@ class ModelConfigService:
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
|
||||
async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate,
|
||||
tenant_id: uuid.UUID) -> ModelConfig:
|
||||
"""创建组合模型"""
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE, tenant_id=tenant_id):
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE,
|
||||
tenant_id=tenant_id):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
|
||||
# 验证所有 API Key 存在且类型匹配
|
||||
for api_key_id in model_data.api_key_ids:
|
||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if not api_key:
|
||||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
|
||||
|
||||
|
||||
# 检查 API Key 关联的模型配置类型
|
||||
for model_config in api_key.model_configs:
|
||||
# chat 和 llm 类型可以兼容
|
||||
compatible_types = {ModelType.LLM, ModelType.CHAT}
|
||||
config_type = model_config.type
|
||||
request_type = model_data.type
|
||||
|
||||
if not (config_type == request_type or
|
||||
|
||||
if not (config_type == request_type or
|
||||
(config_type in compatible_types and request_type in compatible_types)):
|
||||
raise BusinessException(
|
||||
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
|
||||
@@ -399,7 +407,7 @@ class ModelConfigService:
|
||||
# f"API Key {api_key_id} 关联的模型是组合模型,不能用于创建新的组合模型",
|
||||
# BizCode.INVALID_PARAMETER
|
||||
# )
|
||||
|
||||
|
||||
# 创建组合模型
|
||||
model_config_data = {
|
||||
"tenant_id": tenant_id,
|
||||
@@ -418,49 +426,51 @@ class ModelConfigService:
|
||||
|
||||
model = ModelConfigRepository.create(db, model_config_data)
|
||||
db.flush()
|
||||
|
||||
|
||||
# 关联 API Keys
|
||||
for api_key_id in model_data.api_key_ids:
|
||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if api_key:
|
||||
model.api_keys.append(api_key)
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(model)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
|
||||
async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate,
|
||||
tenant_id: uuid.UUID) -> ModelConfig:
|
||||
"""更新组合模型"""
|
||||
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
||||
if not existing_model:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
if model_data.name and model_data.name != existing_model.name:
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id):
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider,
|
||||
tenant_id=tenant_id):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
|
||||
if not existing_model.is_composite:
|
||||
raise BusinessException("该模型不是组合模型", BizCode.INVALID_PARAMETER)
|
||||
|
||||
|
||||
# 验证所有 API Key 存在且类型匹配
|
||||
for api_key_id in model_data.api_key_ids:
|
||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if not api_key:
|
||||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
|
||||
|
||||
|
||||
for model_config in api_key.model_configs:
|
||||
compatible_types = {ModelType.LLM, ModelType.CHAT}
|
||||
config_type = model_config.type
|
||||
request_type = existing_model.type
|
||||
|
||||
if not (config_type == request_type or
|
||||
|
||||
if not (config_type == request_type or
|
||||
(config_type in compatible_types and request_type in compatible_types)):
|
||||
raise BusinessException(
|
||||
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
|
||||
|
||||
# 更新基本信息
|
||||
existing_model.name = model_data.name
|
||||
# existing_model.type = model_data.type
|
||||
@@ -471,14 +481,14 @@ class ModelConfigService:
|
||||
existing_model.is_public = model_data.is_public
|
||||
if "load_balance_strategy" in model_data.model_fields_set:
|
||||
existing_model.load_balance_strategy = model_data.load_balance_strategy
|
||||
|
||||
|
||||
# 更新 API Keys 关联
|
||||
existing_model.api_keys.clear()
|
||||
for api_key_id in model_data.api_key_ids:
|
||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if api_key:
|
||||
existing_model.api_keys.append(api_key)
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(existing_model)
|
||||
return existing_model
|
||||
@@ -532,7 +542,7 @@ class ModelApiKeyService:
|
||||
"""根据provider为多个ModelConfig创建API Key"""
|
||||
created_keys = []
|
||||
failed_models = [] # 记录验证失败的模型
|
||||
|
||||
|
||||
for model_config_id in data.model_config_ids:
|
||||
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
||||
if not model_config:
|
||||
@@ -540,10 +550,10 @@ class ModelApiKeyService:
|
||||
|
||||
data.is_omni = model_config.is_omni
|
||||
data.capability = model_config.capability
|
||||
|
||||
|
||||
# 从ModelBase获取model_name
|
||||
model_name = model_config.model_base.name if model_config.model_base else model_config.name
|
||||
|
||||
|
||||
# 检查是否存在API Key(包括软删除),需要考虑tenant_id
|
||||
existing_key = db.query(ModelApiKey).join(
|
||||
ModelApiKey.model_configs
|
||||
@@ -553,7 +563,7 @@ class ModelApiKeyService:
|
||||
ModelApiKey.model_name == model_name,
|
||||
ModelConfig.tenant_id == model_config.tenant_id
|
||||
).first()
|
||||
|
||||
|
||||
if existing_key:
|
||||
# 如果已存在,重新激活并更新
|
||||
if existing_key.is_active:
|
||||
@@ -566,14 +576,14 @@ class ModelApiKeyService:
|
||||
existing_key.model_name = model_name
|
||||
existing_key.capability = data.capability
|
||||
existing_key.is_omni = data.is_omni
|
||||
|
||||
|
||||
# 检查是否已关联该模型配置
|
||||
if model_config not in existing_key.model_configs:
|
||||
existing_key.model_configs.append(model_config)
|
||||
|
||||
|
||||
created_keys.append(existing_key)
|
||||
continue
|
||||
|
||||
|
||||
# 验证配置
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
@@ -583,13 +593,14 @@ class ModelApiKeyService:
|
||||
api_base=data.api_base,
|
||||
model_type=model_config.type,
|
||||
test_message="Hello",
|
||||
is_omni=data.is_omni
|
||||
is_omni=data.is_omni,
|
||||
capability=model_config.capability
|
||||
)
|
||||
if not validation_result["valid"]:
|
||||
# 记录验证失败的模型,但不抛出异常
|
||||
failed_models.append(model_name)
|
||||
continue
|
||||
|
||||
|
||||
# 创建API Key
|
||||
api_key_data = ModelApiKeyCreate(
|
||||
model_config_ids=[model_config_id],
|
||||
@@ -606,12 +617,12 @@ class ModelApiKeyService:
|
||||
)
|
||||
api_key_obj = ModelApiKeyRepository.create(db, api_key_data)
|
||||
created_keys.append(api_key_obj)
|
||||
|
||||
|
||||
if created_keys:
|
||||
db.commit()
|
||||
for key in created_keys:
|
||||
db.refresh(key)
|
||||
|
||||
|
||||
return created_keys, failed_models
|
||||
|
||||
@staticmethod
|
||||
@@ -626,7 +637,7 @@ class ModelApiKeyService:
|
||||
api_key_data.is_omni = model_config.is_omni
|
||||
if api_key_data.capability is None:
|
||||
api_key_data.capability = model_config.capability
|
||||
|
||||
|
||||
# 检查API Key是否已存在(包括软删除),需要考虑tenant_id
|
||||
existing_key = db.query(ModelApiKey).join(
|
||||
ModelApiKey.model_configs
|
||||
@@ -650,15 +661,15 @@ class ModelApiKeyService:
|
||||
existing_key.model_name = api_key_data.model_name
|
||||
existing_key.capability = api_key_data.capability
|
||||
existing_key.is_omni = api_key_data.is_omni
|
||||
|
||||
|
||||
# 检查是否已关联该模型配置
|
||||
if model_config not in existing_key.model_configs:
|
||||
existing_key.model_configs.append(model_config)
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(existing_key)
|
||||
return existing_key
|
||||
|
||||
|
||||
# 验证配置
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
@@ -668,7 +679,8 @@ class ModelApiKeyService:
|
||||
api_base=api_key_data.api_base,
|
||||
model_type=model_config.type,
|
||||
test_message="Hello",
|
||||
is_omni=api_key_data.is_omni
|
||||
is_omni=api_key_data.is_omni,
|
||||
capability=model_config.capability
|
||||
)
|
||||
if not validation_result["valid"]:
|
||||
raise BusinessException(
|
||||
@@ -691,7 +703,7 @@ class ModelApiKeyService:
|
||||
# 获取关联的模型配置以获取模型类型
|
||||
if existing_api_key.model_configs:
|
||||
model_config = existing_api_key.model_configs[0]
|
||||
|
||||
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
model_name=api_key_data.model_name or existing_api_key.model_name,
|
||||
@@ -700,7 +712,8 @@ class ModelApiKeyService:
|
||||
api_base=api_key_data.api_base or existing_api_key.api_base,
|
||||
model_type=model_config.type,
|
||||
test_message="Hello",
|
||||
is_omni=model_config.is_omni
|
||||
is_omni=model_config.is_omni,
|
||||
capability=model_config.capability
|
||||
)
|
||||
if not validation_result["valid"]:
|
||||
raise BusinessException(
|
||||
@@ -729,15 +742,15 @@ class ModelApiKeyService:
|
||||
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
||||
if not model_config:
|
||||
return None
|
||||
|
||||
|
||||
api_keys = [key for key in model_config.api_keys if key.is_active]
|
||||
if not api_keys:
|
||||
return None
|
||||
|
||||
|
||||
# 如果是轮询策略,按使用次数最少,次数相同则选最早使用的
|
||||
if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN:
|
||||
return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min))
|
||||
|
||||
|
||||
# 否则返回第一个
|
||||
return api_keys[0]
|
||||
|
||||
@@ -760,20 +773,19 @@ class ModelApiKeyService:
|
||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
|
||||
|
||||
class ModelBaseService:
|
||||
"""基础模型服务"""
|
||||
|
||||
@staticmethod
|
||||
def get_model_base_list(db: Session, query: model_schema.ModelBaseQuery, tenant_id: uuid.UUID = None) -> List:
|
||||
models = ModelBaseRepository.get_list(db, query)
|
||||
|
||||
|
||||
provider_groups = {}
|
||||
for m in models:
|
||||
model_dict = model_schema.ModelBase.model_validate(m).model_dump()
|
||||
if tenant_id:
|
||||
model_dict['is_added'] = ModelBaseRepository.check_added_by_tenant(db, m.id, tenant_id)
|
||||
|
||||
|
||||
provider = m.provider
|
||||
if provider not in provider_groups:
|
||||
provider_groups[provider] = {
|
||||
@@ -781,7 +793,7 @@ class ModelBaseService:
|
||||
"models": []
|
||||
}
|
||||
provider_groups[provider]["models"].append(model_dict)
|
||||
|
||||
|
||||
return list(provider_groups.values())
|
||||
|
||||
@staticmethod
|
||||
@@ -823,10 +835,10 @@ class ModelBaseService:
|
||||
model_base = ModelBaseRepository.get_by_id(db, model_base_id)
|
||||
if not model_base:
|
||||
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
|
||||
if ModelBaseRepository.check_added_by_tenant(db, model_base_id, tenant_id):
|
||||
raise BusinessException("模型已添加", BizCode.DUPLICATE_NAME)
|
||||
|
||||
|
||||
model_config_data = {
|
||||
"model_id": model_base_id,
|
||||
"tenant_id": tenant_id,
|
||||
|
||||
@@ -287,6 +287,11 @@ class MultiAgentOrchestrator:
|
||||
sub_conversation_id = None
|
||||
total_tokens = 0
|
||||
|
||||
# 累加 Master Agent 路由决策消耗的 token
|
||||
total_tokens += task_analysis.get("routing_tokens", 0)
|
||||
# 累加 Master Agent 整合消耗的 token
|
||||
total_tokens += getattr(self, '_last_merge_tokens', 0)
|
||||
|
||||
if isinstance(results, dict):
|
||||
sub_conversation_id = results.get("conversation_id") or results.get("result", {}).get("conversation_id")
|
||||
# 提取 token 信息
|
||||
@@ -358,12 +363,16 @@ class MultiAgentOrchestrator:
|
||||
variables=variables
|
||||
)
|
||||
|
||||
# 获取路由决策消耗的 token
|
||||
routing_tokens = getattr(self.router, '_last_routing_tokens', 0)
|
||||
|
||||
logger.info(
|
||||
"Master Agent 分析完成",
|
||||
extra={
|
||||
"selected_agent": routing_decision.get("selected_agent_id"),
|
||||
"confidence": routing_decision.get("confidence"),
|
||||
"strategy": routing_decision.get("strategy")
|
||||
"strategy": routing_decision.get("strategy"),
|
||||
"routing_tokens": routing_tokens
|
||||
}
|
||||
)
|
||||
|
||||
@@ -372,7 +381,8 @@ class MultiAgentOrchestrator:
|
||||
"variables": variables or {},
|
||||
"sub_agents": self.config.sub_agents,
|
||||
"initial_context": variables or {},
|
||||
"routing_decision": routing_decision
|
||||
"routing_decision": routing_decision,
|
||||
"routing_tokens": routing_tokens
|
||||
}
|
||||
|
||||
async def _execute_sequential(
|
||||
@@ -1032,6 +1042,11 @@ class MultiAgentOrchestrator:
|
||||
|
||||
# 5. 流式执行子 Agent
|
||||
sub_conversation_id = None
|
||||
# Master Agent 路由决策消耗的 token,通过 sub_usage 事件发送给上层
|
||||
routing_tokens = task_analysis.get("routing_tokens", 0)
|
||||
if routing_tokens > 0:
|
||||
yield self._format_sse_event("sub_usage", {"total_tokens": routing_tokens})
|
||||
|
||||
async for event in self._execute_sub_agent_stream(
|
||||
agent_data["config"],
|
||||
message,
|
||||
@@ -1054,6 +1069,7 @@ class MultiAgentOrchestrator:
|
||||
except:
|
||||
pass
|
||||
|
||||
# 直接透传所有事件(包括 sub_usage),累加统一由上层处理
|
||||
yield event
|
||||
|
||||
# 6. 如果有会话 ID,发送一个包含它的事件
|
||||
@@ -2600,6 +2616,7 @@ class MultiAgentOrchestrator:
|
||||
api_key=api_key_config.api_key,
|
||||
base_url=api_key_config.api_base,
|
||||
is_omni=api_key_config.is_omni,
|
||||
support_thinking="thinking" in (api_key_config.capability or []),
|
||||
temperature=0.7, # 整合任务使用中等温度
|
||||
max_tokens=2000
|
||||
)
|
||||
@@ -2612,6 +2629,17 @@ class MultiAgentOrchestrator:
|
||||
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
|
||||
|
||||
# 提取整合消耗的 token
|
||||
merge_tokens = 0
|
||||
if hasattr(response, 'usage_metadata') and response.usage_metadata:
|
||||
um = response.usage_metadata
|
||||
merge_tokens = um.get("total_tokens", 0) if isinstance(um, dict) else getattr(um, "total_tokens", 0)
|
||||
elif hasattr(response, 'response_metadata') and response.response_metadata:
|
||||
token_usage = response.response_metadata.get("token_usage") or response.response_metadata.get("usage", {})
|
||||
if isinstance(token_usage, dict):
|
||||
merge_tokens = token_usage.get("total_tokens", 0)
|
||||
self._last_merge_tokens = merge_tokens
|
||||
|
||||
# 提取响应内容
|
||||
if hasattr(response, 'content'):
|
||||
merged_response = response.content
|
||||
@@ -2621,7 +2649,8 @@ class MultiAgentOrchestrator:
|
||||
logger.info(
|
||||
"Master Agent 整合完成",
|
||||
extra={
|
||||
"merged_length": len(merged_response)
|
||||
"merged_length": len(merged_response),
|
||||
"merge_tokens": merge_tokens
|
||||
}
|
||||
)
|
||||
|
||||
@@ -2766,6 +2795,7 @@ class MultiAgentOrchestrator:
|
||||
api_key=api_key_config.api_key,
|
||||
base_url=api_key_config.api_base,
|
||||
is_omni=api_key_config.is_omni,
|
||||
support_thinking="thinking" in (api_key_config.capability or []),
|
||||
temperature=0.7,
|
||||
max_tokens=2000,
|
||||
extra_params={"streaming": True} # 启用流式输出
|
||||
|
||||
@@ -441,13 +441,13 @@ class MultimodalService:
|
||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||
return True, {
|
||||
"type": "text",
|
||||
"text": f"<document url=\"{file.url}\">\n{await self._extract_document_text(file)}\n</document>"
|
||||
"text": f"<document url=\"{file.url}\">\n{await self.extract_document_text(file)}\n</document>"
|
||||
}
|
||||
else:
|
||||
# 本地文件,提取文本内容
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
file.url = f"{server_url}/storage/permanent/{file.upload_file_id}"
|
||||
text = await self._extract_document_text(file)
|
||||
text = await self.extract_document_text(file)
|
||||
file_metadata = self.db.query(FileMetadata).filter(
|
||||
FileMetadata.id == file.upload_file_id
|
||||
).first()
|
||||
@@ -545,7 +545,7 @@ class MultimodalService:
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
return f"{server_url}/storage/permanent/{file_id}"
|
||||
|
||||
async def _extract_document_text(self, file: FileInput) -> str:
|
||||
async def extract_document_text(self, file: FileInput) -> str:
|
||||
"""
|
||||
提取文档文本内容
|
||||
|
||||
|
||||
@@ -185,7 +185,8 @@ class PromptOptimizerService:
|
||||
provider=api_config.provider,
|
||||
api_key=api_config.api_key,
|
||||
base_url=api_config.api_base,
|
||||
is_omni=api_config.is_omni
|
||||
is_omni=api_config.is_omni,
|
||||
support_thinking="thinking" in (api_config.capability or []),
|
||||
), type=ModelType(model_config.type))
|
||||
try:
|
||||
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
||||
|
||||
@@ -1,26 +1,24 @@
|
||||
"""基于分享链接的聊天服务"""
|
||||
import uuid
|
||||
import time
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional, Dict, Any, AsyncGenerator
|
||||
|
||||
from deprecated import deprecated
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.repositories.model_repository import ModelApiKeyRepository
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.models import MultiAgentConfig
|
||||
from app.models import ReleaseShare, AppRelease, Conversation
|
||||
from app.repositories import knowledge_repository
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.draft_run_service import create_web_search_tool
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.services.multi_agent_service import MultiAgentService
|
||||
from app.models import MultiAgentConfig
|
||||
from app.repositories import knowledge_repository
|
||||
import json
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -118,6 +116,7 @@ class SharedChatService:
|
||||
|
||||
return conversation
|
||||
|
||||
@deprecated("Use the chat method under app_chat_service instead.")
|
||||
async def chat(
|
||||
self,
|
||||
share_token: str,
|
||||
@@ -136,10 +135,7 @@ class SharedChatService:
|
||||
config_id = actual_config_id
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
|
||||
from app.services.model_parameter_merger import ModelParameterMerger
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from sqlalchemy import select
|
||||
from app.models import ModelApiKey
|
||||
|
||||
start_time = time.time()
|
||||
actual_config_id = None
|
||||
@@ -252,7 +248,9 @@ class SharedChatService:
|
||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
|
||||
deep_thinking=model_parameters.get("deep_thinking", False),
|
||||
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
|
||||
capability=api_key_obj.capability or [],
|
||||
)
|
||||
|
||||
# 加载历史消息
|
||||
@@ -273,11 +271,6 @@ class SharedChatService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=None,
|
||||
end_user_id=user_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag
|
||||
)
|
||||
|
||||
# 保存消息
|
||||
@@ -324,6 +317,7 @@ class SharedChatService:
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
@deprecated("Use the chat method under app_chat_service instead.")
|
||||
async def chat_stream(
|
||||
self,
|
||||
share_token: str,
|
||||
@@ -341,8 +335,6 @@ class SharedChatService:
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from sqlalchemy import select
|
||||
from app.models import ModelApiKey
|
||||
import json
|
||||
|
||||
start_time = time.time()
|
||||
@@ -460,7 +452,10 @@ class SharedChatService:
|
||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
streaming=True
|
||||
streaming=True,
|
||||
deep_thinking=model_parameters.get("deep_thinking", False),
|
||||
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
|
||||
capability=api_key_obj.capability or [],
|
||||
)
|
||||
|
||||
# 加载历史消息
|
||||
@@ -486,14 +481,11 @@ class SharedChatService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=None,
|
||||
end_user_id=user_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag
|
||||
):
|
||||
if isinstance(chunk, int):
|
||||
total_tokens = chunk
|
||||
elif isinstance(chunk, dict) and chunk.get("type") == "reasoning":
|
||||
yield f"event: reasoning\ndata: {json.dumps({'content': chunk['content']}, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
full_content += chunk
|
||||
# 发送消息块事件
|
||||
@@ -585,6 +577,7 @@ class SharedChatService:
|
||||
|
||||
return conversations, total
|
||||
|
||||
@deprecated("Use the chat method under app_chat_service instead.")
|
||||
async def multi_agent_chat(
|
||||
self,
|
||||
share_token: str,
|
||||
@@ -680,6 +673,7 @@ class SharedChatService:
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
@deprecated("Use the chat method under app_chat_service instead.")
|
||||
async def multi_agent_chat_stream(
|
||||
self,
|
||||
share_token: str,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user