Merge branch 'release/v0.2.10'

This commit is contained in:
Ke Sun
2026-04-08 21:44:27 +08:00
250 changed files with 9559 additions and 3941 deletions

View File

@@ -14,7 +14,6 @@ from . import (
document_controller,
emotion_config_controller,
emotion_controller,
end_user_controller,
file_controller,
file_storage_controller,
home_page_controller,
@@ -99,6 +98,5 @@ manager_router.include_router(file_storage_controller.router)
manager_router.include_router(ontology_controller.router)
manager_router.include_router(skill_controller.router)
manager_router.include_router(i18n_controller.router)
manager_router.include_router(end_user_controller.router)
__all__ = ["manager_router"]

View File

@@ -292,10 +292,19 @@ def get_opening(
):
"""返回开场白文本和预设问题,供前端对话界面初始化时展示"""
workspace_id = current_user.current_workspace_id
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
features = cfg.features or {}
if hasattr(features, "model_dump"):
features = features.model_dump()
# 根据应用类型获取 features
from app.models.app_model import App as AppModel
app = db.get(AppModel, app_id)
if app and app.type == "workflow":
cfg = app_service.get_workflow_config(db=db, app_id=app_id, workspace_id=workspace_id)
features = cfg.features or {}
else:
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
features = cfg.features or {}
if hasattr(features, "model_dump"):
features = features.model_dump()
opening = features.get("opening_statement", {})
return success(data=app_schema.OpeningResponse(
enabled=opening.get("enabled", False),
@@ -1070,6 +1079,14 @@ async def update_workflow_config(
current_user: Annotated[User, Depends(get_current_user)]
):
workspace_id = current_user.current_workspace_id
if payload.variables:
from app.services.workflow_service import WorkflowService
resolved = await WorkflowService(db)._resolve_variables_file_defaults(
[v.model_dump() for v in payload.variables]
)
# Patch default values back into VariableDefinition objects
for var_def, resolved_def in zip(payload.variables, resolved):
var_def.default = resolved_def.get("default", var_def.default)
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
return success(data=WorkflowConfigSchema.model_validate(cfg))

View File

@@ -53,22 +53,24 @@ async def login_for_access_token(
user = auth_service.authenticate_user_or_raise(db, form_data.email, form_data.password)
auth_logger.info(f"用户认证成功: {user.email} (ID: {user.id})")
if form_data.invite:
auth_service.bind_workspace_with_invite(db=db,
user=user,
invite_token=form_data.invite,
workspace_id=invite_info.workspace_id)
auth_service.bind_workspace_with_invite(
db=db,
user=user,
invite_token=form_data.invite,
workspace_id=invite_info.workspace_id
)
except BusinessException as e:
# 用户不存在且有邀请码,尝试注册
if e.code == BizCode.USER_NOT_FOUND:
auth_logger.info(f"用户不存在,使用邀请码注册: {form_data.email}")
user = auth_service.register_user_with_invite(
db=db,
email=form_data.email,
username=form_data.username,
password=form_data.password,
invite_token=form_data.invite,
workspace_id=invite_info.workspace_id
)
db=db,
email=form_data.email,
username=form_data.username,
password=form_data.password,
invite_token=form_data.invite,
workspace_id=invite_info.workspace_id
)
elif e.code == BizCode.PASSWORD_ERROR:
# 用户存在但密码错误
auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}")

View File

@@ -314,8 +314,10 @@ async def parse_documents(
)
# 4. Check if the file exists
api_logger.debug(f"Constructed file path: {file_path}")
api_logger.debug(f"File metadata - kb_id: {db_file.kb_id}, parent_id: {db_file.parent_id}, file_id: {db_file.id}, extension: {db_file.file_ext}")
if not os.path.exists(file_path):
api_logger.warning(f"File not found (possibly deleted): file_path={file_path}")
api_logger.error(f"File not found (possibly deleted): file_path={file_path}, file_id={db_file.id}, document_id={document_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found (possibly deleted)"

View File

@@ -1,48 +0,0 @@
"""End User 管理接口 - 无需认证"""
from app.core.logging_config import get_business_logger
from app.core.response_utils import success
from app.db import get_db
from app.repositories.end_user_repository import EndUserRepository
from app.schemas.memory_api_schema import (
CreateEndUserRequest,
CreateEndUserResponse,
)
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
router = APIRouter(prefix="/end_users", tags=["End Users"])
logger = get_business_logger()
@router.post("")
async def create_end_user(
data: CreateEndUserRequest,
db: Session = Depends(get_db),
):
"""
Create an end user.
Creates a new end user for the given workspace.
If an end user with the same other_id already exists in the workspace,
returns the existing one.
"""
logger.info(f"Create end user request - other_id: {data.other_id}, workspace_id: {data.workspace_id}")
end_user_repo = EndUserRepository(db)
end_user = end_user_repo.get_or_create_end_user(
app_id=None,
workspace_id=data.workspace_id,
other_id=data.other_id,
)
logger.info(f"End user ready: {end_user.id}")
result = {
"id": str(end_user.id),
"other_id": end_user.other_id or "",
"other_name": end_user.other_name or "",
"workspace_id": str(end_user.workspace_id),
}
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")

View File

@@ -3,9 +3,10 @@ from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.response_utils import success
from app.db import get_db
from app.db import get_db, SessionLocal
from app.dependencies import get_current_user
from app.models.user_model import User
from app.repositories.home_page_repository import HomePageRepository
from app.schemas.response_schema import ApiResponse
from app.services.home_page_service import HomePageService
@@ -31,9 +32,32 @@ def get_workspace_list(
@router.get("/version", response_model=ApiResponse)
def get_system_version():
"""获取系统版本号+说明"""
current_version = settings.SYSTEM_VERSION
version_info = HomePageService.load_version_introduction(current_version)
"""获取系统版本号 + 说明"""
current_version = None
version_info = None
# 1⃣ 优先从数据库获取最新已发布的版本
try:
db = SessionLocal()
try:
current_version, version_info = HomePageRepository.get_latest_version_introduction(db)
finally:
db.close()
except Exception as e:
pass
# 2⃣ 降级:使用环境变量中的版本号
if not current_version:
current_version = settings.SYSTEM_VERSION
version_info = HomePageService.load_version_introduction(current_version)
# 3⃣ 如果数据库和 JSON 都没有,返回基本信息
if not version_info:
version_info = {
"introduction": {"codeName": "", "releaseDate": "", "upgradePosition": "", "coreUpgrades": []},
"introduction_en": {"codeName": "", "releaseDate": "", "upgradePosition": "", "coreUpgrades": []}
}
return success(
data={
"version": current_version,

View File

@@ -352,6 +352,7 @@ async def delete_knowledge(
# 2. Soft-delete knowledge base
api_logger.debug(f"Perform a soft delete: {db_knowledge.name} (ID: {knowledge_id})")
db_knowledge.status = 2
db_knowledge.updated_at = datetime.datetime.now()
db.commit()
api_logger.info(f"The knowledge base has been successfully deleted: {db_knowledge.name} (ID: {knowledge_id})")
return success(msg="The knowledge base has been successfully deleted")

View File

@@ -1,3 +1,5 @@
import asyncio
import uuid
from fastapi import APIRouter, Depends, HTTPException, status, Query
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
@@ -47,64 +49,64 @@ def get_workspace_total_end_users(
@router.get("/end_users", response_model=ApiResponse)
async def get_workspace_end_users(
workspace_id: Optional[uuid.UUID] = Query(None, description="工作空间ID可选默认当前用户工作空间"),
keyword: Optional[str] = Query(None, description="搜索关键词(同时模糊匹配 other_name 和 id"),
page: int = Query(1, ge=1, description="页码从1开始"),
pagesize: int = Query(10, ge=1, description="每页数量"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
获取工作空间的宿主列表(高性能优化版本 v2
优化策略:
1. 批量查询 end_users一次查询而非循环
2. 并发查询所有用户的记忆数量Neo4j
3. RAG 模式使用批量查询(一次 SQL
4. 只返回必要字段减少数据传输
5. 添加短期缓存减少重复查询
6. 并发执行配置查询和记忆数量查询
返回格式:
{
"end_user": {"id": "uuid", "other_name": "名称"},
"memory_num": {"total": 数量},
"memory_config": {"memory_config_id": "id", "memory_config_name": "名称"}
}
获取工作空间的宿主列表(分页查询,支持模糊搜索
返回工作空间下的宿主列表,支持分页查询和模糊搜索。
通过 keyword 参数同时模糊匹配 other_name 和 id 字段。
Args:
workspace_id: 工作空间ID可选默认当前用户工作空间
keyword: 搜索关键词(可选,同时模糊匹配 other_name 和 id
page: 页码从1开始默认1
pagesize: 每页数量默认10
db: 数据库会话
current_user: 当前用户
Returns:
ApiResponse: 包含宿主列表和分页信息
"""
import asyncio
import json
from app.aioRedis import aio_redis_get, aio_redis_set
workspace_id = current_user.current_workspace_id
# 尝试从缓存获取30秒缓存
cache_key = f"end_users:workspace:{workspace_id}"
try:
cached_data = await aio_redis_get(cache_key)
if cached_data:
api_logger.info(f"从缓存获取宿主列表: workspace_id={workspace_id}")
return success(data=json.loads(cached_data), msg="宿主列表获取成功")
except Exception as e:
api_logger.warning(f"Redis 缓存读取失败: {str(e)}")
# 如果未提供 workspace_id使用当前用户的工作空间
if workspace_id is None:
workspace_id = current_user.current_workspace_id
# 获取当前空间类型
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
# 获取 end_users(已优化为批量查询)
end_users = memory_dashboard_service.get_workspace_end_users(
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表, 类型: {current_workspace_type}")
# 获取分页的 end_users
end_users_result = memory_dashboard_service.get_workspace_end_users_paginated(
db=db,
workspace_id=workspace_id,
current_user=current_user
current_user=current_user,
page=page,
pagesize=pagesize,
keyword=keyword
)
end_users = end_users_result.get("items", [])
total = end_users_result.get("total", 0)
if not end_users:
api_logger.info("工作空间下没有宿主")
# 缓存空结果,避免重复查询
try:
await aio_redis_set(cache_key, json.dumps([]), expire=30)
except Exception as e:
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
return success(data=[], msg="宿主列表获取成功")
api_logger.info(f"工作空间下没有宿主或当前页无数据: total={total}, page={page}")
return success(data={
"items": [],
"page": {
"page": page,
"pagesize": pagesize,
"total": total,
"hasnext": (page * pagesize) < total
}
}, msg="宿主列表获取成功")
end_user_ids = [str(user.id) for user in end_users]
# 并发执行两个独立的查询任务
async def get_memory_configs():
"""获取记忆配置(在线程池中执行同步查询)"""
@@ -116,7 +118,7 @@ async def get_workspace_end_users(
except Exception as e:
api_logger.error(f"批量获取记忆配置失败: {str(e)}")
return {}
async def get_memory_nums():
"""获取记忆数量"""
if current_workspace_type == "rag":
@@ -130,26 +132,18 @@ async def get_workspace_end_users(
except Exception as e:
api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}")
return {uid: {"total": 0} for uid in end_user_ids}
elif current_workspace_type == "neo4j":
# Neo4j 模式:并发查询(带并发限制
# 使用信号量限制并发数,避免大量用户时压垮 Neo4j
MAX_CONCURRENT_QUERIES = 10
semaphore = asyncio.Semaphore(MAX_CONCURRENT_QUERIES)
async def get_neo4j_memory_num(end_user_id: str):
async with semaphore:
try:
return await memory_storage_service.search_all(end_user_id)
except Exception as e:
api_logger.error(f"获取用户 {end_user_id} Neo4j 记忆数量失败: {str(e)}")
return {"total": 0}
memory_nums_list = await asyncio.gather(*[get_neo4j_memory_num(uid) for uid in end_user_ids])
return {end_user_ids[i]: memory_nums_list[i] for i in range(len(end_user_ids))}
# Neo4j 模式:批量查询(简化版本只返回total
try:
batch_result = await memory_storage_service.search_all_batch(end_user_ids)
return {uid: {"total": count} for uid, count in batch_result.items()}
except Exception as e:
api_logger.error(f"批量获取 Neo4j 记忆数量失败: {str(e)}")
return {uid: {"total": 0} for uid in end_user_ids}
return {uid: {"total": 0} for uid in end_user_ids}
# 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据
try:
from app.celery_app import celery_app as _celery_app
@@ -170,13 +164,13 @@ async def get_workspace_end_users(
get_memory_configs(),
get_memory_nums()
)
# 构建结果(优化:使用列表推导式)
result = []
# 构建结果列表
items = []
for end_user in end_users:
user_id = str(end_user.id)
config_info = memory_configs_map.get(user_id, {})
result.append({
items.append({
'end_user': {
'id': user_id,
'other_name': end_user.other_name
@@ -187,12 +181,6 @@ async def get_workspace_end_users(
"memory_config_name": config_info.get("memory_config_name")
}
})
# 写入缓存30秒过期
try:
await aio_redis_set(cache_key, json.dumps(result), expire=30)
except Exception as e:
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
# 触发社区聚类补全任务(异步,不阻塞接口响应)
try:
@@ -202,7 +190,18 @@ async def get_workspace_end_users(
except Exception as e:
api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}")
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
# 构建分页响应
result = {
"items": items,
"page": {
"page": page,
"pagesize": pagesize,
"total": total,
"hasnext": (page * pagesize) < total
}
}
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录,总计 {total}")
return success(data=result, msg="宿主列表获取成功")
@@ -592,7 +591,7 @@ async def dashboard_data(
"total_api_call": None
}
# 1. 获取记忆总量total_memory
# 1. 获取记忆总量total_memory—— neo4j 独有逻辑:查询 neo4j 存储节点
try:
total_memory_data = await memory_dashboard_service.get_workspace_total_memory_count(
db=db,
@@ -601,49 +600,33 @@ async def dashboard_data(
end_user_id=end_user_id
)
neo4j_data["total_memory"] = total_memory_data.get("total_memory_count", 0)
# total_app: 统计当前空间下的所有app数量
# 包含自有app + 被分享给本工作空间的app
from app.services import app_service as _app_svc
_, total_app = _app_svc.AppService(db).list_apps(
workspace_id=workspace_id, include_shared=True, pagesize=1
)
neo4j_data["total_app"] = total_app
api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}, 应用数量: {neo4j_data['total_app']}")
api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}")
except Exception as e:
api_logger.warning(f"获取记忆总量失败: {str(e)}")
# 2. 获取知识库类型统计total_knowledge
try:
from app.services.memory_agent_service import MemoryAgentService
memory_agent_service = MemoryAgentService()
knowledge_stats = await memory_agent_service.get_knowledge_type_stats(
end_user_id=end_user_id,
only_active=True,
current_workspace_id=workspace_id,
db=db
)
neo4j_data["total_knowledge"] = knowledge_stats.get("total", 0)
api_logger.info(f"成功获取知识库类型统计total: {neo4j_data['total_knowledge']}")
except Exception as e:
api_logger.warning(f"获取知识库类型统计失败: {str(e)}")
# 2. 获取共享统计数据total_app、total_knowledge、total_api_call
common_stats = memory_dashboard_service.get_dashboard_common_stats(db, workspace_id)
neo4j_data.update(common_stats)
api_logger.info(f"成功获取共享统计: app={common_stats['total_app']}, knowledge={common_stats['total_knowledge']}, api_call={common_stats['total_api_call']}")
# 3. 获取API调用统计total_api_call
# 计算昨日对比
try:
# 使用 AppStatisticsService 获取真实的API调用统计
app_stats_service = AppStatisticsService(db)
api_stats = app_stats_service.get_workspace_api_statistics(
changes = memory_dashboard_service.get_dashboard_yesterday_changes(
db=db,
workspace_id=workspace_id,
start_date=start_date,
end_date=end_date
storage_type=storage_type,
today_data=neo4j_data
)
# 计算总调用次数
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
neo4j_data["total_api_call"] = total_api_calls
api_logger.info(f"成功获取API调用统计: {neo4j_data['total_api_call']}")
neo4j_data.update(changes)
except Exception as e:
api_logger.error(f"获取API调用统计失败: {str(e)}")
neo4j_data["total_api_call"] = 0
api_logger.warning(f"计算neo4j昨日对比失败: {str(e)}")
neo4j_data.update({
"total_memory_change": None,
"total_app_change": None,
"total_knowledge_change": None,
"total_api_call_change": None,
})
result["neo4j_data"] = neo4j_data
api_logger.info("成功获取neo4j_data")
@@ -656,44 +639,37 @@ async def dashboard_data(
"total_api_call": None
}
# 获取RAG相关数据
# 1. 获取记忆总量total_memory—— rag 独有逻辑:查询 document 表的 chunk_num
try:
# total_memory: 只统计用户知识库permission_id='Memory'的chunk数
total_chunk = memory_dashboard_service.get_rag_user_kb_total_chunk(db, current_user)
rag_data["total_memory"] = total_chunk
# total_app: 统计当前空间下的所有app数量
# 包含自有app + 被分享给本工作空间的app
from app.services import app_service as _app_svc
_, total_app = _app_svc.AppService(db).list_apps(
workspace_id=workspace_id, include_shared=True, pagesize=1
)
rag_data["total_app"] = total_app
# total_knowledge: 使用 total_kb总知识库数
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
rag_data["total_knowledge"] = total_kb
# total_api_call: 使用 AppStatisticsService 获取真实的API调用统计
try:
app_stats_service = AppStatisticsService(db)
api_stats = app_stats_service.get_workspace_api_statistics(
workspace_id=workspace_id,
start_date=start_date,
end_date=end_date
)
# 计算总调用次数
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
rag_data["total_api_call"] = total_api_calls
api_logger.info(f"成功获取RAG模式API调用统计: {rag_data['total_api_call']}")
except Exception as e:
api_logger.warning(f"获取RAG模式API调用统计失败使用默认值: {str(e)}")
rag_data["total_api_call"] = 0
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={total_app}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}")
api_logger.info(f"成功获取RAG记忆总量: {total_chunk}")
except Exception as e:
api_logger.warning(f"获取RAG相关数据失败: {str(e)}")
api_logger.warning(f"获取RAG记忆总量失败: {str(e)}")
# 2. 获取共享统计数据total_app、total_knowledge、total_api_call
common_stats = memory_dashboard_service.get_dashboard_common_stats(db, workspace_id)
rag_data.update(common_stats)
api_logger.info(f"成功获取共享统计: app={common_stats['total_app']}, knowledge={common_stats['total_knowledge']}, api_call={common_stats['total_api_call']}")
# 计算昨日对比
try:
changes = memory_dashboard_service.get_dashboard_yesterday_changes(
db=db,
workspace_id=workspace_id,
storage_type=storage_type,
today_data=rag_data
)
rag_data.update(changes)
except Exception as e:
api_logger.warning(f"计算RAG昨日对比失败: {str(e)}")
rag_data.update({
"total_memory_change": None,
"total_app_change": None,
"total_knowledge_change": None,
"total_api_call_change": None,
})
result["rag_data"] = rag_data
api_logger.info("成功获取rag_data")

View File

@@ -26,7 +26,7 @@ from app.services.memory_storage_service import (
analytics_hot_memory_tags,
analytics_recent_activity_stats,
kb_type_distribution,
search_all,
search_all_batch,
search_chunk,
search_detials,
search_dialogue,
@@ -409,7 +409,10 @@ async def search_all_num(
) -> dict:
api_logger.info(f"Search all requested for end_user_id: {end_user_id}")
try:
result = await search_all(end_user_id)
if not end_user_id:
return success(data={"total": 0}, msg="查询成功")
batch_result = await search_all_batch([end_user_id])
result = {"total": batch_result.get(end_user_id, 0)}
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Search all failed: {str(e)}")

View File

@@ -163,6 +163,7 @@ def _get_ontology_service(
api_key=api_key_config.api_key,
base_url=api_key_config.api_base,
is_omni=api_key_config.is_omni,
support_thinking="thinking" in (api_key_config.capability or []),
max_retries=3,
timeout=60.0
)

View File

@@ -453,31 +453,10 @@ async def chat(
# 流式返回
agent_config = agent_config_4_app_release(release)
if payload.stream:
# async def event_generator():
# async for event in service.chat_stream(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# ):
# yield event
if not (agent_config.model_parameters.get("deep_thinking", False) and payload.thinking):
agent_config.model_parameters["deep_thinking"] = False
# return StreamingResponse(
# event_generator(),
# media_type="text/event-stream",
# headers={
# "Cache-Control": "no-cache",
# "Connection": "keep-alive",
# "X-Accel-Buffering": "no"
# }
# )
if payload.stream:
async def event_generator():
async for event in app_chat_service.agnet_chat_stream(
message=payload.message,
@@ -503,20 +482,6 @@ async def chat(
"X-Accel-Buffering": "no"
}
)
# 非流式返回
# result = await service.chat(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# )
# return success(data=conversation_schema.ChatResponse(**result))
result = await app_chat_service.agnet_chat(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
@@ -575,48 +540,6 @@ async def chat(
)
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
# 多 Agent 流式返回
# if payload.stream:
# async def event_generator():
# async for event in service.multi_agent_chat_stream(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# ):
# yield event
# return StreamingResponse(
# event_generator(),
# media_type="text/event-stream",
# headers={
# "Cache-Control": "no-cache",
# "Connection": "keep-alive",
# "X-Accel-Buffering": "no"
# }
# )
# # 多 Agent 非流式返回
# result = await service.multi_agent_chat(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# )
# return success(data=conversation_schema.ChatResponse(**result))
elif app_type == AppType.WORKFLOW:
config = workflow_config_4_app_release(release)
if not config.id:
@@ -714,7 +637,8 @@ async def config_query(
"app_type": release.app.type,
"variables": release.config.get("variables"),
"memory": release.config.get("memory", {}).get("enabled"),
"features": release.config.get("features")
"features": release.config.get("features"),
"model_parameters": release.config.get("model_parameters")
}
elif release.app.type == AppType.MULTI_AGENT:
content = {

View File

@@ -4,7 +4,7 @@
认证方式: API Key
"""
from fastapi import APIRouter
from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller
from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller, end_user_api_controller
# 创建 V1 API 路由器
service_router = APIRouter()
@@ -16,5 +16,6 @@ service_router.include_router(rag_api_document_controller.router)
service_router.include_router(rag_api_file_controller.router)
service_router.include_router(rag_api_chunk_controller.router)
service_router.include_router(memory_api_controller.router)
service_router.include_router(end_user_api_controller.router)
__all__ = ["service_router"]

View File

@@ -144,6 +144,11 @@ async def chat(
# print(app.current_release.default_model_config_id)
agent_config = agent_config_4_app_release(app.current_release)
# print(agent_config.default_model_config_id)
# thinking 开关:仅当 agent 配置了 deep_thinking 且请求 thinking=True 时才启用
if not (agent_config.model_parameters.get("deep_thinking", False) and payload.thinking):
agent_config.model_parameters["deep_thinking"] = False
# 流式返回
if payload.stream:
async def event_generator():

View File

@@ -0,0 +1,92 @@
"""End User 服务接口 - 基于 API Key 认证"""
import uuid
from fastapi import APIRouter, Body, Depends, Request
from sqlalchemy.orm import Session
from app.core.api_key_auth import require_api_key
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.response_utils import success
from app.db import get_db
from app.repositories.end_user_repository import EndUserRepository
from app.schemas.api_key_schema import ApiKeyAuth
from app.schemas.memory_api_schema import CreateEndUserRequest, CreateEndUserResponse
from app.services.memory_config_service import MemoryConfigService
router = APIRouter(prefix="/end_user", tags=["V1 - End User API"])
logger = get_business_logger()
@router.post("/create")
@require_api_key(scopes=["memory"])
async def create_end_user(
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
message: str = Body(..., description="Request body"),
):
"""
Create or retrieve an end user for the workspace.
Creates a new end user and connects it to a memory configuration.
If an end user with the same other_id already exists in the workspace,
returns the existing one.
Optionally accepts a memory_config_id to connect the end user to a specific
memory configuration. If not provided, falls back to the workspace default config.
"""
body = await request.json()
payload = CreateEndUserRequest(**body)
workspace_id = api_key_auth.workspace_id
logger.info("Create end user request - other_id: %s, workspace_id: %s", payload.other_id, workspace_id)
# Resolve memory_config_id: explicit > workspace default
memory_config_id = None
config_service = MemoryConfigService(db)
if payload.memory_config_id:
try:
memory_config_id = uuid.UUID(payload.memory_config_id)
except ValueError:
raise BusinessException(
f"Invalid memory_config_id format: {payload.memory_config_id}",
BizCode.INVALID_PARAMETER
)
config = config_service.get_config_with_fallback(memory_config_id, workspace_id)
if not config:
raise BusinessException(
f"Memory config not found: {payload.memory_config_id}",
BizCode.MEMORY_CONFIG_NOT_FOUND
)
memory_config_id = config.config_id
else:
default_config = config_service.get_workspace_default_config(workspace_id)
if default_config:
memory_config_id = default_config.config_id
logger.info(f"Using workspace default memory config: {memory_config_id}")
else:
logger.warning(f"No default memory config found for workspace: {workspace_id}")
end_user_repo = EndUserRepository(db)
end_user = end_user_repo.get_or_create_end_user_with_config(
app_id=api_key_auth.resource_id,
workspace_id=workspace_id,
other_id=payload.other_id,
memory_config_id=memory_config_id,
)
logger.info(f"End user ready: {end_user.id}")
result = {
"id": str(end_user.id),
"other_id": end_user.other_id or "",
"other_name": end_user.other_name or "",
"workspace_id": str(end_user.workspace_id),
"memory_config_id": str(end_user.memory_config_id) if end_user.memory_config_id else None,
}
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")

View File

@@ -6,6 +6,8 @@ from app.core.response_utils import success
from app.db import get_db
from app.schemas.api_key_schema import ApiKeyAuth
from app.schemas.memory_api_schema import (
CreateEndUserRequest,
CreateEndUserResponse,
ListConfigsResponse,
MemoryReadRequest,
MemoryReadResponse,
@@ -113,3 +115,31 @@ async def list_memory_configs(
logger.info(f"Listed {result['total']} configs for workspace: {api_key_auth.workspace_id}")
return success(data=ListConfigsResponse(**result).model_dump(), msg="Configs listed successfully")
@router.post("/end_users")
@require_api_key(scopes=["memory"])
async def create_end_user(
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""
Create an end user.
Creates a new end user for the authorized workspace.
If an end user with the same other_id already exists, returns the existing one.
"""
body = await request.json()
payload = CreateEndUserRequest(**body)
logger.info(f"Create end user request - other_id: {payload.other_id}, workspace_id: {api_key_auth.workspace_id}")
memory_api_service = MemoryAPIService(db)
result = memory_api_service.create_end_user(
workspace_id=api_key_auth.workspace_id,
other_id=payload.other_id,
)
logger.info(f"End user ready: {result['id']}")
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")

View File

@@ -11,17 +11,14 @@ LangChain Agent 封装
import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
from app.db import get_db
from app.core.logging_config import get_business_logger
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType, ModelProvider
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
from langchain.agents import create_agent
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.tools import BaseTool
from langgraph.errors import GraphRecursionError
from app.core.logging_config import get_business_logger
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType
logger = get_business_logger()
@@ -41,7 +38,10 @@ class LangChainAgent:
tools: Optional[Sequence[BaseTool]] = None,
streaming: bool = False,
max_iterations: Optional[int] = None, # 最大迭代次数None 表示自动计算)
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
max_tool_consecutive_calls: int = 3, # 单个工具最大连续调用次数
deep_thinking: bool = False, # 是否启用深度思考模式
thinking_budget_tokens: Optional[int] = None, # 深度思考 token 预算
capability: Optional[List[str]] = None # 模型能力列表,用于校验是否支持深度思考
):
"""初始化 LangChain Agent
@@ -64,6 +64,7 @@ class LangChainAgent:
self.streaming = streaming
self.is_omni = is_omni
self.max_tool_consecutive_calls = max_tool_consecutive_calls
self.deep_thinking = deep_thinking and ("thinking" in (capability or []))
# 工具调用计数器:记录每个工具的连续调用次数
self.tool_call_counter: Dict[str, int] = {}
@@ -86,6 +87,13 @@ class LangChainAgent:
f"auto_calculated={max_iterations is None}"
)
# 根据 capability 校验是否真正支持深度思考
actual_deep_thinking = self.deep_thinking
if deep_thinking and not actual_deep_thinking:
logger.warning(
f"模型 {model_name} 不支持深度思考capability 中无 'thinking'),已自动关闭 deep_thinking"
)
# 创建 RedBearLLM支持多提供商
model_config = RedBearModelConfig(
model_name=model_name,
@@ -93,10 +101,13 @@ class LangChainAgent:
api_key=api_key,
base_url=api_base,
is_omni=is_omni,
deep_thinking=actual_deep_thinking,
thinking_budget_tokens=thinking_budget_tokens if actual_deep_thinking else None,
support_thinking="thinking" in (capability or []),
extra_params={
"temperature": temperature,
"max_tokens": max_tokens,
"streaming": streaming # 使用参数控制流式
"streaming": streaming
}
)
@@ -226,10 +237,9 @@ class LangChainAgent:
Returns:
List[BaseMessage]: 消息列表
"""
messages = []
messages:list = [SystemMessage(content=self.system_prompt)]
# 添加系统提示词
messages.append(SystemMessage(content=self.system_prompt))
# 添加历史消息
if history:
@@ -254,6 +264,33 @@ class LangChainAgent:
return messages
@staticmethod
def _extract_tokens_from_message(msg) -> int:
"""从 AIMessage 或类似对象中提取 total_tokens兼容多种 provider 格式
支持的格式:
- response_metadata.token_usage.total_tokens (OpenAI/ChatOpenAI)
- response_metadata.usage.total_tokens (部分 provider)
- usage_metadata.total_tokens (LangChain 新版)
"""
total = 0
# 1. response_metadata
response_meta = getattr(msg, "response_metadata", None)
if response_meta and isinstance(response_meta, dict):
# 尝试 token_usage 路径
token_usage = response_meta.get("token_usage") or response_meta.get("usage", {})
if isinstance(token_usage, dict):
total = token_usage.get("total_tokens", 0)
# 2. usage_metadataLangChain 新版 AIMessage 属性)
if not total:
usage_meta = getattr(msg, "usage_metadata", None)
if usage_meta:
if isinstance(usage_meta, dict):
total = usage_meta.get("total_tokens", 0)
else:
total = getattr(usage_meta, "total_tokens", 0)
return total or 0
def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
构建多模态消息内容
@@ -288,17 +325,23 @@ class LangChainAgent:
return content_parts
@staticmethod
def _extract_reasoning_content(msg) -> str:
"""从 AIMessage 中提取深度思考内容reasoning_content
所有 provider 统一通过 additional_kwargs.reasoning_content 传递:
- DeepSeek-R1 / QwQ: 原生字段
- Volcano (Doubao-thinking): 由 VolcanoChatOpenAI 从 delta.reasoning_content 注入
"""
additional = getattr(msg, "additional_kwargs", None) or {}
return additional.get("reasoning_content") or additional.get("reasoning", "")
async def chat(
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
end_user_id: Optional[str] = None,
config_id: Optional[str] = None, # 添加这个参数
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
memory_flag: Optional[bool] = True,
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
files: Optional[List[Dict[str, Any]]] = None
) -> Dict[str, Any]:
"""执行对话
@@ -306,31 +349,12 @@ class LangChainAgent:
message: 用户消息
history: 历史消息列表 [{"role": "user/assistant", "content": "..."}]
context: 上下文信息(如知识库检索结果)
files: 多模态文件
Returns:
Dict: 包含 content 和元数据的字典
"""
message_chat = message
start_time = time.time()
actual_config_id = config_id
# If config_id is None, try to get from end_user's connected config
if actual_config_id is None and end_user_id:
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
db = next(get_db())
try:
connected_config = get_end_user_connected_config(end_user_id, db)
actual_config_id = connected_config.get("memory_config_id")
except Exception as e:
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
finally:
db.close()
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
try:
# 准备消息列表(支持多模态)
messages = self._prepare_messages(message, history, context, files)
@@ -354,7 +378,7 @@ class LangChainAgent:
{"messages": messages},
config={"recursion_limit": self.max_iterations}
)
except RecursionError as e:
except (RecursionError, GraphRecursionError) as e:
logger.warning(
f"Agent 达到最大迭代次数限制 ({self.max_iterations}),可能存在工具调用循环",
extra={"error": str(e)}
@@ -377,6 +401,7 @@ class LangChainAgent:
logger.debug(f"输出消息数量: {len(output_messages)}")
total_tokens = 0
reasoning_content = ""
for msg in reversed(output_messages):
if isinstance(msg, AIMessage):
logger.debug(f"找到 AI 消息content 类型: {type(msg.content)}")
@@ -411,16 +436,13 @@ class LangChainAgent:
else:
content = str(msg.content)
logger.debug(f"转换为字符串: {content[:100]}...")
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
total_tokens = self._extract_tokens_from_message(msg)
reasoning_content = self._extract_reasoning_content(msg) if self.deep_thinking else ""
break
logger.info(f"最终提取的内容长度: {len(content)}")
elapsed_time = time.time() - start_time
if memory_flag:
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id,
actual_config_id)
response = {
"content": content,
"model": self.model_name,
@@ -431,6 +453,8 @@ class LangChainAgent:
"total_tokens": total_tokens
}
}
if reasoning_content:
response["reasoning_content"] = reasoning_content
logger.debug(
"Agent 调用完成",
@@ -451,22 +475,20 @@ class LangChainAgent:
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
end_user_id: Optional[str] = None,
config_id: Optional[str] = None,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
memory_flag: Optional[bool] = True,
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
) -> AsyncGenerator[str, None]:
files: Optional[List[Dict[str, Any]]] = None
) -> AsyncGenerator[str | int | dict[str, str], None]:
"""执行流式对话
Args:
message: 用户消息
history: 历史消息列表
context: 上下文信息
files: 多模态文件
Yields:
str: 消息内容块
int: token 统计
Dict: 深度思考内容 {"type": "reasoning", "content": "..."}
"""
logger.info("=" * 80)
logger.info(" chat_stream 方法开始执行")
@@ -474,23 +496,6 @@ class LangChainAgent:
logger.info(f" Has tools: {bool(self.tools)}")
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
logger.info("=" * 80)
message_chat = message
actual_config_id = config_id
# If config_id is None, try to get from end_user's connected config
if actual_config_id is None and end_user_id:
try:
db = next(get_db())
try:
connected_config = get_end_user_connected_config(end_user_id, db)
actual_config_id = connected_config.get("memory_config_id")
except Exception as e:
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
finally:
db.close()
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
try:
# 准备消息列表(支持多模态)
messages = self._prepare_messages(message, history, context, files)
@@ -500,17 +505,19 @@ class LangChainAgent:
)
chunk_count = 0
yielded_content = False
# 统一使用 agent 的 astream_events 实现流式输出
logger.debug("使用 Agent astream_events 实现流式输出")
full_content = ''
full_reasoning = ''
try:
last_event = {}
async for event in self.agent.astream_events(
{"messages": messages},
version="v2",
config={"recursion_limit": self.max_iterations}
):
last_event = event
chunk_count += 1
kind = event.get("event")
@@ -519,12 +526,18 @@ class LangChainAgent:
# LLM 流式输出
chunk = event.get("data", {}).get("chunk")
if chunk and hasattr(chunk, "content"):
# 提取深度思考内容(仅在启用深度思考时)
if self.deep_thinking:
reasoning_chunk = self._extract_reasoning_content(chunk)
if reasoning_chunk:
full_reasoning += reasoning_chunk
yield {"type": "reasoning", "content": reasoning_chunk}
# 处理多模态响应content 可能是字符串或列表
chunk_content = chunk.content
if isinstance(chunk_content, str) and chunk_content:
full_content += chunk_content
yield chunk_content
yielded_content = True
elif isinstance(chunk_content, list):
# 多模态响应:提取文本部分
for item in chunk_content:
@@ -535,29 +548,32 @@ class LangChainAgent:
if text:
full_content += text
yield text
yielded_content = True
# OpenAI 格式: {"type": "text", "text": "..."}
elif item.get("type") == "text":
text = item.get("text", "")
if text:
full_content += text
yield text
yielded_content = True
elif isinstance(item, str):
full_content += item
yield item
yielded_content = True
elif kind == "on_llm_stream":
# 另一种 LLM 流式事件
chunk = event.get("data", {}).get("chunk")
if chunk:
if hasattr(chunk, "content"):
# 提取深度思考内容(仅在启用深度思考时)
if self.deep_thinking:
reasoning_chunk = self._extract_reasoning_content(chunk)
if reasoning_chunk:
full_reasoning += reasoning_chunk
yield {"type": "reasoning", "content": reasoning_chunk}
chunk_content = chunk.content
if isinstance(chunk_content, str) and chunk_content:
full_content += chunk_content
yield chunk_content
yielded_content = True
elif isinstance(chunk_content, list):
# 多模态响应:提取文本部分
for item in chunk_content:
@@ -568,22 +584,18 @@ class LangChainAgent:
if text:
full_content += text
yield text
yielded_content = True
# OpenAI 格式: {"type": "text", "text": "..."}
elif item.get("type") == "text":
text = item.get("text", "")
if text:
full_content += text
yield text
yielded_content = True
elif isinstance(item, str):
full_content += item
yield item
yielded_content = True
elif isinstance(chunk, str):
full_content += chunk
yield chunk
yielded_content = True
# 记录工具调用(可选)
elif kind == "on_tool_start":
@@ -593,19 +605,20 @@ class LangChainAgent:
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
# 统计token消耗
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
output_messages = last_event.get("data", {}).get("output", {}).get("messages", [])
for msg in reversed(output_messages):
if isinstance(msg, AIMessage):
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get(
"total_tokens",
0
) if response_meta else 0
yield total_tokens
stream_total_tokens = self._extract_tokens_from_message(msg)
logger.info(f"流式 token 统计: total_tokens={stream_total_tokens}")
yield stream_total_tokens
break
if memory_flag:
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,
actual_config_id)
except GraphRecursionError:
logger.warning(
f"Agent 达到最大迭代次数限制 ({self.max_iterations}),模型可能不支持正确的工具调用停止判断"
)
if not full_content:
yield "抱歉,我在处理您的请求时遇到了问题(已达最大处理步骤限制)。请尝试简化问题或更换模型后重试。"
except Exception as e:
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
raise

View File

@@ -19,6 +19,7 @@ class BizCode(IntEnum):
TENANT_NOT_FOUND = 3002
WORKSPACE_NO_ACCESS = 3003
WORKSPACE_INVITE_NOT_FOUND = 3004
WORKSPACE_ACCESS_DENIED = 3005
# API Key 管理3xxx
API_KEY_NOT_FOUND = 3007
API_KEY_DUPLICATE_NAME = 3008
@@ -113,6 +114,8 @@ HTTP_MAPPING = {
BizCode.FORBIDDEN: 403,
BizCode.TENANT_NOT_FOUND: 400,
BizCode.WORKSPACE_NO_ACCESS: 403,
BizCode.WORKSPACE_INVITE_NOT_FOUND: 400,
BizCode.WORKSPACE_ACCESS_DENIED: 403,
BizCode.NOT_FOUND: 400,
BizCode.USER_NOT_FOUND: 200,
BizCode.WORKSPACE_NOT_FOUND: 400,

View File

@@ -0,0 +1,408 @@
"""
Perceptual Memory Retrieval Node & Service
Provides PerceptualSearchService for searching perceptual memories (vision, audio,
text, conversation) from Neo4j using keyword fulltext + embedding semantic search
with BM25+embedding fusion reranking.
Also provides the perceptual_retrieve_node for use as a LangGraph node.
"""
import asyncio
import math
from typing import List, Dict, Any, Optional
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import ReadState
from app.core.memory.utils.data.text_utils import escape_lucene_query
from app.repositories.neo4j.graph_search import (
search_perceptual,
search_perceptual_by_embedding,
)
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
logger = get_agent_logger(__name__)
class PerceptualSearchService:
"""
感知记忆检索服务。
封装关键词全文检索 + 向量语义检索 + BM25/embedding 融合排序的完整流程。
调用方只需提供 query / keywords、end_user_id、memory_config即可获得
格式化并排序后的感知记忆列表和拼接文本。
Usage:
service = PerceptualSearchService(end_user_id=..., memory_config=...)
results = await service.search(query="...", keywords=[...], limit=10)
# results = {"memories": [...], "content": "...", "keyword_raw": N, "embedding_raw": M}
"""
DEFAULT_ALPHA = 0.6
DEFAULT_CONTENT_SCORE_THRESHOLD = 0.5
def __init__(
self,
end_user_id: str,
memory_config: Any,
alpha: float = DEFAULT_ALPHA,
content_score_threshold: float = DEFAULT_CONTENT_SCORE_THRESHOLD,
):
self.end_user_id = end_user_id
self.memory_config = memory_config
self.alpha = alpha
self.content_score_threshold = content_score_threshold
async def search(
self,
query: str,
keywords: Optional[List[str]] = None,
limit: int = 10,
) -> Dict[str, Any]:
"""
执行感知记忆检索(关键词 + 向量并行),融合排序后返回结果。
对 embedding 命中但 keyword 未命中的结果,补查全文索引获取 BM25 分数,
确保所有结果都同时具备 BM25 和 embedding 两个维度的评分。
Args:
query: 原始用户查询(用于向量检索和 BM25 补查)
keywords: 关键词列表(用于全文检索),为 None 时使用 [query]
limit: 最大返回数量
Returns:
{
"memories": [格式化后的记忆 dict, ...],
"content": "拼接的纯文本摘要",
"keyword_raw": int,
"embedding_raw": int,
}
"""
if keywords is None:
keywords = [query] if query else []
connector = Neo4jConnector()
try:
kw_task = self._keyword_search(connector, keywords, limit)
emb_task = self._embedding_search(connector, query, limit)
kw_results, emb_results = await asyncio.gather(
kw_task, emb_task, return_exceptions=True
)
if isinstance(kw_results, Exception):
logger.warning(f"[PerceptualSearch] keyword search error: {kw_results}")
kw_results = []
if isinstance(emb_results, Exception):
logger.warning(f"[PerceptualSearch] embedding search error: {emb_results}")
emb_results = []
# 补查 BM25找出 embedding 命中但 keyword 未命中的 id
# 用原始 query 对这些节点补查全文索引拿 BM25 score
kw_ids = {r.get("id") for r in kw_results if r.get("id")}
emb_only_ids = {r.get("id") for r in emb_results if r.get("id") and r.get("id") not in kw_ids}
if emb_only_ids and query:
backfill = await self._bm25_backfill(connector, query, emb_only_ids, limit)
# 把补查到的 BM25 score 注入到 embedding 结果中
backfill_map = {r["id"]: r.get("score", 0) for r in backfill}
for r in emb_results:
rid = r.get("id", "")
if rid in backfill_map:
r["bm25_backfill_score"] = backfill_map[rid]
logger.info(
f"[PerceptualSearch] BM25 backfill: {len(emb_only_ids)} embedding-only ids, "
f"{len(backfill_map)} got BM25 scores"
)
reranked = self._rerank(kw_results, emb_results, limit)
memories = []
content_parts = []
for record in reranked:
fmt = self._format_result(record)
fmt["score"] = round(record.get("content_score", 0), 4)
memories.append(fmt)
content_parts.append(self._build_content_text(fmt))
logger.info(
f"[PerceptualSearch] {len(memories)} results after rerank "
f"(keyword_raw={len(kw_results)}, embedding_raw={len(emb_results)})"
)
return {
"memories": memories,
"content": "\n\n".join(content_parts),
"keyword_raw": len(kw_results),
"embedding_raw": len(emb_results),
}
finally:
await connector.close()
async def _bm25_backfill(
self,
connector: Neo4jConnector,
query: str,
target_ids: set,
limit: int,
) -> List[dict]:
"""
对指定 id 集合补查全文索引 BM25 score。
用原始 query 查全文索引,只保留 id 在 target_ids 中的结果。
"""
escaped = escape_lucene_query(query)
if not escaped.strip():
return []
try:
r = await search_perceptual(
connector=connector, q=escaped,
end_user_id=self.end_user_id,
limit=limit * 5, # 多查一些以提高命中率
)
all_hits = r.get("perceptuals", [])
return [h for h in all_hits if h.get("id") in target_ids]
except Exception as e:
logger.warning(f"[PerceptualSearch] BM25 backfill failed: {e}")
return []
async def _keyword_search(
self,
connector: Neo4jConnector,
keywords: List[str],
limit: int,
) -> List[dict]:
"""并发对每个关键词做全文检索,去重后按 score 降序返回 top N 原始结果。"""
seen_ids: set = set()
all_results: List[dict] = []
async def _one(kw: str):
escaped = escape_lucene_query(kw)
if not escaped.strip():
return []
r = await search_perceptual(
connector=connector, q=escaped,
end_user_id=self.end_user_id, limit=limit,
)
return r.get("perceptuals", [])
tasks = [_one(kw) for kw in keywords[:10]]
batch = await asyncio.gather(*tasks, return_exceptions=True)
for result in batch:
if isinstance(result, Exception):
logger.warning(f"[PerceptualSearch] keyword sub-query error: {result}")
continue
for rec in result:
rid = rec.get("id", "")
if rid and rid not in seen_ids:
seen_ids.add(rid)
all_results.append(rec)
all_results.sort(key=lambda x: float(x.get("score", 0)), reverse=True)
return all_results[:limit]
async def _embedding_search(
self,
connector: Neo4jConnector,
query_text: str,
limit: int,
) -> List[dict]:
"""向量语义检索,返回原始结果(不做阈值过滤)。"""
try:
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.services.memory_config_service import MemoryConfigService
with get_db_context() as db:
cfg = MemoryConfigService(db).get_embedder_config(
str(self.memory_config.embedding_model_id)
)
client = OpenAIEmbedderClient(RedBearModelConfig(**cfg))
r = await search_perceptual_by_embedding(
connector=connector, embedder_client=client,
query_text=query_text, end_user_id=self.end_user_id,
limit=limit,
)
return r.get("perceptuals", [])
except Exception as e:
logger.warning(f"[PerceptualSearch] embedding search failed: {e}")
return []
def _rerank(
self,
keyword_results: List[dict],
embedding_results: List[dict],
limit: int,
) -> List[dict]:
"""BM25 + embedding 融合排序。
对 embedding 结果中带有 bm25_backfill_score 的条目,
将其与 keyword 结果合并后统一归一化,确保 BM25 分数在同一尺度上。
"""
# 把补查的 BM25 score 合并到 keyword_results 中统一归一化
emb_backfill_items = []
for item in embedding_results:
backfill_score = item.get("bm25_backfill_score")
if backfill_score is not None and item.get("id"):
emb_backfill_items.append({"id": item["id"], "score": backfill_score})
# 合并后统一归一化 BM25 scores
all_bm25_items = keyword_results + emb_backfill_items
all_bm25_items = self._normalize_scores(all_bm25_items)
# 建立 id -> normalized BM25 score 的映射
bm25_norm_map: Dict[str, float] = {}
for item in all_bm25_items:
item_id = item.get("id", "")
if item_id:
bm25_norm_map[item_id] = float(item.get("normalized_score", 0))
# 归一化 embedding scores
embedding_results = self._normalize_scores(embedding_results)
# 合并
combined: Dict[str, dict] = {}
for item in keyword_results:
item_id = item.get("id", "")
if not item_id:
continue
combined[item_id] = item.copy()
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
combined[item_id]["embedding_score"] = 0.0
for item in embedding_results:
item_id = item.get("id", "")
if not item_id:
continue
if item_id in combined:
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
else:
combined[item_id] = item.copy()
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
for item in combined.values():
bm25 = float(item.get("bm25_score", 0) or 0)
emb = float(item.get("embedding_score", 0) or 0)
item["content_score"] = self.alpha * bm25 + (1 - self.alpha) * emb
results = list(combined.values())
before = len(results)
results = [r for r in results if r["content_score"] >= self.content_score_threshold]
results.sort(key=lambda x: x["content_score"], reverse=True)
results = results[:limit]
logger.info(
f"[PerceptualSearch] rerank: merged={before}, after_threshold={len(results)} "
f"(alpha={self.alpha}, threshold={self.content_score_threshold})"
)
return results
@staticmethod
def _normalize_scores(items: List[dict], field: str = "score") -> List[dict]:
"""Z-score + sigmoid 归一化。"""
if not items:
return items
scores = [float(it.get(field, 0) or 0) for it in items]
if len(scores) <= 1:
for it in items:
it[f"normalized_{field}"] = 1.0
return items
mean = sum(scores) / len(scores)
var = sum((s - mean) ** 2 for s in scores) / len(scores)
std = math.sqrt(var)
if std == 0:
for it in items:
it[f"normalized_{field}"] = 1.0
else:
for it, s in zip(items, scores):
z = (s - mean) / std
it[f"normalized_{field}"] = 1 / (1 + math.exp(-z))
return items
@staticmethod
def _format_result(record: dict) -> dict:
return {
"id": record.get("id", ""),
"perceptual_type": record.get("perceptual_type", ""),
"file_name": record.get("file_name", ""),
"file_path": record.get("file_path", ""),
"summary": record.get("summary", ""),
"topic": record.get("topic", ""),
"domain": record.get("domain", ""),
"keywords": record.get("keywords", []),
"created_at": str(record.get("created_at", "")),
"file_type": record.get("file_type", ""),
"score": record.get("score", 0),
}
@staticmethod
def _build_content_text(formatted: dict) -> str:
parts = []
if formatted["summary"]:
parts.append(formatted["summary"])
if formatted["topic"]:
parts.append(f"[主题: {formatted['topic']}]")
if formatted["keywords"]:
kw_list = formatted["keywords"]
if isinstance(kw_list, list):
parts.append(f"[关键词: {', '.join(kw_list)}]")
if formatted["file_name"]:
parts.append(f"[文件: {formatted['file_name']}]")
return " ".join(parts)
def _extract_keywords_from_problems(problem_extension: dict) -> List[str]:
"""Extract search keywords from problem extension results."""
keywords = []
context = problem_extension.get("context", {})
if isinstance(context, dict):
for original_q, extended_qs in context.items():
keywords.append(original_q)
if isinstance(extended_qs, list):
keywords.extend(extended_qs)
return keywords
async def perceptual_retrieve_node(state: ReadState) -> ReadState:
"""
LangGraph node: perceptual memory retrieval.
Uses PerceptualSearchService to run keyword + embedding search with
BM25 fusion reranking, then writes results to state['perceptual_data'].
"""
end_user_id = state.get("end_user_id", "")
problem_extension = state.get("problem_extension", {})
original_query = state.get("data", "")
memory_config = state.get("memory_config", None)
logger.info(f"Perceptual_Retrieve: start, end_user_id={end_user_id}")
keywords = _extract_keywords_from_problems(problem_extension)
if not keywords:
keywords = [original_query] if original_query else []
logger.info(f"Perceptual_Retrieve: {len(keywords)} keywords extracted")
service = PerceptualSearchService(
end_user_id=end_user_id,
memory_config=memory_config,
)
search_result = await service.search(
query=original_query,
keywords=keywords,
limit=10,
)
result = {
"memories": search_result["memories"],
"content": search_result["content"],
"_intermediate": {
"type": "perceptual_retrieve",
"title": "感知记忆检索",
"data": search_result["memories"],
"query": original_query,
"result_count": len(search_result["memories"]),
},
}
return {"perceptual_data": result}

View File

@@ -263,7 +263,6 @@ async def Problem_Extension(state: ReadState) -> ReadState:
logger.info(f"Problem extension result: {aggregated_dict}")
# Emit intermediate output for frontend
print(time.time() - start)
result = {
"context": aggregated_dict,
"original": data,

View File

@@ -1,7 +1,11 @@
import asyncio
import os
import time
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
PerceptualSearchService,
)
from app.core.memory.agent.models.summary_models import (
RetrieveSummaryResponse,
SummaryResponse,
@@ -339,11 +343,45 @@ async def Input_Summary(state: ReadState) -> ReadState:
try:
if storage_type != "rag":
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(
async def _perceptual_search():
service = PerceptualSearchService(
end_user_id=end_user_id,
memory_config=memory_config,
)
return await service.search(query=data, limit=5)
hybrid_task = SearchService().execute_hybrid_search(
**search_params,
memory_config=memory_config,
expand_communities=False, # 路径 "2" 只需要 community 的 summary 文本,不展开到 Statement
expand_communities=False,
)
perceptual_task = _perceptual_search()
gather_results = await asyncio.gather(
hybrid_task, perceptual_task, return_exceptions=True
)
hybrid_result = gather_results[0]
perceptual_results = gather_results[1]
# 处理 hybrid search 异常
if isinstance(hybrid_result, Exception):
raise hybrid_result
retrieve_info, question, raw_results = hybrid_result
# 处理感知记忆结果
if isinstance(perceptual_results, Exception):
logger.warning(f"[Input_Summary] perceptual search failed: {perceptual_results}")
perceptual_results = []
# 拼接感知记忆内容到 retrieve_info
if perceptual_results and isinstance(perceptual_results, dict):
perceptual_content = perceptual_results.get("content", "")
if perceptual_content:
retrieve_info = f"{retrieve_info}\n\n<history-files>\n{perceptual_content}"
count = len(perceptual_results.get("memories", []))
logger.info(f"[Input_Summary] appended {count} perceptual memories (reranked)")
# 调试:打印 community 检索结果数量
if raw_results and isinstance(raw_results, dict):
reranked = raw_results.get('reranked_results', {})
@@ -371,10 +409,7 @@ async def Input_Summary(state: ReadState) -> ReadState:
"error": str(e)
}
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
duration = end - start
log_time('检索', duration)
return {"summary": summary}
@@ -412,8 +447,20 @@ async def Retrieve_Summary(state: ReadState) -> ReadState:
retrieve_info_str = list(set(retrieve_info_str))
retrieve_info_str = '\n'.join(retrieve_info_str)
aimessages = await summary_llm(state, history, retrieve_info_str,
'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1")
# Merge perceptual memory content
perceptual_data = state.get("perceptual_data", {})
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
if perceptual_content:
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
aimessages = await summary_llm(
state,
history,
retrieve_info_str,
'direct_summary_prompt.jinja2',
'retrieve_summary', RetrieveSummaryResponse,
"1"
)
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
await summary_redis_save(state, aimessages)
if aimessages == '':
@@ -458,6 +505,12 @@ async def Summary(state: ReadState) -> ReadState:
retrieve_info_str += i + '\n'
history = await summary_history(state)
# Merge perceptual memory content
perceptual_data = state.get("perceptual_data", {})
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
if perceptual_content:
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
data = {
"query": query,
"history": history,
@@ -508,6 +561,13 @@ async def Summary_fails(state: ReadState) -> ReadState:
if key == 'answer_small':
for i in value:
retrieve_info_str += i + '\n'
# Merge perceptual memory content
perceptual_data = state.get("perceptual_data", {})
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
if perceptual_content:
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
data = {
"query": query,
"history": history,

View File

@@ -15,7 +15,10 @@ from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import (
Problem_Extension,
)
from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import (
retrieve,
retrieve_nodes,
)
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
perceptual_retrieve_node,
)
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
Input_Summary,
@@ -48,13 +51,14 @@ async def make_read_graph():
"""
try:
# Build workflow graph
workflow = StateGraph(ReadState)
workflow = StateGraph(ReadState)
workflow.add_node("content_input", content_input_node)
workflow.add_node("Split_The_Problem", Split_The_Problem)
workflow.add_node("Problem_Extension", Problem_Extension)
workflow.add_node("Input_Summary", Input_Summary)
# workflow.add_node("Retrieve", retrieve_nodes)
workflow.add_node("Retrieve", retrieve)
workflow.add_node("Retrieve", retrieve_nodes)
# workflow.add_node("Retrieve", retrieve)
workflow.add_node("Perceptual_Retrieve", perceptual_retrieve_node)
workflow.add_node("Verify", Verify)
workflow.add_node("Retrieve_Summary", Retrieve_Summary)
workflow.add_node("Summary", Summary)
@@ -65,14 +69,15 @@ async def make_read_graph():
workflow.add_conditional_edges("content_input", Split_continue)
workflow.add_edge("Input_Summary", END)
workflow.add_edge("Split_The_Problem", "Problem_Extension")
workflow.add_edge("Problem_Extension", "Retrieve")
# After Problem_Extension, retrieve perceptual memory first, then main Retrieve
workflow.add_edge("Problem_Extension", "Perceptual_Retrieve")
workflow.add_edge("Perceptual_Retrieve", "Retrieve")
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
workflow.add_edge("Retrieve_Summary", END)
workflow.add_conditional_edges("Verify", Verify_continue)
workflow.add_edge("Summary_fails", END)
workflow.add_edge("Summary", END)
'''-----'''
# workflow.add_edge("Retrieve", END)
# Compile workflow
@@ -80,7 +85,5 @@ async def make_read_graph():
yield graph
except Exception as e:
print(f"创建工作流失败: {e}")
logger.error(f"创建工作流失败: {e}")
raise
finally:
print("工作流创建完成")

View File

@@ -12,7 +12,6 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
from app.utils.config_utils import resolve_config_id
@@ -21,25 +20,6 @@ logger = get_agent_logger(__name__)
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
"""
Write messages to RAG storage system
Combines user and AI messages into a single string format and stores them
in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval.
Args:
end_user_id: User identifier for the conversation
user_message: User's input message content
ai_message: AI's response message content
user_rag_memory_id: RAG memory identifier for storage location
"""
# RAG mode: combine messages into string format (maintain original logic)
combined_message = f"user: {user_message}\nassistant: {ai_message}"
await write_rag(end_user_id, combined_message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
async def write(
storage_type,
end_user_id,
@@ -118,7 +98,7 @@ async def write(
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope):
async def term_memory_save(end_user_id, strategy_type, scope):
"""
Save long-term memory data to database
@@ -127,10 +107,8 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
to long-term memory storage.
Args:
long_term_messages: Long-term message data to be saved
actual_config_id: Configuration identifier for memory settings
end_user_id: User identifier for memory association
type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
strategy_type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
scope: Scope/window size for memory processing
"""
with get_db_context() as db_session:
@@ -138,7 +116,10 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
from app.core.memory.agent.utils.redis_tool import write_store
result = write_store.get_session_by_userid(end_user_id)
if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
if not result:
logger.warning(f"No write data found for user {end_user_id}")
return
if strategy_type in [AgentMemory_Long_Term.STRATEGY_CHUNK, AgentMemory_Long_Term.STRATEGY_AGGREGATE]:
data = await format_parsing(result, "dict")
chunk_data = data[:scope]
if len(chunk_data) == scope:
@@ -151,9 +132,6 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
logger.info(f'写入短长期:')
"""Window-based dialogue processing"""
async def window_dialogue(end_user_id, langchain_messages, memory_config, scope):
"""
Process dialogue based on window size and write to Neo4j
@@ -167,40 +145,33 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope)
langchain_messages: Original message data list
scope: Window size determining when to trigger long-term storage
"""
scope = scope
is_end_user_id = count_store.get_sessions_count(end_user_id)
if is_end_user_id is not False:
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
redis_messages = count_store.get_sessions_count(end_user_id)[1]
if is_end_user_id and int(is_end_user_id) != int(scope):
is_end_user_id += 1
langchain_messages += redis_messages
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
elif int(is_end_user_id) == int(scope):
is_end_user_has_history = count_store.get_sessions_count(end_user_id)
if is_end_user_has_history:
end_user_visit_count, redis_messages = is_end_user_has_history
else:
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
return
end_user_visit_count += 1
if end_user_visit_count < scope:
redis_messages.extend(langchain_messages)
count_store.update_sessions_count(end_user_id, end_user_visit_count, redis_messages)
else:
logger.info('写入长期记忆NEO4J')
formatted_messages = redis_messages
redis_messages.extend(langchain_messages)
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
if hasattr(memory_config, 'config_id'):
config_id = memory_config.config_id
else:
config_id = memory_config
await write(
AgentMemory_Long_Term.STORAGE_NEO4J,
end_user_id,
"",
"",
None,
end_user_id,
config_id,
formatted_messages
write_message_task.delay(
end_user_id, # end_user_id: User ID
redis_messages, # message: JSON string format message list
config_id, # config_id: Configuration ID string
AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j"
"" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
)
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
else:
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
"""Time-based memory processing"""
count_store.update_sessions_count(end_user_id, 0, [])
async def memory_long_term_storage(end_user_id, memory_config, time):
@@ -291,9 +262,7 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
return result_dict
except Exception as e:
print(f"[aggregate_judgment] 发生错误: {e}")
import traceback
traceback.print_exc()
logger.error(f"[aggregate_judgment] 发生错误: {e}", exc_info=True)
return {
"is_same_event": False,

View File

@@ -1,49 +1,25 @@
import asyncio
import json
import sys
import warnings
from contextlib import asynccontextmanager
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from app.db import get_db, get_db_context
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
aggregate_judgment
from app.core.memory.agent.utils.redis_tool import write_store
from app.db import get_db_context
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import write_rag
warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__)
if sys.platform.startswith("win"):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
@asynccontextmanager
async def make_write_graph():
"""
Create a write graph workflow for memory operations.
Args:
user_id: User identifier
tools: MCP tools loaded from session
apply_id: Application identifier
end_user_id: Group identifier
memory_config: MemoryConfig object containing all configuration
"""
workflow = StateGraph(WriteState)
workflow.add_node("save_neo4j", write_node)
workflow.add_edge(START, "save_neo4j")
workflow.add_edge("save_neo4j", END)
graph = workflow.compile()
yield graph
async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '',
end_user_id: str = '', scope: int = 6):
async def long_term_storage(
long_term_type: str,
langchain_messages: list,
memory_config_id: str,
end_user_id: str,
scope: int = 6
):
"""
Handle long-term memory storage with different strategies
@@ -53,33 +29,39 @@ async def long_term_storage(long_term_type: str = "chunk", langchain_messages: l
Args:
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
langchain_messages: List of messages to store
memory_config: Memory configuration identifier
memory_config_id: Memory configuration identifier
end_user_id: User group identifier
scope: Scope parameter for chunk-based storage (default: 6)
"""
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
aggregate_judgment
from app.core.memory.agent.utils.redis_tool import write_store
if langchain_messages is None:
langchain_messages = []
write_store.save_session_write(end_user_id, langchain_messages)
# 获取数据库会话
with get_db_context() as db_session:
config_service = MemoryConfigService(db_session)
memory_config = config_service.load_memory_config(
config_id=memory_config, # 改为整数
config_id=memory_config_id, # 改为整数
service_name="MemoryAgentService"
)
if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
'''Strategy 1: Dialogue window with 6 rounds of conversation'''
# Dialogue window with 6 rounds of conversation
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
"""Time-based strategy"""
# Time-based strategy
await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
"""Strategy 3: Aggregate judgment"""
# Aggregate judgment
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id):
async def write_long_term(
storage_type: str,
end_user_id: str,
messages: list[dict],
user_rag_memory_id: str,
actual_config_id: str
):
"""
Write long-term memory with different storage types
@@ -89,44 +71,24 @@ async def write_long_term(storage_type, end_user_id, message_chat, aimessages, u
Args:
storage_type: Type of storage (RAG or traditional)
end_user_id: User group identifier
message_chat: User message content
aimessages: AI response messages
messages: message list
user_rag_memory_id: RAG memory identifier
actual_config_id: Actual configuration ID
"""
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
message_content = []
for message in messages:
message_content.append(f'{message.get("role")}:{message.get("content")}')
messages_string = "\n".join(message_content)
await write_rag(end_user_id, messages_string, user_rag_memory_id)
else:
# AI reply writing (user messages and AI replies paired, written as complete dialogue at once)
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
long_term_messages = await agent_chat_messages(message_chat, aimessages)
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
# async def main():
# """主函数 - 运行工作流"""
# langchain_messages = [
# {
# "role": "user",
# "content": "今天周五去爬山"
# },
# {
# "role": "assistant",
# "content": "好耶"
# }
#
# ]
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
#
#
#
# if __name__ == "__main__":
# import asyncio
# asyncio.run(main())
await long_term_storage(long_term_type=CHUNK,
langchain_messages=messages,
memory_config_id=actual_config_id,
end_user_id=end_user_id,
scope=SCOPE)
await term_memory_save(end_user_id, CHUNK, scope=SCOPE)

View File

@@ -10,7 +10,6 @@ from app.core.logging_config import get_agent_logger
from app.core.memory.src.search import run_hybrid_search
from app.core.memory.utils.data.text_utils import escape_lucene_query
logger = get_agent_logger(__name__)
# 需要从展开结果中过滤的字段(含 Neo4j DateTime不可 JSON 序列化)
@@ -31,10 +30,10 @@ def _clean_expand_fields(obj):
async def expand_communities_to_statements(
community_results: List[dict],
end_user_id: str,
existing_content: str = "",
limit: int = 10,
community_results: List[dict],
end_user_id: str,
existing_content: str = "",
limit: int = 10,
) -> Tuple[List[dict], List[str]]:
"""
社区展开 helper给定命中的 community 列表,拉取关联 Statement。
@@ -76,17 +75,18 @@ async def expand_communities_to_statements(
if s.get("statement") and s["statement"] not in existing_lines
]
cleaned = _clean_expand_fields(expanded_stmts)
logger.info(f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements新增 {len(new_texts)}community_ids={community_ids}")
logger.info(
f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements新增 {len(new_texts)}community_ids={community_ids}")
return cleaned, new_texts
class SearchService:
"""Service for executing hybrid search and processing results."""
def __init__(self):
"""Initialize the search service."""
logger.info("SearchService initialized")
def extract_content_from_result(self, result: dict, node_type: str = "") -> str:
"""
Extract only meaningful content from search results, dropping all metadata.
@@ -107,19 +107,19 @@ class SearchService:
"""
if not isinstance(result, dict):
return str(result)
content_parts = []
# Statements: extract statement field
if 'statement' in result and result['statement']:
content_parts.append(result['statement'])
# Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定
# 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要
is_community = (
node_type == "community"
or 'member_count' in result
or 'core_entities' in result
node_type == "community"
or 'member_count' in result
or 'core_entities' in result
)
if is_community:
name = result.get('name', '')
@@ -130,16 +130,16 @@ class SearchService:
elif 'content' in result and result['content']:
# Summaries / Chunks
content_parts.append(result['content'])
# Entities: extract name and fact_summary (commented out in original)
# if 'name' in result and result['name']:
# content_parts.append(result['name'])
# if result.get('fact_summary'):
# content_parts.append(result['fact_summary'])
# Return concatenated content or empty string
return '\n'.join(content_parts) if content_parts else ""
def clean_query(self, query: str) -> str:
"""
Clean and escape query text for Lucene.
@@ -155,33 +155,33 @@ class SearchService:
Cleaned and escaped query string
"""
q = str(query).strip()
# Remove wrapping quotes
if (q.startswith("'") and q.endswith("'")) or (
q.startswith('"') and q.endswith('"')
q.startswith('"') and q.endswith('"')
):
q = q[1:-1]
# Remove newlines and carriage returns
q = q.replace('\r', ' ').replace('\n', ' ').strip()
# Apply Lucene escaping
q = escape_lucene_query(q)
return q
async def execute_hybrid_search(
self,
end_user_id: str,
question: str,
limit: int = 5,
search_type: str = "hybrid",
include: Optional[List[str]] = None,
rerank_alpha: float = 0.4,
output_path: str = "search_results.json",
return_raw_results: bool = False,
memory_config = None,
expand_communities: bool = True,
self,
end_user_id: str,
question: str,
limit: int = 5,
search_type: str = "hybrid",
include: Optional[List[str]] = None,
rerank_alpha: float = 0.4,
output_path: str = "search_results.json",
return_raw_results: bool = False,
memory_config=None,
expand_communities: bool = True,
) -> Tuple[str, str, Optional[dict]]:
"""
Execute hybrid search and return clean content.
@@ -205,10 +205,10 @@ class SearchService:
"""
if include is None:
include = ["statements", "chunks", "entities", "summaries", "communities"]
# Clean query
cleaned_query = self.clean_query(question)
try:
# Execute search
answer = await run_hybrid_search(
@@ -221,18 +221,18 @@ class SearchService:
memory_config=memory_config,
rerank_alpha=rerank_alpha
)
# Extract results based on search type and include parameter
# Prioritize summaries as they contain synthesized contextual information
answer_list = []
# For hybrid search, use reranked_results
if search_type == "hybrid":
reranked_results = answer.get('reranked_results', {})
# Priority order: summaries first (most contextual), then communities, statements, chunks, entities
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
for category in priority_order:
if category in include and category in reranked_results:
category_results = reranked_results[category]
@@ -242,7 +242,7 @@ class SearchService:
# For keyword or embedding search, results are directly in answer dict
# Apply same priority order
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
for category in priority_order:
if category in include and category in answer:
category_results = answer[category]
@@ -261,7 +261,7 @@ class SearchService:
end_user_id=end_user_id,
)
answer_list.extend(cleaned_stmts)
# Extract clean content from all results按类型传入 node_type 区分 community
content_list = []
for ans in answer_list:
@@ -269,19 +269,18 @@ class SearchService:
ntype = "community" if ('member_count' in ans or 'core_entities' in ans) else ""
content_list.append(self.extract_content_from_result(ans, node_type=ntype))
# Filter out empty strings and join with newlines
clean_content = '\n'.join([c for c in content_list if c])
# Log first 200 chars
logger.info(f"检索接口搜索结果==>>:{clean_content[:200]}...")
# Return raw results if requested
if return_raw_results:
return clean_content, cleaned_query, answer
else:
return clean_content, cleaned_query, None
except Exception as e:
logger.error(
f"Search failed for query '{question}' in group '{end_user_id}': {e}",

View File

@@ -1,4 +1,3 @@
import os
from collections import defaultdict
from pathlib import Path
from typing import Annotated, TypedDict
@@ -52,6 +51,7 @@ class ReadState(TypedDict):
embedding_id: str
memory_config: object # 新增字段用于传递内存配置对象
retrieve: dict
perceptual_data: dict
RetrieveSummary: dict
InputSummary: dict
verify: dict

View File

@@ -3,8 +3,9 @@ import uuid
from app.core.config import settings
from typing import List, Dict, Any, Optional, Union
from app.core.logging_config import get_logger
from app.core.memory.agent.utils.redis_base import (
serialize_messages,
serialize_messages,
deserialize_messages,
fix_encoding,
format_session_data,
@@ -14,12 +15,12 @@ from app.core.memory.agent.utils.redis_base import (
get_current_timestamp
)
logger = get_logger(__name__)
class RedisWriteStore:
"""Redis Write 类型存储类,用于管理 save_session_write 相关的数据"""
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
"""
初始化 Redis 连接
@@ -66,10 +67,10 @@ class RedisWriteStore:
})
result = pipe.execute()
print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
logger.debug(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
return session_id
except Exception as e:
print(f"[save_session_write] 保存会话失败: {e}")
logger.error(f"[save_session_write] 保存会话失败: {e}")
raise e
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
@@ -99,7 +100,7 @@ class RedisWriteStore:
for key, data in zip(keys, all_data):
if not data:
continue
# 从 write 类型读取,匹配 sessionid 字段
if data.get('sessionid') == userid:
# 从 key 中提取 session_id: session:write:{session_id}
@@ -108,16 +109,16 @@ class RedisWriteStore:
"sessionid": session_id,
"messages": fix_encoding(data.get('messages', ''))
})
if not results:
return False
print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
logger.debug(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
return results
except Exception as e:
print(f"[get_session_by_userid] 查询失败: {e}")
logger.error(f"[get_session_by_userid] 查询失败: {e}")
return False
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
"""
通过 end_user_id 获取所有 write 类型的会话数据
@@ -144,7 +145,7 @@ class RedisWriteStore:
# 只查询 write 类型的 key
keys = self.r.keys('session:write:*')
if not keys:
print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
logger.debug(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
return False
# 批量获取数据
@@ -158,12 +159,12 @@ class RedisWriteStore:
for key, data in zip(keys, all_data):
if not data:
continue
# 从 write 类型读取,匹配 sessionid 字段
if data.get('sessionid') == end_user_id:
# 从 key 中提取 session_id: session:write:{session_id}
session_id = key.split(':')[-1]
# 构建完整的会话信息
session_info = {
"session_id": session_id,
@@ -173,23 +174,21 @@ class RedisWriteStore:
"starttime": data.get('starttime', '')
}
results.append(session_info)
if not results:
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
return False
# 按时间排序(最新的在前)
results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
return results
except Exception as e:
print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}")
import traceback
traceback.print_exc()
logger.error(f"[get_all_sessions_by_end_user_id] 查询失败: {e}", exc_info=True)
return False
def find_user_recent_sessions(self, userid: str,
def find_user_recent_sessions(self, userid: str,
minutes: int = 5) -> List[Dict[str, str]]:
"""
根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据
@@ -203,11 +202,11 @@ class RedisWriteStore:
"""
import time
start_time = time.time()
# 只查询 write 类型的 key
keys = self.r.keys('session:write:*')
if not keys:
print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
logger.debug(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
return []
# 批量获取数据
@@ -221,7 +220,7 @@ class RedisWriteStore:
for data in all_data:
if not data:
continue
# 从 write 类型读取,匹配 sessionid 字段
if data.get('sessionid') == userid and data.get('starttime'):
# write 类型没有 aimessages所以 Answer 为空
@@ -230,15 +229,14 @@ class RedisWriteStore:
"Answer": "",
"starttime": data.get('starttime', '')
})
# 根据时间范围过滤
filtered_items = filter_by_time_range(matched_items, minutes)
# 排序并移除时间字段
result_items = sort_and_limit_results(filtered_items, limit=None)
print(result_items)
result_items = sort_and_limit_results(filtered_items)
elapsed_time = time.time() - start_time
print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
logger.debug(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
return result_items
@@ -258,7 +256,7 @@ class RedisWriteStore:
class RedisCountStore:
"""Redis Count 类型存储类,用于管理访问次数统计相关的数据"""
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
"""
初始化 Redis 连接
@@ -278,7 +276,7 @@ class RedisCountStore:
decode_responses=True,
encoding='utf-8'
)
self.uudi = session_id
self.uuid = session_id
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
"""
@@ -295,26 +293,26 @@ class RedisCountStore:
session_id = str(uuid.uuid4())
key = generate_session_key(session_id, key_type="count")
index_key = f'session:count:index:{end_user_id}' # 索引键
pipe = self.r.pipeline()
pipe.hset(key, mapping={
"id": self.uudi,
"id": self.uuid,
"end_user_id": end_user_id,
"count": int(count),
"messages": serialize_messages(messages),
"starttime": get_current_timestamp()
})
pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期
# 创建索引end_user_id -> session_id 映射
pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60)
result = pipe.execute()
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
logger.debug(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
return session_id
def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]:
def get_sessions_count(self, end_user_id: str) -> tuple[int, list[dict]] | bool:
"""
通过 end_user_id 查询访问次数统计
@@ -327,7 +325,7 @@ class RedisCountStore:
try:
# 使用索引键快速查找
index_key = f'session:count:index:{end_user_id}'
# 检查索引键类型,避免 WRONGTYPE 错误
try:
key_type = self.r.type(index_key)
@@ -335,35 +333,40 @@ class RedisCountStore:
self.r.delete(index_key)
return False
except Exception as type_error:
print(f"[get_sessions_count] 检查键类型失败: {type_error}")
logger.error(f"[get_sessions_count] 检查键类型失败: {type_error}")
session_id = self.r.get(index_key)
if not session_id:
return False
# 直接获取数据
key = generate_session_key(session_id, key_type="count")
data = self.r.hgetall(key)
if not data:
# 索引存在但数据不存在,清理索引
self.r.delete(index_key)
return False
count = data.get('count')
messages_str = data.get('messages')
if count is not None:
messages = deserialize_messages(messages_str)
return [int(count), messages]
messages: list[dict] = deserialize_messages(messages_str)
return int(count), messages
return False
except Exception as e:
print(f"[get_sessions_count] 查询失败: {e}")
logger.error(f"[get_sessions_count] 查询失败: {e}")
return False
def update_sessions_count(self, end_user_id: str, new_count: int,
messages: Any) -> bool:
def update_sessions_count(
self,
end_user_id: str,
new_count: int,
messages: Any
) -> bool:
"""
通过 end_user_id 修改访问次数统计(优化版:使用索引)
@@ -378,39 +381,39 @@ class RedisCountStore:
try:
# 使用索引键快速查找
index_key = f'session:count:index:{end_user_id}'
# 检查索引键类型,避免 WRONGTYPE 错误
try:
key_type = self.r.type(index_key)
if key_type != 'string' and key_type != 'none':
# 索引键类型错误,删除并返回 False
print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
logger.warning(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
self.r.delete(index_key)
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
return False
except Exception as type_error:
print(f"[update_sessions_count] 检查键类型失败: {type_error}")
logger.error(f"[update_sessions_count] 检查键类型失败: {type_error}")
session_id = self.r.get(index_key)
if not session_id:
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
return False
# 直接更新数据
key = generate_session_key(session_id, key_type="count")
messages_str = serialize_messages(messages)
pipe = self.r.pipeline()
pipe.hset(key, 'count', int(new_count))
pipe.hset(key, 'count', str(new_count))
pipe.hset(key, 'messages', messages_str)
result = pipe.execute()
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
logger.debug(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
return True
except Exception as e:
print(f"[update_sessions_count] 更新失败: {e}")
logger.debug(f"[update_sessions_count] 更新失败: {e}")
return False
def delete_all_count_sessions(self) -> int:
@@ -428,7 +431,7 @@ class RedisCountStore:
class RedisSessionStore:
"""Redis 会话存储类,用于管理会话数据"""
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
"""
初始化 Redis 连接
@@ -451,9 +454,9 @@ class RedisSessionStore:
self.uudi = session_id
# ==================== 写入操作 ====================
def save_session(self, userid: str, messages: str, aimessages: str,
apply_id: str, end_user_id: str) -> str:
def save_session(self, userid: str, messages: str, aimessages: str,
apply_id: str, end_user_id: str) -> str:
"""
写入一条会话数据,返回 session_id
@@ -483,14 +486,14 @@ class RedisSessionStore:
})
result = pipe.execute()
print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
logger.debug(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
return session_id
except Exception as e:
print(f"[save_session] 保存会话失败: {e}")
logger.error(f"[save_session] 保存会话失败: {e}")
raise e
# ==================== 读取操作 ====================
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
"""
读取一条会话数据
@@ -520,8 +523,8 @@ class RedisSessionStore:
sessions[sid] = self.get_session(sid)
return sessions
def find_user_apply_group(self, sessionid: str, apply_id: str,
end_user_id: str) -> List[Dict[str, str]]:
def find_user_apply_group(self, sessionid: str, apply_id: str,
end_user_id: str) -> List[Dict[str, str]]:
"""
根据 sessionid、apply_id 和 end_user_id 查询会话数据返回最新的6条
@@ -535,10 +538,10 @@ class RedisSessionStore:
"""
import time
start_time = time.time()
keys = self.r.keys('session:*')
if not keys:
print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
logger.debug(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
return []
# 批量获取数据
@@ -556,21 +559,21 @@ class RedisSessionStore:
continue
if (data.get('apply_id') == apply_id and
data.get('end_user_id') == end_user_id):
data.get('end_user_id') == end_user_id):
# 支持模糊匹配或完全匹配 sessionid
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
matched_items.append(format_session_data(data, include_time=True))
# 排序、限制数量并移除时间字段
result_items = sort_and_limit_results(matched_items, limit=6)
elapsed_time = time.time() - start_time
print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
logger.debug(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
return result_items
# ==================== 更新操作 ====================
def update_session(self, session_id: str, field: str, value: Any) -> bool:
"""
更新单个字段
@@ -591,7 +594,7 @@ class RedisSessionStore:
return bool(results[0])
# ==================== 删除操作 ====================
def delete_session(self, session_id: str) -> int:
"""
删除单条会话
@@ -632,7 +635,7 @@ class RedisSessionStore:
keys = self.r.keys('session:*')
if not keys:
print("[delete_duplicate_sessions] 没有会话数据")
logger.debug("[delete_duplicate_sessions] 没有会话数据")
return 0
# 批量获取所有数据
@@ -678,7 +681,7 @@ class RedisSessionStore:
deleted_count += len(batch)
elapsed_time = time.time() - start_time
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}")
logger.debug(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}")
return deleted_count

View File

@@ -152,6 +152,24 @@ async def write(
# Step 3: Save all data to Neo4j database
step_start = time.time()
# Neo4j 写入前:清洗用户/AI助手实体之间的别名交叉污染
# 从 Neo4j 查询已有的 AI 助手别名,与本轮实体中的 AI 助手别名合并,
# 确保用户实体的 aliases 不包含 AI 助手的名字
try:
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
clean_cross_role_aliases,
fetch_neo4j_assistant_aliases,
)
neo4j_assistant_aliases = set()
if all_entity_nodes:
_eu_id = all_entity_nodes[0].end_user_id
if _eu_id:
neo4j_assistant_aliases = await fetch_neo4j_assistant_aliases(neo4j_connector, _eu_id)
clean_cross_role_aliases(all_entity_nodes, external_assistant_aliases=neo4j_assistant_aliases)
logger.info(f"Neo4j 写入前别名清洗完成AI助手别名排除集大小: {len(neo4j_assistant_aliases)}")
except Exception as e:
logger.warning(f"Neo4j 写入前别名清洗失败(不影响主流程): {e}")
# 添加死锁重试机制
max_retries = 3
retry_delay = 1 # 秒

View File

@@ -56,7 +56,7 @@ class LLMClient(ABC):
self.max_retries = self.config.max_retries
self.timeout = self.config.timeout
logger.info(
logger.debug(
f"初始化 LLM 客户端: provider={self.provider}, "
f"model={self.model_name}, max_retries={self.max_retries}"
)

View File

@@ -43,6 +43,7 @@ load_dotenv()
logger = get_memory_logger(__name__)
def _parse_datetime(value: Any) -> Optional[datetime]:
"""Parse ISO `created_at` strings of the form 'YYYY-MM-DDTHH:MM:SS.ssssss'."""
if value is None:
@@ -75,7 +76,7 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
if score_field == "activation_value" and score is None:
scores.append(None) # 保持 None稍后特殊处理
continue
if score is not None and isinstance(score, (int, float)):
scores.append(float(score))
else:
@@ -83,10 +84,10 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
if not scores:
return results
# 过滤掉 None 值,只对有效分数进行归一化
valid_scores = [s for s in scores if s is not None]
if not valid_scores:
# 所有分数都是 None不进行归一化
for item in results:
@@ -94,7 +95,7 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
item[f"normalized_{score_field}"] = None
return results
if len(valid_scores) == 1: # Single valid score, set to 1.0
if len(valid_scores) == 1: # Single valid score, set to 1.0
for item, score in zip(results, scores):
if score_field in item or score_field == "activation_value":
if score is None:
@@ -132,7 +133,6 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
return results
def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Remove duplicate items from search results based on content.
@@ -150,52 +150,53 @@ def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
seen_ids = set()
seen_content = set()
deduplicated = []
for item in items:
# Try multiple ID fields to identify unique items
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
# Extract content from various possible fields
content = (
item.get("text") or
item.get("content") or
item.get("statement") or
item.get("name") or
""
item.get("text") or
item.get("content") or
item.get("statement") or
item.get("name") or
""
)
# Normalize content for comparison (strip whitespace and lowercase)
normalized_content = str(content).strip().lower() if content else ""
# Check if we've seen this ID or content before
is_duplicate = False
if item_id and item_id in seen_ids:
is_duplicate = True
elif normalized_content and normalized_content in seen_content:
# Only check content duplication if content is not empty
is_duplicate = True
if not is_duplicate:
# Mark as seen
if item_id:
seen_ids.add(item_id)
if normalized_content: # Only track non-empty content
seen_content.add(normalized_content)
deduplicated.append(item)
return deduplicated
def rerank_with_activation(
keyword_results: Dict[str, List[Dict[str, Any]]],
embedding_results: Dict[str, List[Dict[str, Any]]],
alpha: float = 0.6,
limit: int = 10,
forgetting_config: ForgettingEngineConfig | None = None,
activation_boost_factor: float = 0.8,
now: datetime | None = None,
keyword_results: Dict[str, List[Dict[str, Any]]],
embedding_results: Dict[str, List[Dict[str, Any]]],
alpha: float = 0.6,
limit: int = 10,
forgetting_config: ForgettingEngineConfig | None = None,
activation_boost_factor: float = 0.8,
now: datetime | None = None,
content_score_threshold: float = 0.5,
) -> Dict[str, List[Dict[str, Any]]]:
"""
两阶段排序:先按内容相关性筛选,再按激活值排序。
@@ -222,6 +223,8 @@ def rerank_with_activation(
forgetting_config: 遗忘引擎配置(当前未使用)
activation_boost_factor: 激活度对记忆强度的影响系数 (默认: 0.8)
now: 当前时间(用于遗忘计算)
content_score_threshold: 内容相关性最低阈值(基于归一化后的 content_score
低于此阈值的结果会被过滤。默认 0.5。
返回:
带评分元数据的重排序结果,按 final_score 排序
@@ -229,26 +232,26 @@ def rerank_with_activation(
# 验证权重范围
if not (0 <= alpha <= 1):
raise ValueError(f"alpha 必须在 [0, 1] 范围内,当前值: {alpha}")
# 初始化遗忘引擎(如果需要)
engine = None
if forgetting_config:
engine = ForgettingEngine(forgetting_config)
now_dt = now or datetime.now()
reranked: Dict[str, List[Dict[str, Any]]] = {}
for category in ["statements", "chunks", "entities", "summaries", "communities"]:
keyword_items = keyword_results.get(category, [])
embedding_items = embedding_results.get(category, [])
# 步骤 1: 归一化分数
keyword_items = normalize_scores(keyword_items, "score")
embedding_items = normalize_scores(embedding_items, "score")
# 步骤 2: 按 ID 合并结果(去重)
combined_items: Dict[str, Dict[str, Any]] = {}
# 添加关键词结果
for item in keyword_items:
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
@@ -257,7 +260,7 @@ def rerank_with_activation(
combined_items[item_id] = item.copy()
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
combined_items[item_id]["embedding_score"] = 0 # 默认值
# 添加或更新向量嵌入结果
for item in embedding_items:
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
@@ -271,18 +274,18 @@ def rerank_with_activation(
combined_items[item_id] = item.copy()
combined_items[item_id]["bm25_score"] = 0 # 默认值
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
# 步骤 3: 归一化激活度分数
# 为所有项准备激活度值列表
items_list = list(combined_items.values())
items_list = normalize_scores(items_list, "activation_value")
# 更新 combined_items 中的归一化激活度分数
for item in items_list:
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
if item_id and item_id in combined_items:
combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value")
# 步骤 4: 计算基础分数和最终分数
for item_id, item in combined_items.items():
bm25_norm = float(item.get("bm25_score", 0) or 0)
@@ -290,45 +293,45 @@ def rerank_with_activation(
# normalized_activation_value 为 None 表示该节点无激活值,保留 None 语义
raw_act_norm = item.get("normalized_activation_value")
act_norm = float(raw_act_norm) if raw_act_norm is not None else None
# 第一阶段只考虑内容相关性BM25 + Embedding
# alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重
content_score = alpha * bm25_norm + (1 - alpha) * emb_norm
base_score = content_score # 第一阶段用内容分数
# 存储激活度分数供第二阶段使用None 表示无激活值,不参与激活值排序)
item["activation_score"] = act_norm # 可能为 None
item["content_score"] = content_score
item["base_score"] = base_score
# 步骤 5: 应用遗忘曲线(可选)
if engine:
# 计算受激活度影响的记忆强度
importance = float(item.get("importance_score", 0.5) or 0.5)
# 获取 activation_value
activation_val = item.get("activation_value")
# 只对有激活值的节点应用遗忘曲线
if activation_val is not None and isinstance(activation_val, (int, float)):
activation_val = float(activation_val)
# 计算记忆强度importance_score × (1 + activation_value × boost_factor)
memory_strength = importance * (1 + activation_val * activation_boost_factor)
# 计算经过的时间(天数)
dt = _parse_datetime(item.get("created_at"))
if dt is None:
time_elapsed_days = 0.0
else:
time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
# 获取遗忘权重
forgetting_weight = engine.calculate_weight(
time_elapsed=time_elapsed_days,
memory_strength=memory_strength
)
# 应用到基础分数
item["forgetting_weight"] = forgetting_weight
item["final_score"] = base_score * forgetting_weight
@@ -338,7 +341,7 @@ def rerank_with_activation(
else:
# 不使用遗忘曲线
item["final_score"] = base_score
# 步骤 6: 两阶段排序和限制
# 第一阶段按内容相关性base_score排序取 Top-K
first_stage_limit = limit * 3 # 可配置取3倍候选
@@ -347,11 +350,11 @@ def rerank_with_activation(
key=lambda x: float(x.get("base_score", 0) or 0), # 按内容分数排序
reverse=True
)[:first_stage_limit]
# 第二阶段:分离有激活值和无激活值的节点
items_with_activation = []
items_without_activation = []
for item in first_stage_sorted:
activation_score = item.get("activation_score")
# 检查是否有有效的激活值(不是 None
@@ -359,14 +362,14 @@ def rerank_with_activation(
items_with_activation.append(item)
else:
items_without_activation.append(item)
# 优先按激活值排序有激活值的节点
sorted_with_activation = sorted(
items_with_activation,
key=lambda x: float(x.get("activation_score", 0) or 0),
reverse=True
)
# 如果有激活值的节点不足 limit用无激活值的节点补充
if len(sorted_with_activation) < limit:
needed = limit - len(sorted_with_activation)
@@ -374,7 +377,7 @@ def rerank_with_activation(
sorted_items = sorted_with_activation + items_without_activation[:needed]
else:
sorted_items = sorted_with_activation[:limit]
# 两阶段排序完成,更新 final_score 以反映实际排序依据
# Stage 1: 按 content_score 筛选候选(已完成)
# Stage 2: 按 activation_score 排序(已完成)
@@ -390,16 +393,29 @@ def rerank_with_activation(
else:
# 无激活值:使用内容相关性分数
item["final_score"] = item.get("base_score", 0)
# 最终去重确保没有重复项
if content_score_threshold > 0:
before_count = len(sorted_items)
sorted_items = [
item for item in sorted_items
if float(item.get("content_score", 0) or 0) >= content_score_threshold
]
filtered_count = before_count - len(sorted_items)
if filtered_count > 0:
logger.info(
f"[rerank] {category}: filtered {filtered_count}/{before_count} "
f"items below content_score_threshold={content_score_threshold}"
)
sorted_items = _deduplicate_results(sorted_items)
reranked[category] = sorted_items
return reranked
def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str], log_file: str = None):
def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str],
log_file: str = None):
"""Log search query information using the logger.
Args:
@@ -412,7 +428,7 @@ def log_search_query(query_text: str, search_type: str, end_user_id: str | None,
"""
# Ensure the query text is plain and clean before logging
cleaned_query = extract_plain_query(query_text)
# Log using the standard logger
logger.info(
f"Search query: query='{cleaned_query}', type={search_type}, "
@@ -439,8 +455,8 @@ def _remove_keys_recursive(obj: Any, keys_to_remove: List[str]) -> Any:
def apply_reranker_placeholder(
results: Dict[str, List[Dict[str, Any]]],
query_text: str,
results: Dict[str, List[Dict[str, Any]]],
query_text: str,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Placeholder for a cross-encoder reranker.
@@ -483,7 +499,7 @@ def apply_reranker_placeholder(
# ) -> Dict[str, List[Dict[str, Any]]]:
# """
# Apply LLM-based reranking to search results.
# Args:
# results: Search results organized by category
# query_text: Original search query
@@ -491,7 +507,7 @@ def apply_reranker_placeholder(
# llm_weight: Weight for LLM score (0.0-1.0, higher favors LLM)
# top_k: Maximum number of items to rerank per category
# batch_size: Number of items to process concurrently
# Returns:
# Reranked results with final_score and reranker_model fields
# """
@@ -501,18 +517,18 @@ def apply_reranker_placeholder(
# # except Exception as e:
# # logger.debug(f"Failed to load reranker config: {e}")
# # rc = {}
# # Check if reranking is enabled
# enabled = rc.get("enabled", False)
# if not enabled:
# logger.debug("LLM reranking is disabled in configuration")
# return results
# # Load configuration parameters with defaults
# llm_weight = llm_weight if llm_weight is not None else rc.get("llm_weight", 0.5)
# top_k = top_k if top_k is not None else rc.get("top_k", 20)
# batch_size = batch_size if batch_size is not None else rc.get("batch_size", 5)
# # Initialize reranker client if not provided
# if reranker_client is None:
# try:
@@ -520,10 +536,10 @@ def apply_reranker_placeholder(
# except Exception as e:
# logger.warning(f"Failed to initialize reranker client: {e}, skipping LLM reranking")
# return results
# # Get model name for metadata
# model_name = getattr(reranker_client, 'model_name', 'unknown')
# # Process each category
# reranked_results = {}
# for category in ["statements", "chunks", "entities", "summaries"]:
@@ -531,38 +547,38 @@ def apply_reranker_placeholder(
# if not items:
# reranked_results[category] = []
# continue
# # Select top K items by combined_score for reranking
# sorted_items = sorted(
# items,
# key=lambda x: float(x.get("combined_score", x.get("score", 0.0)) or 0.0),
# reverse=True
# )
# top_items = sorted_items[:top_k]
# remaining_items = sorted_items[top_k:]
# # Extract text content from each item
# def extract_text(item: Dict[str, Any]) -> str:
# """Extract text content from a result item."""
# # Try different text fields based on category
# text = item.get("text") or item.get("content") or item.get("statement") or item.get("name") or ""
# return str(text).strip()
# # Batch items for concurrent processing
# batches = []
# for i in range(0, len(top_items), batch_size):
# batch = top_items[i:i + batch_size]
# batches.append(batch)
# # Process batches concurrently
# async def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
# """Process a batch of items with LLM relevance scoring."""
# scored_batch = []
# for item in batch:
# item_text = extract_text(item)
# # Skip items with no text
# if not item_text:
# item_copy = item.copy()
@@ -572,7 +588,7 @@ def apply_reranker_placeholder(
# item_copy["reranker_model"] = model_name
# scored_batch.append(item_copy)
# continue
# # Create relevance scoring prompt
# prompt = f"""Given the search query and a result item, rate the relevance of the item to the query on a scale from 0.0 to 1.0.
@@ -585,15 +601,15 @@ def apply_reranker_placeholder(
# - 1.0 means perfectly relevant
# Relevance score:"""
# # Send request to LLM
# try:
# messages = [{"role": "user", "content": prompt}]
# response = await reranker_client.chat(messages)
# # Parse LLM response to extract relevance score
# response_text = str(response.content if hasattr(response, 'content') else response).strip()
# # Try to extract a float from the response
# try:
# # Remove any non-numeric characters except decimal point
@@ -608,11 +624,11 @@ def apply_reranker_placeholder(
# except (ValueError, AttributeError) as e:
# logger.warning(f"Invalid LLM score format: {response_text}, using combined_score. Error: {e}")
# llm_score = None
# # Calculate final score
# item_copy = item.copy()
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
# if llm_score is not None:
# final_score = (1 - llm_weight) * combined_score + llm_weight * llm_score
# item_copy["llm_relevance_score"] = llm_score
@@ -620,7 +636,7 @@ def apply_reranker_placeholder(
# # Use combined_score as fallback
# final_score = combined_score
# item_copy["llm_relevance_score"] = combined_score
# item_copy["final_score"] = final_score
# item_copy["reranker_model"] = model_name
# scored_batch.append(item_copy)
@@ -632,14 +648,14 @@ def apply_reranker_placeholder(
# item_copy["llm_relevance_score"] = combined_score
# item_copy["reranker_model"] = model_name
# scored_batch.append(item_copy)
# return scored_batch
# # Process all batches concurrently
# try:
# batch_tasks = [process_batch(batch) for batch in batches]
# batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
# # Merge batch results
# scored_items = []
# for result in batch_results:
@@ -647,7 +663,7 @@ def apply_reranker_placeholder(
# logger.warning(f"Batch processing failed: {result}")
# continue
# scored_items.extend(result)
# # Add remaining items (not in top K) with their combined_score as final_score
# for item in remaining_items:
# item_copy = item.copy()
@@ -655,11 +671,11 @@ def apply_reranker_placeholder(
# item_copy["final_score"] = combined_score
# item_copy["reranker_model"] = model_name
# scored_items.append(item_copy)
# # Sort all items by final_score in descending order
# scored_items.sort(key=lambda x: float(x.get("final_score", 0.0) or 0.0), reverse=True)
# reranked_results[category] = scored_items
# except Exception as e:
# logger.error(f"Error in LLM reranking for category {category}: {e}, returning original results")
# # Return original items with combined_score as final_score
@@ -668,22 +684,22 @@ def apply_reranker_placeholder(
# item["final_score"] = combined_score
# item["reranker_model"] = model_name
# reranked_results[category] = items
# return reranked_results
async def run_hybrid_search(
query_text: str,
search_type: str,
end_user_id: str | None,
limit: int,
include: List[str],
output_path: str | None,
memory_config: "MemoryConfig",
rerank_alpha: float = 0.6,
activation_boost_factor: float = 0.8,
use_forgetting_rerank: bool = False,
use_llm_rerank: bool = False,
query_text: str,
search_type: str,
end_user_id: str | None,
limit: int,
include: List[str],
output_path: str | None,
memory_config: "MemoryConfig",
rerank_alpha: float = 0.6,
activation_boost_factor: float = 0.8,
use_forgetting_rerank: bool = False,
use_llm_rerank: bool = False,
):
"""
@@ -699,7 +715,7 @@ async def run_hybrid_search(
# Clean and normalize the incoming query before use/logging
query_text = extract_plain_query(query_text)
# Validate query is not empty after cleaning
if not query_text or not query_text.strip():
logger.warning("Empty query after cleaning, returning empty results")
@@ -716,7 +732,7 @@ async def run_hybrid_search(
"error": "Empty query"
}
}
# Log the search query
log_search_query(query_text, search_type, end_user_id, limit, include)
@@ -747,7 +763,7 @@ async def run_hybrid_search(
# Embedding-based search
logger.info("[PERF] Starting embedding search...")
embedding_start = time.time()
# 从数据库读取嵌入器配置(按 ID并构建 RedBearModelConfig
config_load_start = time.time()
try:
@@ -758,8 +774,7 @@ async def run_hybrid_search(
model_name=embedder_config_dict["model_name"],
provider=embedder_config_dict["provider"],
api_key=embedder_config_dict["api_key"],
base_url=embedder_config_dict["base_url"],
type="llm"
base_url=embedder_config_dict["base_url"]
)
config_load_time = time.time() - config_load_start
logger.info(f"[PERF] Config loading took {config_load_time:.4f}s")
@@ -769,7 +784,7 @@ async def run_hybrid_search(
embedder = OpenAIEmbedderClient(model_config=rb_config)
embedder_init_time = time.time() - embedder_init_start
logger.info(f"[PERF] Embedder init took {embedder_init_time:.4f}s")
embedding_task = asyncio.create_task(
search_graph_by_embedding(
connector=connector,
@@ -789,7 +804,7 @@ async def run_hybrid_search(
if keyword_task:
keyword_results = await keyword_task
keyword_latency = time.time() - keyword_start
keyword_latency = time.time() - search_start_time
latency_metrics["keyword_search_latency"] = round(keyword_latency, 4)
logger.info(f"[PERF] Keyword search completed in {keyword_latency:.4f}s")
if search_type == "keyword":
@@ -799,7 +814,7 @@ async def run_hybrid_search(
if embedding_task:
embedding_results = await embedding_task
embedding_latency = time.time() - embedding_start
embedding_latency = time.time() - search_start_time
latency_metrics["embedding_search_latency"] = round(embedding_latency, 4)
logger.info(f"[PERF] Embedding search completed in {embedding_latency:.4f}s")
if search_type == "embedding":
@@ -811,7 +826,8 @@ async def run_hybrid_search(
if search_type == "hybrid":
results["combined_summary"] = {
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
"total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
"total_embedding_results": sum(
len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
"search_query": query_text,
"search_timestamp": datetime.now().isoformat()
}
@@ -819,7 +835,7 @@ async def run_hybrid_search(
# Apply two-stage reranking with ACTR activation calculation
rerank_start = time.time()
logger.info("[PERF] Using two-stage reranking with ACTR activation")
# 加载遗忘引擎配置
config_start = time.time()
try:
@@ -830,7 +846,7 @@ async def run_hybrid_search(
forgetting_cfg = ForgettingEngineConfig()
config_time = time.time() - config_start
logger.info(f"[PERF] Forgetting config loading took {config_time:.4f}s")
# 统一使用激活度重排序(两阶段:检索 + ACTR计算
rerank_compute_start = time.time()
reranked_results = rerank_with_activation(
@@ -843,14 +859,14 @@ async def run_hybrid_search(
)
rerank_compute_time = time.time() - rerank_compute_start
logger.info(f"[PERF] Rerank computation took {rerank_compute_time:.4f}s")
rerank_latency = time.time() - rerank_start
latency_metrics["reranking_latency"] = round(rerank_latency, 4)
logger.info(f"[PERF] Total reranking completed in {rerank_latency:.4f}s")
# Optional: apply reranker placeholder if enabled via config
reranked_results = apply_reranker_placeholder(reranked_results, query_text)
# Apply LLM reranking if enabled
llm_rerank_applied = False
# if use_llm_rerank:
@@ -863,11 +879,12 @@ async def run_hybrid_search(
# logger.info("LLM reranking applied successfully")
# except Exception as e:
# logger.warning(f"LLM reranking failed: {e}, using previous scores")
results["reranked_results"] = reranked_results
results["combined_summary"] = {
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
"total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
"total_embedding_results": sum(
len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
"total_reranked_results": sum(len(v) if isinstance(v, list) else 0 for v in reranked_results.values()),
"search_query": query_text,
"search_timestamp": datetime.now().isoformat(),
@@ -880,13 +897,13 @@ async def run_hybrid_search(
# Calculate total latency
total_latency = time.time() - search_start_time
latency_metrics["total_latency"] = round(total_latency, 4)
# Add latency metrics to results
if "combined_summary" in results:
results["combined_summary"]["latency_metrics"] = latency_metrics
else:
results["latency_metrics"] = latency_metrics
logger.info(f"[PERF] ===== SEARCH PERFORMANCE SUMMARY =====")
logger.info(f"[PERF] Total search completed in {total_latency:.4f}s")
logger.info(f"[PERF] Latency breakdown: {json.dumps(latency_metrics, indent=2)}")
@@ -909,8 +926,10 @@ async def run_hybrid_search(
# Log search completion with result count
if search_type == "hybrid":
result_counts = {
"keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in keyword_results.items()},
"embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in embedding_results.items()}
"keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in
keyword_results.items()},
"embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in
embedding_results.items()}
}
else:
result_counts = {key: len(value) if isinstance(value, list) else 0 for key, value in results.items()}
@@ -928,12 +947,12 @@ async def run_hybrid_search(
async def search_by_temporal(
end_user_id: Optional[str] = "test",
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 1,
end_user_id: Optional[str] = "test",
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 1,
):
"""
Temporal search across Statements.
@@ -969,13 +988,13 @@ async def search_by_temporal(
async def search_by_keyword_temporal(
query_text: str,
end_user_id: Optional[str] = "test",
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 1,
query_text: str,
end_user_id: Optional[str] = "test",
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 1,
):
"""
Temporal keyword search across Statements.
@@ -1012,9 +1031,9 @@ async def search_by_keyword_temporal(
async def search_chunk_by_chunk_id(
chunk_id: str,
end_user_id: Optional[str] = "test",
limit: int = 1,
chunk_id: str,
end_user_id: Optional[str] = "test",
limit: int = 1,
):
"""
Search for Chunks by chunk_id.
@@ -1027,4 +1046,3 @@ async def search_chunk_by_chunk_id(
limit=limit
)
return {"chunks": chunks}

View File

@@ -4,6 +4,7 @@
import asyncio
import difflib # 提供字符串相似度计算工具
import importlib
import logging
import os
import re
from datetime import datetime
@@ -16,6 +17,8 @@ from app.core.memory.models.graph_models import (
)
from app.core.memory.models.variate_config import DedupConfig
logger = logging.getLogger(__name__)
# 模块级类型统一工具函数
def _unify_entity_type(canonical: ExtractedEntityNode, losing: ExtractedEntityNode, suggested_type: str = None) -> None:
@@ -198,6 +201,161 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
except Exception:
pass
# 用户和AI助手的占位名称集合用于名称标准化
_USER_PLACEHOLDER_NAMES = {"用户", "", "user", "i"}
_ASSISTANT_PLACEHOLDER_NAMES = {"ai助手", "助手", "人工智能助手", "智能助手", "智能体", "ai assistant", "assistant"}
# 标准化后的规范名称和类型
_CANONICAL_USER_NAME = "用户"
_CANONICAL_USER_TYPE = "用户"
_CANONICAL_ASSISTANT_NAME = "AI助手"
_CANONICAL_ASSISTANT_TYPE = "Agent"
# 用户和AI助手的所有可能名称用于判断实体是否为特殊角色实体
_ALL_USER_NAMES = _USER_PLACEHOLDER_NAMES
_ALL_ASSISTANT_NAMES = _ASSISTANT_PLACEHOLDER_NAMES
def _is_user_entity(ent: ExtractedEntityNode) -> bool:
"""判断实体是否为用户实体name 或 entity_type 匹配)"""
name = (getattr(ent, "name", "") or "").strip().lower()
etype = (getattr(ent, "entity_type", "") or "").strip()
return name in _ALL_USER_NAMES or etype == _CANONICAL_USER_TYPE
def _is_assistant_entity(ent: ExtractedEntityNode) -> bool:
"""判断实体是否为AI助手实体name 或 entity_type 匹配)"""
name = (getattr(ent, "name", "") or "").strip().lower()
etype = (getattr(ent, "entity_type", "") or "").strip()
return name in _ALL_ASSISTANT_NAMES or etype == _CANONICAL_ASSISTANT_TYPE
def _would_merge_cross_role(a: ExtractedEntityNode, b: ExtractedEntityNode) -> bool:
"""判断两个实体的合并是否会跨越用户/AI助手角色边界。
用户实体和AI助手实体永远不应该被合并在一起。
如果一方是用户实体、另一方是AI助手实体返回 True阻止合并
"""
return (
(_is_user_entity(a) and _is_assistant_entity(b))
or (_is_assistant_entity(a) and _is_user_entity(b))
)
def _normalize_special_entity_names(
entity_nodes: List[ExtractedEntityNode],
) -> None:
"""标准化用户和AI助手实体的名称和类型。
多轮对话中LLM 对同一角色可能使用不同的名称变体(如"用户"/""/"User"
"AI助手"/"助手"/"Assistant"),导致精确匹配无法合并。
此函数在去重前将这些变体统一为规范名称,并强制绑定 entity_type确保
- name="用户" 的实体 entity_type 一定为 "用户"
- name="AI助手" 的实体 entity_type 一定为 "Agent"
Args:
entity_nodes: 实体节点列表(原地修改)
"""
for ent in entity_nodes:
name = (getattr(ent, "name", "") or "").strip()
name_lower = name.lower()
if name_lower in _USER_PLACEHOLDER_NAMES:
ent.name = _CANONICAL_USER_NAME
ent.entity_type = _CANONICAL_USER_TYPE
elif name_lower in _ASSISTANT_PLACEHOLDER_NAMES:
ent.name = _CANONICAL_ASSISTANT_NAME
ent.entity_type = _CANONICAL_ASSISTANT_TYPE
# 第二步:清洗用户/AI助手之间的别名交叉污染复用 clean_cross_role_aliases
clean_cross_role_aliases(entity_nodes)
async def fetch_neo4j_assistant_aliases(neo4j_connector, end_user_id: str) -> set:
"""从 Neo4j 查询 AI 助手实体的所有别名(小写归一化)。
这是助手别名查询的唯一入口,供 write_tools 和 extraction_orchestrator 共用,
避免多处维护相同的 Cypher 和名称列表。
Args:
neo4j_connector: Neo4j 连接器实例(需提供 execute_query 方法)
end_user_id: 终端用户 ID
Returns:
小写归一化后的助手别名集合
"""
# 查询名称列表:规范名称 + 常见变体(与 _normalize_special_entity_names 标准化后一致)
query_names = [_CANONICAL_ASSISTANT_NAME, *_ASSISTANT_PLACEHOLDER_NAMES]
# 去重保序
query_names = list(dict.fromkeys(query_names))
cypher = """
MATCH (e:ExtractedEntity)
WHERE e.end_user_id = $end_user_id AND e.name IN $names
RETURN e.aliases AS aliases
"""
try:
result = await neo4j_connector.execute_query(
cypher, end_user_id=end_user_id, names=query_names
)
assistant_aliases: set = set()
for record in (result or []):
for alias in (record.get("aliases") or []):
assistant_aliases.add(alias.strip().lower())
if assistant_aliases:
logger.debug(f"Neo4j 中 AI 助手别名: {assistant_aliases}")
return assistant_aliases
except Exception as e:
logger.warning(f"查询 Neo4j AI 助手别名失败: {e}")
return set()
def clean_cross_role_aliases(
entity_nodes: List[ExtractedEntityNode],
external_assistant_aliases: set = None,
) -> None:
"""清洗用户实体和AI助手实体之间的别名交叉污染。
在 Neo4j 写入前调用,确保:
- 用户实体的 aliases 不包含 AI 助手的别名
- AI 助手实体的 aliases 不包含用户的别名
Args:
entity_nodes: 实体节点列表(原地修改)
external_assistant_aliases: 外部传入的 AI 助手别名集合(如从 Neo4j 查询),
与本轮实体中的 AI 助手别名合并使用
"""
# 收集本轮 AI 助手实体的所有别名
assistant_aliases = set(external_assistant_aliases or set())
user_aliases = set()
for ent in entity_nodes:
if _is_assistant_entity(ent):
for alias in (getattr(ent, "aliases", []) or []):
assistant_aliases.add(alias.strip().lower())
elif _is_user_entity(ent):
for alias in (getattr(ent, "aliases", []) or []):
user_aliases.add(alias.strip().lower())
# 从用户实体的 aliases 中移除 AI 助手别名
if assistant_aliases:
for ent in entity_nodes:
if _is_user_entity(ent):
original = getattr(ent, "aliases", []) or []
cleaned = [a for a in original if a.strip().lower() not in assistant_aliases]
if len(cleaned) < len(original):
ent.aliases = cleaned
# 从 AI 助手实体的 aliases 中移除用户别名
if user_aliases:
for ent in entity_nodes:
if _is_assistant_entity(ent):
original = getattr(ent, "aliases", []) or []
cleaned = [a for a in original if a.strip().lower() not in user_aliases]
if len(cleaned) < len(original):
ent.aliases = cleaned
def accurate_match(
entity_nodes: List[ExtractedEntityNode]
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]:
@@ -261,6 +419,10 @@ def accurate_match(
canonical = alias_index.get((ent_uid, ent_name))
# 确保不是自身
if canonical is not None and canonical.id != ent.id:
# 保护禁止跨角色合并用户实体和AI助手实体不能互相合并
if _would_merge_cross_role(canonical, ent):
i += 1
continue
_merge_attribute(canonical, ent)
id_redirect[ent.id] = canonical.id
for k, v in list(id_redirect.items()):
@@ -704,6 +866,11 @@ def fuzzy_match(
# 条件A快速通道alias_match_merge = True
# 条件B标准通道s_name ≥ tn AND s_type ≥ type_threshold AND overall ≥ tover
if alias_match_merge or (s_name >= tn and s_type >= type_threshold and overall >= tover):
# 保护禁止跨角色合并用户实体和AI助手实体不能互相合并
if _would_merge_cross_role(a, b):
j += 1
continue
# ========== 第六步:执行实体合并 ==========
# 6.1 合并别名
@@ -813,6 +980,12 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
b = entity_by_id.get(losing_id)
if not a or not b: # 若不存在 a 或 b可能已在精确或模糊阶段合并在之前阶段合并之后不会再处理但是处于审计的目的会记录
continue
# 保护禁止跨角色合并用户实体和AI助手实体不能互相合并
if _would_merge_cross_role(a, b):
llm_records.append(
f"[LLM阻断] 跨角色合并被阻止: {a.id} ({a.name}) 与 {b.id} ({b.name})"
)
continue
_merge_attribute(a, b)
# ID 重定向
try:
@@ -934,6 +1107,9 @@ async def deduplicate_entities_and_edges(
返回:去重后的实体、语句→实体边、实体↔实体边。
"""
local_llm_records: List[str] = [] # 作为“审计日志”的本地收集器 初始化保留为了之后对于LLM决策追溯
# 0) 标准化用户和AI助手实体名称确保多轮对话中的变体名称统一
_normalize_special_entity_names(entity_nodes)
# 1) 精确匹配
deduped_entities, id_redirect, exact_merge_map = accurate_match(entity_nodes)

View File

@@ -15,6 +15,7 @@ from app.core.memory.models.message_models import DialogData
from app.core.memory.models.variate_config import ExtractionPipelineConfig
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
deduplicate_entities_and_edges,
clean_cross_role_aliases,
)
from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import (
second_layer_dedup_and_merge_with_neo4j,
@@ -100,6 +101,10 @@ async def dedup_layers_and_merge_and_return(
except Exception as e:
print(f"Second-layer dedup failed: {e}")
# 第二层去重后,清洗用户/AI助手之间的别名交叉污染
# 第二层从 Neo4j 合并了旧实体,可能带入历史脏数据
clean_cross_role_aliases(fused_entity_nodes)
return (
dialogue_nodes,
chunk_nodes,

View File

@@ -44,6 +44,10 @@ from app.core.memory.models.variate_config import (
from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import (
dedup_layers_and_merge_and_return,
)
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
_USER_PLACEHOLDER_NAMES,
fetch_neo4j_assistant_aliases,
)
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import (
embedding_generation,
generate_entity_embeddings_from_triplets,
@@ -1341,14 +1345,20 @@ class ExtractionOrchestrator:
dialog_data_list: List[DialogData]
) -> None:
"""
从 Neo4j 读取用户实体的最终 aliases同步到 end_user 和 end_user_info 表
将本轮提取的用户别名同步到 end_user 和 end_user_info 表
注意:
1. other_name 使用本次对话提取的第一个别名(保持时间顺序)
2. aliases 从 Neo4j 读取(保持完整性)
注意:此方法在 Neo4j 写入之前调用,因此不能依赖 Neo4j 作为别名的权威数据源。
改为直接使用内存中去重后的 entity_nodes 的 aliases与 PgSQL 已有的 aliases 合并。
策略:
1. 从内存中的 entity_nodes 提取本轮用户别名current_aliases
2. 从去重后的 entity_nodes 中提取完整别名(含 Neo4j 二层去重合并的历史别名)
3. 从 PgSQL end_user_info 读取已有的 aliasesdb_aliases
4. 合并 db_aliases + deduped_aliases + current_aliases去重保序
5. 写回 PgSQL
Args:
entity_nodes: 实体节点列表
entity_nodes: 去重后的实体节点列表(内存中,含二层去重合并结果)
dialog_data_list: 对话数据列表
"""
try:
@@ -1361,23 +1371,40 @@ class ExtractionOrchestrator:
logger.warning("end_user_id 为空,跳过用户别名同步")
return
# 1. 提取本对话的用户别名(保持 LLM 提取的原始顺序,不排序)
current_aliases = self._extract_current_aliases(entity_nodes)
# 1. 提取本对话的用户别名(保持 LLM 提取的原始顺序,不排序)
current_aliases = self._extract_current_aliases(entity_nodes, dialog_data_list)
# 2. 从 Neo4j 获取完整 aliases权威数据源
neo4j_aliases = await self._fetch_neo4j_user_aliases(end_user_id)
# 1.5 从去重后的 entity_nodes 中提取完整别名
# 二层去重会将 Neo4j 中已有的历史别名合并到 entity_nodes 中,
# 这里提取出来确保 PgSQL 与 Neo4j 的别名保持同步
deduped_aliases = self._extract_deduped_entity_aliases(entity_nodes)
if not neo4j_aliases:
# Neo4j 中没有别名,使用本次对话提取的别名
neo4j_aliases = current_aliases
if not neo4j_aliases:
logger.debug(f"aliases 为空,跳过同步: end_user_id={end_user_id}")
return
# 1.6 从 Neo4j 查询已有的 AI 助手别名,作为额外的排除源
# (防止 LLM 未提取出 AI 助手实体时AI 别名泄漏到用户别名中)
neo4j_assistant_aliases = await self._fetch_neo4j_assistant_aliases(end_user_id)
if neo4j_assistant_aliases:
before_count = len(current_aliases)
current_aliases = [
a for a in current_aliases
if a.strip().lower() not in neo4j_assistant_aliases
]
if len(current_aliases) < before_count:
logger.info(f"通过 Neo4j AI 助手别名排除了 {before_count - len(current_aliases)} 个误归属别名")
# 同样过滤 deduped_aliases
deduped_aliases = [
a for a in deduped_aliases
if a.strip().lower() not in neo4j_assistant_aliases
]
logger.info(f"本次对话提取的 aliases: {current_aliases}")
logger.info(f"Neo4j 中的完整 aliases: {neo4j_aliases}")
if not current_aliases and not deduped_aliases:
logger.debug(f"本轮未提取到用户别名,跳过同步: end_user_id={end_user_id}")
return
# 3. 同步到数据库
logger.info(f"本轮对话提取的 aliases: {current_aliases}")
if deduped_aliases:
logger.info(f"去重后实体的完整 aliases含历史: {deduped_aliases}")
# 2. 同步到数据库
end_user_uuid = uuid.UUID(end_user_id)
with get_db_context() as db:
# 更新 end_user 表
@@ -1386,7 +1413,38 @@ class ExtractionOrchestrator:
logger.warning(f"未找到 end_user_id={end_user_id} 的用户记录")
return
new_name = self._resolve_other_name(end_user.other_name, current_aliases, neo4j_aliases)
# 3. 从 PgSQL 读取已有 aliases 并与本轮合并
info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid)
db_aliases = (info.aliases if info and info.aliases else [])
# 过滤掉占位名称
db_aliases = [a for a in db_aliases if a.strip().lower() not in self.USER_PLACEHOLDER_NAMES]
# 合并:已有 + 去重后完整别名 + 本轮新增,去重保序
merged_aliases = list(db_aliases)
seen_lower = {a.strip().lower() for a in merged_aliases}
# 先合并去重后实体的完整别名(含 Neo4j 历史别名)
for alias in deduped_aliases:
if alias.strip().lower() not in seen_lower:
merged_aliases.append(alias)
seen_lower.add(alias.strip().lower())
# 再合并本轮新提取的别名
for alias in current_aliases:
if alias.strip().lower() not in seen_lower:
merged_aliases.append(alias)
seen_lower.add(alias.strip().lower())
# 最终过滤:从合并结果中排除 AI 助手别名(清理历史脏数据)
if neo4j_assistant_aliases:
merged_aliases = [
a for a in merged_aliases
if a.strip().lower() not in neo4j_assistant_aliases
]
logger.info(f"PgSQL 已有 aliases: {db_aliases}")
logger.info(f"合并后 aliases: {merged_aliases}")
# 更新 end_user 表 other_name
new_name = self._resolve_other_name(end_user.other_name, current_aliases, merged_aliases)
if new_name is not None:
end_user.other_name = new_name
logger.info(f"更新 end_user 表 other_name → {new_name}")
@@ -1394,26 +1452,27 @@ class ExtractionOrchestrator:
logger.debug(f"end_user 表 other_name 保持不变: {end_user.other_name}")
# 更新或创建 end_user_info 记录
info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid)
if info:
new_name_info = self._resolve_other_name(info.other_name, current_aliases, neo4j_aliases)
new_name_info = self._resolve_other_name(info.other_name, current_aliases, merged_aliases)
if new_name_info is not None:
info.other_name = new_name_info
logger.info(f"更新 end_user_info 表 other_name → {new_name_info}")
if info.aliases != neo4j_aliases:
info.aliases = neo4j_aliases
logger.info(f"同步 Neo4j aliases 到 end_user_info: {neo4j_aliases}")
if info.aliases != merged_aliases:
info.aliases = merged_aliases
logger.info(f"同步合并后 aliases 到 end_user_info: {merged_aliases}")
else:
first_alias = current_aliases[0].strip() if current_aliases else ""
first_alias = current_aliases[0].strip() if current_aliases else (
deduped_aliases[0].strip() if deduped_aliases else ""
)
# 确保 first_alias 不是占位名称
if first_alias and first_alias not in self.USER_PLACEHOLDER_NAMES:
if first_alias and first_alias.lower() not in self.USER_PLACEHOLDER_NAMES:
db.add(EndUserInfo(
end_user_id=end_user_uuid,
other_name=first_alias,
aliases=neo4j_aliases,
aliases=merged_aliases,
meta_data={}
))
logger.info(f"创建 end_user_info 记录other_name={first_alias}, aliases={neo4j_aliases}")
logger.info(f"创建 end_user_info 记录other_name={first_alias}, aliases={merged_aliases}")
db.commit()
@@ -1423,49 +1482,81 @@ class ExtractionOrchestrator:
# 用户实体占位名称,不允许作为 other_name 或出现在 aliases 中
USER_PLACEHOLDER_NAMES = {'用户', '', 'User', 'I'}
# 复用 deduped_and_disamb 模块级常量,避免重复维护
USER_PLACEHOLDER_NAMES = _USER_PLACEHOLDER_NAMES
def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode]) -> List[str]:
"""实体节点提取用户别名(保持 LLM 提取的原始顺序,不进行任何排序
def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode], dialog_data_list=None) -> List[str]:
"""用户发言的原始实体中提取本轮新增别名(绕过去重污染
这个方法直接返回 LLM 提取的别名列表,并过滤掉占位名称("用户""""User""I")。
第一个别名将被用作 other_name。
策略:
仅从 dialog_data_list 中找到 speaker="user" 的 statement
从这些 statement 的 triplet_extraction_info 中提取用户实体的 aliases。
这样拿到的是 LLM 对用户原话的提取结果,不受去重合并的影响。
注意:不再使用去重后 entity_nodes 作为兜底,因为二层去重会将 Neo4j 历史别名
合并进来,导致历史别名被误认为"本轮提取"。历史别名的同步由
_extract_deduped_entity_aliases 负责。
Args:
entity_nodes: 实体节点列表
entity_nodes: 去重后的实体节点列表(未使用,保留参数兼容性)
dialog_data_list: 对话数据列表
Returns:
别名列表(保持 LLM 提取的原始顺序,已过滤占位名称
别名列表(保持原始顺序,已过滤)
"""
if not dialog_data_list:
return []
all_user_aliases = []
seen_lower = set()
for dialog in dialog_data_list:
for chunk in dialog.chunks:
speaker = getattr(chunk, 'speaker', None)
for statement in chunk.statements:
stmt_speaker = getattr(statement, 'speaker', None) or speaker
if stmt_speaker != "user":
continue
triplet_info = getattr(statement, 'triplet_extraction_info', None)
if not triplet_info:
continue
for entity in (triplet_info.entities or []):
ent_name = getattr(entity, 'name', '').strip()
if ent_name.lower() in self.USER_PLACEHOLDER_NAMES:
for alias in (getattr(entity, 'aliases', []) or []):
a = alias.strip()
if a and a.lower() not in self.USER_PLACEHOLDER_NAMES and a.lower() not in seen_lower:
all_user_aliases.append(a)
seen_lower.add(a.lower())
if all_user_aliases:
logger.debug(f"从用户原始发言提取到别名: {all_user_aliases}")
return all_user_aliases
def _extract_deduped_entity_aliases(self, entity_nodes: List[ExtractedEntityNode]) -> List[str]:
"""从去重后的用户实体中提取完整别名列表。
二层去重会将 Neo4j 中已有的历史别名合并到 entity_nodes 的用户实体中,
因此这里提取到的别名包含了历史积累的所有别名,可用于同步到 PgSQL。
Args:
entity_nodes: 去重后的实体节点列表(含二层去重合并结果)
Returns:
别名列表(已过滤占位名称,去重保序)
"""
for entity in entity_nodes:
if getattr(entity, 'name', '').strip() in self.USER_PLACEHOLDER_NAMES:
if getattr(entity, 'name', '').strip().lower() in self.USER_PLACEHOLDER_NAMES:
aliases = getattr(entity, 'aliases', []) or []
# 过滤掉占位名称,防止 "用户"/"我"/"User"/"I" 被存入 aliases 和 other_name
filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES]
logger.debug(f"提取到用户别名(原始顺序,已过滤占位名称): {filtered}")
return filtered
filtered = [
a for a in aliases
if a.strip().lower() not in self.USER_PLACEHOLDER_NAMES
]
if filtered:
return filtered
return []
async def _fetch_neo4j_user_aliases(self, end_user_id: str) -> List[str]:
"""从 Neo4j 查询用户实体的完整 aliases 列表(已过滤占位名称)"""
cypher = """
MATCH (e:ExtractedEntity)
WHERE e.end_user_id = $end_user_id AND e.name IN ['用户', '', 'User', 'I']
RETURN e.aliases AS aliases
LIMIT 1
"""
result = await Neo4jConnector().execute_query(cypher, end_user_id=end_user_id)
if not result:
logger.debug(f"Neo4j 中未找到用户实体: end_user_id={end_user_id}")
return []
aliases = result[0].get('aliases') or []
if not aliases:
logger.debug(f"Neo4j 用户实体 aliases 为空: end_user_id={end_user_id}")
return []
# 过滤掉占位名称,防止历史脏数据传播
filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES]
return filtered
async def _fetch_neo4j_assistant_aliases(self, end_user_id: str) -> set:
"""从 Neo4j 查询 AI 助手实体的所有别名(用于从用户别名中排除)"""
return await fetch_neo4j_assistant_aliases(self.connector, end_user_id)
def _resolve_other_name(
self,
@@ -1484,16 +1575,16 @@ class ExtractionOrchestrator:
注意:返回值不允许是占位名称("用户""""User""I"
"""
# 当前值为空或为占位名称时,需要更新
if not current or not current.strip() or current.strip() in self.USER_PLACEHOLDER_NAMES:
if not current or not current.strip() or current.strip().lower() in self.USER_PLACEHOLDER_NAMES:
candidate = current_aliases[0].strip() if current_aliases else None
# 确保候选值不是占位名称
if candidate and candidate in self.USER_PLACEHOLDER_NAMES:
if candidate and candidate.lower() in self.USER_PLACEHOLDER_NAMES:
return None
return candidate
if current not in neo4j_aliases:
candidate = neo4j_aliases[0].strip() if neo4j_aliases else None
# 确保候选值不是占位名称
if candidate and candidate in self.USER_PLACEHOLDER_NAMES:
if candidate and candidate.lower() in self.USER_PLACEHOLDER_NAMES:
return None
return candidate

View File

@@ -61,6 +61,7 @@ class TripletExtractor:
predicate_instructions=PREDICATE_DEFINITIONS,
language=self._get_language(),
ontology_types=self.ontology_types,
speaker=getattr(statement, 'speaker', None),
)
# Create messages for LLM

View File

@@ -1,6 +1,6 @@
import os
from jinja2 import Environment, FileSystemLoader
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
from app.core.memory.utils.log.logging_utils import log_prompt_rendering, log_template_rendering
# Setup Jinja2 environment
@@ -205,6 +205,7 @@ async def render_triplet_extraction_prompt(
predicate_instructions: dict = None,
language: str = "zh",
ontology_types: "OntologyTypeList | None" = None,
speaker: str = None,
) -> str:
"""
Renders the triplet extraction prompt using the extract_triplet.jinja2 template.
@@ -216,6 +217,7 @@ async def render_triplet_extraction_prompt(
predicate_instructions: Optional predicate instructions
language: The language to use for entity descriptions ("zh" for Chinese, "en" for English)
ontology_types: Optional OntologyTypeList containing predefined ontology types for entity classification
speaker: Speaker role ("user" or "assistant") for the current statement
Returns:
Rendered prompt content as string
@@ -223,7 +225,7 @@ async def render_triplet_extraction_prompt(
template = prompt_env.get_template("extract_triplet.jinja2")
# 准备本体类型数据
ontology_type_section = ""
ontology_type_section = None
ontology_type_names = []
type_hierarchy_hints = []
if ontology_types and ontology_types.types:
@@ -240,6 +242,7 @@ async def render_triplet_extraction_prompt(
ontology_types=ontology_type_section,
ontology_type_names=ontology_type_names,
type_hierarchy_hints=type_hierarchy_hints,
speaker=speaker,
)
# 记录渲染结果到提示日志(与示例日志结构一致)
log_prompt_rendering('triplet extraction', rendered_prompt)

View File

@@ -23,6 +23,16 @@ Extract entities and knowledge triplets from the given statement.
===Inputs===
**Chunk Content:** "{{ chunk_content }}"
**Statement:** "{{ statement }}"
{% if speaker %}
**Speaker:** {{ speaker }}
{% if speaker == "assistant" %}
{% if language == "zh" %}
⚠️ 当前陈述句来自 **AI助手的回复**。AI助手在回复中用来称呼用户的名字是**用户的别名**,不是 AI 助手的别名。但只能提取原文中逐字出现的名字,严禁推测或创造原文中不存在的别名变体。
{% else %}
⚠️ This statement is from the **AI assistant's reply**. Names the AI uses to address the user are **user's aliases**, NOT the AI assistant's aliases. But only extract names that appear VERBATIM in the text — never infer or fabricate alias variants.
{% endif %}
{% endif %}
{% endif %}
{% if ontology_types %}
===Ontology Type Guidance===
@@ -87,7 +97,17 @@ Extract entities and knowledge triplets from the given statement.
* "我叫张三,大家叫我小张" → aliases=["张三", "小张"](张三是第一个,将成为 other_name
* "大家叫我小李,我全名叫李明" → aliases=["小李", "李明"](小李先出现,将成为 other_name
- 空值:如果没有别名,使用 `[]`
- 重要:只提取本次对话中明确提到的别名,不要推测或添加未提及的名字
- **🚨🚨🚨 严禁幻觉:只提取对话原文中逐字出现的别名,绝对不能推测、衍生或创造任何未在原文中出现的名字。例如,看到"陈思远"不能自行添加"思远大人""远哥""小远"等变体。如果原文没有这些字,就不能出现在 aliases 中。**
- **🚨 归属区分:必须严格区分名称的归属对象。默认情况下,用户提到的名字归属用户实体。只有出现明确的第二人称命名表达(如"叫你""给你取名")时,才将名字归属 AI/助手实体。**
- **🚨 说话人视角:当 speaker 为 assistant 时AI 助手用来称呼用户的名字是用户的别名,必须归入用户实体的 aliases绝对不能归入 AI 助手实体。但同样只能提取原文中逐字出现的称呼,不能推测。**
* "我叫陈思远我给AI取名为远仔" → 用户 aliases=["陈思远"]AI助手 aliases=["远仔"]
* "我叫vv" → 用户 aliases=["vv"]没有给AI取名的表达名字归用户
* [speaker=assistant] "好的VV" → 用户 aliases=["VV"]AI 在称呼用户,原文中出现了"VV"
* [speaker=assistant] "我叫陈仔" → AI助手 aliases=["陈仔"]AI 在自我介绍,这是 AI 的别名)
* ❌ 错误:将"远仔"放入用户的 aliases"远仔"是给AI取的名字不是用户的名字
* ❌ 错误:用户说"我叫vv",却把"vv"放入 AI 助手的 aliases
* ❌ 错误AI 称呼用户为"VV",却把"VV"放入 AI 助手的 aliases
* ❌ 错误:原文只有"陈思远",却在 aliases 中添加"思远大人""远哥""小远"等从未出现的变体(这是幻觉)
{% else %}
- Include: nicknames, full names, abbreviations, alternative names
- Order: **The FIRST alias will be used as the user's primary display name (other_name). Put the most important/frequently used name FIRST**
@@ -96,7 +116,17 @@ Extract entities and knowledge triplets from the given statement.
* "I'm John, people call me Johnny" → aliases=["John", "Johnny"] (John is first, will become other_name)
* "People call me Mike, my full name is Michael" → aliases=["Mike", "Michael"] (Mike appears first, will become other_name)
- Empty: If no aliases, use `[]`
- Important: Only extract aliases explicitly mentioned in current conversation, do not infer or add unmentioned names
- **🚨🚨🚨 NO HALLUCINATION: Only extract aliases that appear VERBATIM in the original text. NEVER infer, derive, or fabricate names not present in the text. For example, seeing "John Smith" does NOT allow adding "Johnny", "Smithy", "Mr. Smith" unless those exact strings appear in the conversation.**
- **🚨 Ownership distinction: By default, all names mentioned by the user belong to the user entity. Only assign a name to the AI/assistant entity when an explicit second-person naming expression (e.g., "I'll call you", "your name is") is present.**
- **🚨 Speaker perspective: When speaker is "assistant", names the AI uses to address the user are the USER's aliases and MUST go into the user entity's aliases, NEVER into the AI assistant entity's aliases. But only extract names that appear verbatim in the text, never infer.**
* "I'm Alex, I'll call you Buddy" → User aliases=["Alex"], AI assistant aliases=["Buddy"]
* "I'm vv" → User aliases=["vv"] (no AI-naming expression, name belongs to user)
* [speaker=assistant] "Sure thing, VV" → User aliases=["VV"] (AI addressing the user, "VV" appears in text)
* [speaker=assistant] "I'm Jarvis" → AI assistant aliases=["Jarvis"] (AI self-introduction, this is AI's alias)
* ❌ Wrong: putting "Buddy" in user's aliases ("Buddy" is a name for the AI, not the user)
* ❌ Wrong: User says "I'm vv" but "vv" is put in AI assistant's aliases
* ❌ Wrong: AI calls user "VV" but "VV" is put in AI assistant's aliases
* ❌ Wrong: Text only has "John Smith" but aliases include "Johnny", "Smithy" (hallucinated variants)
{% endif %}
@@ -122,7 +152,60 @@ Extract entities and knowledge triplets from the given statement.
4. **ALIASES ORDER:**
4. **AI/ASSISTANT ENTITY SPECIAL HANDLING:**
{% if language == "zh" %}
- **🚨 默认规则:如果对话中没有出现明确指向 AI/助手的命名表达,则所有名字都归属于用户实体。不要猜测或推断某个名字是给 AI 取的。**
- 只有当用户**明确**对 AI/助手进行命名时,才创建 AI/助手实体并将对应名字放入其 aliases
- AI/助手实体的 name 字段:使用 "AI助手"
- 用户给 AI 取的名字:放入 AI/助手实体的 aliases
- **🚨 禁止将用户给 AI 取的名字放入用户实体的 aliases 中**
- **必须出现以下明确的命名表达才能判定为给 AI 取名:**「给你取名」「叫你」「称呼你为」「给AI取名」「你的名字是」「以后叫你」「你就叫」「你不叫X了」「你现在叫」等**第二人称(你)或明确指向 AI 的命名句式**
- **🚨 "你不叫X了"/"你不叫X你叫Y" 句式X 和 Y 都是 AI 的名字(旧名和新名),绝对不是用户的名字。因为句子主语是"你"AI。**
- **以下情况名字归属用户,不是给 AI 取名:**「我叫」「我的名字是」「叫我」「我是」「大家叫我」「我的英文名是」「我的昵称是」等**第一人称(我)的自我介绍句式**
- **🚨 speaker=assistant 时的特殊规则:**
* AI 用来称呼用户的名字 → 归入**用户**实体的 aliases但必须是原文中逐字出现的称呼不能推测
* AI 自称的名字(如"我叫陈仔""我是你的助手")→ 归入**AI助手**实体的 aliases
* 判断依据AI 说"你叫X"或用 X 称呼用户 → X 是用户别名AI 说"我叫X"或"我是X" → X 是 AI 别名
- 示例:
* "我叫vv" → 用户实体: name="用户", aliases=["vv"](第一人称自我介绍,名字归用户)
* "我的英文名叫vv" → 用户实体: name="用户", aliases=["vv"](第一人称自我介绍,名字归用户)
* "我叫陈思远我给AI取名为远仔" → 用户实体: name="用户", aliases=["陈思远"]AI实体: name="AI助手", aliases=["远仔"]
* "叫你小助,我自己叫老王" → 用户实体: name="用户", aliases=["老王"]AI实体: name="AI助手", aliases=["小助"]
* "你不叫远仔了,你现在叫陈仔" → AI实体: name="AI助手", aliases=["陈仔"]"远仔"是AI旧名"陈仔"是AI新名都归AI。不要把"远仔"或"陈仔"放入用户的aliases
* [speaker=assistant] "好的VV今天想干点啥" → 用户实体: name="用户", aliases=["VV"]AI 在称呼用户,原文中出现了"VV"
* [speaker=assistant] "你叫陈思远,我叫陈仔" → 用户实体: name="用户", aliases=["陈思远"]AI实体: name="AI助手", aliases=["陈仔"]
* ❌ 错误:用户说"我叫vv",却把"vv"放入 AI 助手的 aliases没有任何给 AI 取名的表达)
* ❌ 错误AI 称呼用户为"VV",却把"VV"放入 AI 助手的 aliases
* ❌ 错误aliases=["陈思远", "远仔"]"远仔"是给AI取的名字不是用户的名字
* ❌ 错误:原文只有"陈思远",却在 aliases 中添加"思远大人""远哥""小远"等从未出现的变体(这是幻觉)
{% else %}
- **🚨 Default rule: If there is NO explicit AI/assistant naming expression in the conversation, ALL names belong to the user entity. Do NOT guess or infer that a name is for the AI.**
- Only create an AI/assistant entity when the user **explicitly** names the AI/assistant
- AI/assistant entity name field: use "AI Assistant"
- Names the user gives to the AI: put in the AI/assistant entity's aliases
- **🚨 NEVER put names given to the AI into the user entity's aliases**
- **An AI-naming expression MUST be present to assign a name to the AI:** "I'll call you", "your name is", "I name you", "let me call you", "you'll be called", "you're not called X anymore", "your new name is", etc. — **second-person ("you") or explicit AI-directed naming patterns**
- **🚨 "You're not called X anymore" / "You're not X, you're Y" pattern: BOTH X and Y are AI's names (old and new). They are NOT user's names. The subject is "you" (the AI).**
- **These patterns mean the name belongs to the USER, NOT the AI:** "I'm", "my name is", "call me", "I am", "people call me", "my English name is", "my nickname is", etc. — **first-person ("I"/"me") self-introduction patterns**
- **🚨 Special rules when speaker=assistant:**
* Names the AI uses to address the user → belong to the **user** entity's aliases (but only extract names that appear verbatim in the text, never infer)
* Names the AI uses for itself (e.g., "I'm Jarvis", "I am your assistant") → belong to the **AI assistant** entity's aliases
* Rule: AI says "you are X" or calls user X → X is user's alias; AI says "I'm X" or "I am X" → X is AI's alias
- Examples:
* "I'm vv" → User entity: name="User", aliases=["vv"] (first-person intro, name belongs to user)
* "My English name is vv" → User entity: name="User", aliases=["vv"] (first-person intro, name belongs to user)
* "I'm Alex, I'll call you Buddy" → User entity: name="User", aliases=["Alex"]; AI entity: name="AI Assistant", aliases=["Buddy"]
* "Call yourself Jarvis, my name is Tony" → User entity: name="User", aliases=["Tony"]; AI entity: name="AI Assistant", aliases=["Jarvis"]
* "You're not called Jarvis anymore, your new name is Friday" → AI entity: name="AI Assistant", aliases=["Friday"] (both "Jarvis" and "Friday" are AI names, NOT user names)
* [speaker=assistant] "Sure thing, VV" → User entity: name="User", aliases=["VV"] (AI addressing the user, "VV" appears in text)
* [speaker=assistant] "You're Alex, and I'm Jarvis" → User entity: name="User", aliases=["Alex"]; AI entity: name="AI Assistant", aliases=["Jarvis"]
* ❌ Wrong: User says "I'm vv" but "vv" is put in AI assistant's aliases (no AI-naming expression exists)
* ❌ Wrong: AI calls user "VV" but "VV" is put in AI assistant's aliases
* ❌ Wrong: aliases=["Alex", "Buddy"] ("Buddy" is a name for the AI, not the user)
* ❌ Wrong: Text only has "John Smith" but aliases include "Johnny", "Smithy" (hallucinated variants)
{% endif %}
5. **ALIASES ORDER:**
{% if language == "zh" %}
- 顺序优先级:按出现顺序,先出现的在前
{% else %}
@@ -202,8 +285,19 @@ Output:
{"entity_idx": 0, "name": "Tripod", "type": "Equipment", "description": "Photography equipment accessory", "example": "", "aliases": ["Camera Tripod"], "is_explicit_memory": false}
]
}
**Example 4 (User vs AI alias distinction - English output):** "I'm Alex, and I'll call you Buddy"
Output:
{
"triplets": [
{"subject_name": "User", "subject_id": 0, "predicate": "NAMED", "object_name": "AI Assistant", "object_id": 1, "value": "Buddy"}
],
"entities": [
{"entity_idx": 0, "name": "User", "type": "Person", "description": "The user", "example": "", "aliases": ["Alex"], "is_explicit_memory": false},
{"entity_idx": 1, "name": "AI Assistant", "type": "Person", "description": "The user's AI assistant", "example": "", "aliases": ["Buddy"], "is_explicit_memory": false}
]
}
{% else %}
**Example 1 (English input → Chinese output):** "I plan to travel to Paris next week and visit the Louvre."
Output:
{
"triplets": [
@@ -258,6 +352,39 @@ Output:
]
}
**Example 6 (用户与AI别名区分 - Chinese):** "我称呼自己为陈思远我给AI取名为远仔"
Output:
{
"triplets": [
{"subject_name": "用户", "subject_id": 0, "predicate": "NAMED", "object_name": "AI助手", "object_id": 1, "value": "远仔"}
],
"entities": [
{"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["陈思远"], "is_explicit_memory": false},
{"entity_idx": 1, "name": "AI助手", "type": "Person", "description": "用户的AI助手", "example": "", "aliases": ["远仔"], "is_explicit_memory": false}
]
}
**Example 7 (纯用户自我介绍无AI命名 - Chinese):** "我叫vv"
Output:
{
"triplets": [],
"entities": [
{"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["vv"], "is_explicit_memory": false}
]
}
**Example 8 (给AI改名 - Chinese):** "你不叫远仔了,你现在叫陈仔"
Output:
{
"triplets": [
{"subject_name": "用户", "subject_id": 0, "predicate": "NAMED", "object_name": "AI助手", "object_id": 1, "value": "陈仔"}
],
"entities": [
{"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": [], "is_explicit_memory": false},
{"entity_idx": 1, "name": "AI助手", "type": "Person", "description": "用户的AI助手", "example": "", "aliases": ["陈仔"], "is_explicit_memory": false}
]
}
{% endif %}
===End of Examples===

View File

@@ -14,6 +14,7 @@ from pydantic import BaseModel, Field
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.models.models_model import ModelProvider, ModelType
from app.core.models.volcano_chat import VolcanoChatOpenAI
T = TypeVar("T")
@@ -25,6 +26,9 @@ class RedBearModelConfig(BaseModel):
api_key: str
base_url: Optional[str] = None
is_omni: bool = False # 是否为 Omni 模型
deep_thinking: bool = False # 是否启用深度思考模式
thinking_budget_tokens: Optional[int] = None # 深度思考 token 预算
support_thinking: bool = False # 模型是否支持 enable_thinking 参数capability 含 thinking
# 请求超时时间(秒)- 默认120秒以支持复杂的LLM调用可通过环境变量 LLM_TIMEOUT 配置
timeout: float = Field(default_factory=lambda: float(os.getenv("LLM_TIMEOUT", "120.0")))
# 最大重试次数 - 默认2次以避免过长等待可通过环境变量 LLM_MAX_RETRIES 配置
@@ -44,7 +48,7 @@ class RedBearModelFactory:
# 打印供应商信息用于调试
from app.core.logging_config import get_business_logger
logger = get_business_logger()
logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}, is_omni: {config.is_omni}")
logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}, is_omni: {config.is_omni}, deep_thinking: {config.deep_thinking}")
# dashscope 的 omni 模型使用 OpenAI 兼容模式
if provider == ModelProvider.DASHSCOPE and config.is_omni:
@@ -58,7 +62,7 @@ class RedBearModelFactory:
write=60.0,
pool=10.0,
)
return {
params: Dict[str, Any] = {
"model": config.model_name,
"base_url": config.base_url,
"api_key": config.api_key,
@@ -66,6 +70,23 @@ class RedBearModelFactory:
"max_retries": config.max_retries,
**config.extra_params
}
# 流式模式下启用 stream_usage 以获取 token 统计
is_streaming = bool(config.extra_params.get("streaming"))
if is_streaming:
params["stream_usage"] = True
# 只有支持 thinking 的模型才传 enable_thinking
if config.support_thinking:
model_kwargs: Dict[str, Any] = config.extra_params.get("model_kwargs", {})
if is_streaming:
model_kwargs["enable_thinking"] = config.deep_thinking
if config.deep_thinking:
model_kwargs["incremental_output"] = True
if config.thinking_budget_tokens:
model_kwargs["thinking_budget"] = config.thinking_budget_tokens
else:
model_kwargs["enable_thinking"] = False
params["model_kwargs"] = model_kwargs
return params
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]:
# 使用 httpx.Timeout 对象来设置详细的超时配置
@@ -78,7 +99,7 @@ class RedBearModelFactory:
write=60.0, # 写入超时60秒
pool=10.0, # 连接池超时10秒
)
return {
params: Dict[str, Any] = {
"model": config.model_name,
"base_url": config.base_url,
"api_key": config.api_key,
@@ -86,16 +107,49 @@ class RedBearModelFactory:
"max_retries": config.max_retries,
**config.extra_params
}
# 流式模式下启用 stream_usage 以获取 token 统计
if config.extra_params.get("streaming"):
params["stream_usage"] = True
# 深度思考模式
is_streaming = bool(config.extra_params.get("streaming"))
if is_streaming and not config.is_omni:
if provider == ModelProvider.VOLCANO:
# 火山引擎深度思考仅流式调用支持,非流式时不传 thinking 参数
thinking_config: Dict[str, Any] = {
"type": "enabled" if config.deep_thinking else "disabled"
}
if config.deep_thinking and config.thinking_budget_tokens:
thinking_config["budget_tokens"] = config.thinking_budget_tokens
params["extra_body"] = {"thinking": thinking_config}
else:
# 始终显式传递 enable_thinking不支持该参数的模型如 DeepSeek-R1会直接忽略
model_kwargs: Dict[str, Any] = config.extra_params.get("model_kwargs", {})
model_kwargs["enable_thinking"] = config.deep_thinking
if config.deep_thinking and config.thinking_budget_tokens:
model_kwargs["thinking_budget"] = config.thinking_budget_tokens
params["model_kwargs"] = model_kwargs
return params
elif provider == ModelProvider.DASHSCOPE:
# DashScope (通义千问) 使用自己的参数格式
# 注意: DashScopeEmbeddings 不支持 timeout 和 base_url 参数
# 只支持: model, dashscope_api_key, max_retries, client
return {
params = {
"model": config.model_name,
"dashscope_api_key": config.api_key,
"max_retries": config.max_retries,
**config.extra_params
}
# 只有支持 thinking 的模型才传 enable_thinking
if config.support_thinking:
is_streaming = bool(config.extra_params.get("streaming"))
model_kwargs: Dict[str, Any] = config.extra_params.get("model_kwargs", {})
if is_streaming:
model_kwargs["enable_thinking"] = config.deep_thinking
if config.deep_thinking:
model_kwargs["incremental_output"] = True
if config.thinking_budget_tokens:
model_kwargs["thinking_budget"] = config.thinking_budget_tokens
else:
model_kwargs["enable_thinking"] = False
params["model_kwargs"] = model_kwargs
return params
elif provider == ModelProvider.BEDROCK:
# Bedrock 使用 AWS 凭证
# api_key 格式: "access_key_id:secret_access_key" 或只是 access_key_id
@@ -134,6 +188,13 @@ class RedBearModelFactory:
elif "region_name" not in params:
params["region_name"] = "us-east-1" # 默认区域
# 深度思考模式Claude 3.7 Sonnet 等支持思考的模型
# 通过 additional_model_request_fields 传递 thinking 块关闭时不传Bedrock 无 disabled 选项)
if config.deep_thinking:
budget = config.thinking_budget_tokens or 10000
params["additional_model_request_fields"] = {
"thinking": {"type": "enabled", "budget_tokens": budget}
}
return params
else:
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
@@ -160,7 +221,9 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy
# dashscope 的 omni 模型使用 OpenAI 兼容模式
if provider == ModelProvider.DASHSCOPE and config.is_omni:
return ChatOpenAI
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.VOLCANO]:
if provider == ModelProvider.VOLCANO:
return VolcanoChatOpenAI
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
if type == ModelType.LLM:
return OpenAI
elif type == ModelType.CHAT:

View File

@@ -1,5 +1,5 @@
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Union
from langchain_core.embeddings import Embeddings
from app.core.models.base import RedBearModelConfig, get_provider_embedding_class, RedBearModelFactory
@@ -22,11 +22,40 @@ class RedBearEmbeddings(Embeddings):
self._model = self._create_model(config)
self._client = None
def _create_model(self, config: RedBearModelConfig) -> Embeddings:
@staticmethod
def _create_model(config: RedBearModelConfig) -> Embeddings:
"""根据配置创建 LangChain 模型"""
embedding_class = get_provider_embedding_class(config.provider)
model_params = RedBearModelFactory.get_model_params(config)
return embedding_class(**model_params)
provider = config.provider.lower()
# Embedding models only need connection params, never LLM-specific ones
# (e.g. enable_thinking, model_kwargs) — build params directly.
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
import httpx
params = {
"model": config.model_name,
"base_url": config.base_url,
"api_key": config.api_key,
"timeout": httpx.Timeout(timeout=config.timeout, connect=60.0),
"max_retries": config.max_retries,
"check_embedding_ctx_length": False,
"encoding_format": "float"
}
elif provider == ModelProvider.DASHSCOPE:
params = {
"model": config.model_name,
"dashscope_api_key": config.api_key,
"max_retries": config.max_retries,
}
elif provider == ModelProvider.OLLAMA:
params = {
"model": config.model_name,
"base_url": config.base_url,
}
elif provider == ModelProvider.BEDROCK:
params = RedBearModelFactory.get_model_params(config)
else:
params = RedBearModelFactory.get_model_params(config)
return embedding_class(**params)
def _create_volcano_client(self, config: RedBearModelConfig):
"""创建火山引擎客户端"""

View File

@@ -11,6 +11,7 @@ models:
tags:
- 大语言模型
logo: bedrock
- name: amazon nova
type: llm
provider: bedrock
@@ -27,6 +28,7 @@ models:
- stream-tool-call
- vision
logo: bedrock
- name: anthropic claude
type: llm
provider: bedrock
@@ -35,6 +37,7 @@ models:
is_official: true
capability:
- vision
- thinking
is_omni: false
tags:
- 大语言模型
@@ -44,6 +47,7 @@ models:
- stream-tool-call
- document
logo: bedrock
- name: cohere
type: llm
provider: bedrock
@@ -58,6 +62,7 @@ models:
- tool-call
- stream-tool-call
logo: bedrock
- name: deepseek
type: llm
provider: bedrock
@@ -66,6 +71,7 @@ models:
is_official: true
capability:
- vision
- thinking
is_omni: false
tags:
- 大语言模型
@@ -74,6 +80,7 @@ models:
- tool-call
- stream-tool-call
logo: bedrock
- name: meta
type: llm
provider: bedrock
@@ -87,6 +94,7 @@ models:
- agent-thought
- tool-call
logo: bedrock
- name: mistral
type: llm
provider: bedrock
@@ -100,6 +108,7 @@ models:
- agent-thought
- tool-call
logo: bedrock
- name: openai
type: llm
provider: bedrock
@@ -114,6 +123,7 @@ models:
- tool-call
- stream-tool-call
logo: bedrock
- name: qwen
type: llm
provider: bedrock
@@ -128,6 +138,7 @@ models:
- tool-call
- stream-tool-call
logo: bedrock
- name: amazon.rerank-v1:0
type: rerank
provider: bedrock
@@ -139,6 +150,7 @@ models:
tags:
- 重排序模型
logo: bedrock
- name: cohere.rerank-v3-5:0
type: rerank
provider: bedrock
@@ -150,6 +162,7 @@ models:
tags:
- 重排序模型
logo: bedrock
- name: amazon.nova-2-multimodal-embeddings-v1:0
type: embedding
provider: bedrock
@@ -163,6 +176,7 @@ models:
- 文本嵌入模型
- vision
logo: bedrock
- name: amazon.titan-embed-text-v1
type: embedding
provider: bedrock
@@ -174,6 +188,7 @@ models:
tags:
- 文本嵌入模型
logo: bedrock
- name: amazon.titan-embed-text-v2:0
type: embedding
provider: bedrock
@@ -185,6 +200,7 @@ models:
tags:
- 文本嵌入模型
logo: bedrock
- name: cohere.embed-english-v3
type: embedding
provider: bedrock
@@ -196,6 +212,7 @@ models:
tags:
- 文本嵌入模型
logo: bedrock
- name: cohere.embed-multilingual-v3
type: embedding
provider: bedrock

View File

@@ -6,36 +6,42 @@ models:
description: DeepSeek-R1-Distill-Qwen-14B大语言模型支持智能体思考32000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
capability:
- thinking
is_omni: false
tags:
- 大语言模型
- agent-thought
logo: dashscope
- name: deepseek-r1-distill-qwen-32b
type: llm
provider: dashscope
description: DeepSeek-R1-Distill-Qwen-32B大语言模型支持智能体思考32000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
capability:
- thinking
is_omni: false
tags:
- 大语言模型
- agent-thought
logo: dashscope
- name: deepseek-r1
type: llm
provider: dashscope
description: DeepSeek-R1大语言模型支持智能体思考131072超大上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
capability:
- thinking
is_omni: false
tags:
- 大语言模型
- agent-thought
logo: dashscope
- name: deepseek-v3.1
type: llm
provider: dashscope
@@ -48,6 +54,7 @@ models:
- 大语言模型
- agent-thought
logo: dashscope
- name: deepseek-v3.2-exp
type: llm
provider: dashscope
@@ -60,6 +67,7 @@ models:
- 大语言模型
- agent-thought
logo: dashscope
- name: deepseek-v3.2
type: llm
provider: dashscope
@@ -72,6 +80,7 @@ models:
- 大语言模型
- agent-thought
logo: dashscope
- name: deepseek-v3
type: llm
provider: dashscope
@@ -84,6 +93,7 @@ models:
- 大语言模型
- agent-thought
logo: dashscope
- name: farui-plus
type: llm
provider: dashscope
@@ -98,6 +108,7 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: glm-4.7
type: llm
provider: dashscope
@@ -112,6 +123,7 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qvq-max-latest
type: llm
provider: dashscope
@@ -119,7 +131,8 @@ models:
is_deprecated: false
is_official: true
capability:
- vision
- vision
- thinking
is_omni: false
tags:
- 大语言模型
@@ -127,6 +140,7 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qvq-max
type: llm
provider: dashscope
@@ -134,7 +148,8 @@ models:
is_deprecated: false
is_official: true
capability:
- vision
- vision
- thinking
is_omni: false
tags:
- 大语言模型
@@ -142,6 +157,7 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-coder-turbo-0919
type: llm
provider: dashscope
@@ -155,13 +171,15 @@ models:
- 代码模型
- agent-thought
logo: dashscope
- name: qwen-max-latest
type: llm
provider: dashscope
description: qwen-max-latest大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
capability: []
capability:
- thinking
is_omni: false
tags:
- 大语言模型
@@ -169,6 +187,7 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-max-longcontext
type: llm
provider: dashscope
@@ -183,13 +202,15 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-max
type: llm
provider: dashscope
description: qwen-max大语言模型支持多工具调用、智能体思考、流式工具调用32768上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
capability: []
capability:
- thinking
is_omni: false
tags:
- 大语言模型
@@ -197,6 +218,7 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-mt-plus
type: llm
provider: dashscope
@@ -210,6 +232,7 @@ models:
- 翻译模型
- agent-thought
logo: dashscope
- name: qwen-mt-turbo
type: llm
provider: dashscope
@@ -223,6 +246,7 @@ models:
- 翻译模型
- agent-thought
logo: dashscope
- name: qwen-plus-0112
type: llm
provider: dashscope
@@ -237,6 +261,7 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-plus-0125
type: llm
provider: dashscope
@@ -251,6 +276,7 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-plus-0723
type: llm
provider: dashscope
@@ -265,6 +291,7 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-plus-0806
type: llm
provider: dashscope
@@ -279,6 +306,7 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-plus-0919
type: llm
provider: dashscope
@@ -293,6 +321,7 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-plus-1125
type: llm
provider: dashscope
@@ -307,6 +336,7 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-plus-1127
type: llm
provider: dashscope
@@ -321,6 +351,7 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-plus-1220
type: llm
provider: dashscope
@@ -335,6 +366,7 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-vl-max
type: chat
provider: dashscope
@@ -342,8 +374,8 @@ models:
is_deprecated: false
is_official: true
capability:
- vision
- video
- vision
- video
is_omni: false
tags:
- 大语言模型
@@ -352,6 +384,7 @@ models:
- agent-thought
- video
logo: dashscope
- name: qwen-vl-plus-0809
type: chat
provider: dashscope
@@ -359,8 +392,8 @@ models:
is_deprecated: true
is_official: true
capability:
- vision
- video
- vision
- video
is_omni: false
tags:
- 大语言模型
@@ -369,6 +402,7 @@ models:
- agent-thought
- video
logo: dashscope
- name: qwen-vl-plus-2025-01-02
type: chat
provider: dashscope
@@ -376,8 +410,8 @@ models:
is_deprecated: false
is_official: true
capability:
- vision
- video
- vision
- video
is_omni: false
tags:
- 大语言模型
@@ -386,6 +420,7 @@ models:
- agent-thought
- video
logo: dashscope
- name: qwen-vl-plus-2025-01-25
type: chat
provider: dashscope
@@ -393,8 +428,8 @@ models:
is_deprecated: false
is_official: true
capability:
- vision
- video
- vision
- video
is_omni: false
tags:
- 大语言模型
@@ -403,6 +438,7 @@ models:
- agent-thought
- video
logo: dashscope
- name: qwen-vl-plus-latest
type: chat
provider: dashscope
@@ -410,8 +446,8 @@ models:
is_deprecated: false
is_official: true
capability:
- vision
- video
- vision
- video
is_omni: false
tags:
- 大语言模型
@@ -420,6 +456,7 @@ models:
- agent-thought
- video
logo: dashscope
- name: qwen-vl-plus
type: chat
provider: dashscope
@@ -427,8 +464,8 @@ models:
is_deprecated: false
is_official: true
capability:
- vision
- video
- vision
- video
is_omni: false
tags:
- 大语言模型
@@ -437,6 +474,7 @@ models:
- agent-thought
- video
logo: dashscope
- name: qwen2.5-0.5b-instruct
type: llm
provider: dashscope
@@ -451,13 +489,15 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-14b
type: llm
provider: dashscope
description: qwen3-14b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
capability:
- thinking
is_omni: false
tags:
- 大语言模型
@@ -465,13 +505,15 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-235b-a22b-instruct-2507
type: llm
provider: dashscope
description: qwen3-235b-a22b-instruct-2507大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
capability:
- thinking
is_omni: false
tags:
- 大语言模型
@@ -479,13 +521,15 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-235b-a22b-thinking-2507
type: llm
provider: dashscope
description: qwen3-235b-a22b-thinking-2507大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
capability:
- thinking
is_omni: false
tags:
- 大语言模型
@@ -493,13 +537,15 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-235b-a22b
type: llm
provider: dashscope
description: qwen3-235b-a22b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
capability:
- thinking
is_omni: false
tags:
- 大语言模型
@@ -507,13 +553,15 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-30b-a3b-instruct-2507
type: llm
provider: dashscope
description: qwen3-30b-a3b-instruct-2507大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
capability:
- thinking
is_omni: false
tags:
- 大语言模型
@@ -521,13 +569,15 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-30b-a3b
type: llm
provider: dashscope
description: qwen3-30b-a3b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
capability:
- thinking
is_omni: false
tags:
- 大语言模型
@@ -535,13 +585,15 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-32b
type: llm
provider: dashscope
description: qwen3-32b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
capability:
- thinking
is_omni: false
tags:
- 大语言模型
@@ -549,13 +601,15 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-4b
type: llm
provider: dashscope
description: qwen3-4b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
capability:
- thinking
is_omni: false
tags:
- 大语言模型
@@ -563,13 +617,15 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-8b
type: llm
provider: dashscope
description: qwen3-8b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
capability:
- thinking
is_omni: false
tags:
- 大语言模型
@@ -577,65 +633,75 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-coder-30b-a3b-instruct
type: llm
provider: dashscope
description: qwen3-coder-30b-a3b-instruct大语言模型支持智能体思考262144上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
capability:
- thinking
is_omni: false
tags:
- 大语言模型
- 代码模型
- agent-thought
logo: dashscope
- name: qwen3-coder-480b-a35b-instruct
type: llm
provider: dashscope
description: qwen3-coder-480b-a35b-instruct大语言模型支持智能体思考262144上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
capability:
- thinking
is_omni: false
tags:
- 大语言模型
- 代码模型
- agent-thought
logo: dashscope
- name: qwen3-coder-plus-2025-09-23
type: llm
provider: dashscope
description: qwen3-coder-plus-2025-09-23大语言模型支持智能体思考1000000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
capability:
- thinking
is_omni: false
tags:
- 大语言模型
- 代码模型
- agent-thought
logo: dashscope
- name: qwen3-coder-plus
type: llm
provider: dashscope
description: qwen3-coder-plus大语言模型支持智能体思考1000000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
capability:
- thinking
is_omni: false
tags:
- 大语言模型
- 代码模型
- agent-thought
logo: dashscope
- name: qwen3-max-2025-09-23
type: llm
provider: dashscope
description: qwen3-max-2025-09-23大语言模型支持多工具调用、智能体思考、流式工具调用262144上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
capability: []
capability:
- thinking
is_omni: false
tags:
- 大语言模型
@@ -644,13 +710,15 @@ models:
- stream-tool-call
- 联网搜索
logo: dashscope
- name: qwen3-max-2026-01-23
type: llm
provider: dashscope
description: qwen3-max-2026-01-23大语言模型支持多工具调用、智能体思考、流式工具调用262144上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
capability: []
capability:
- thinking
is_omni: false
tags:
- 大语言模型
@@ -659,13 +727,15 @@ models:
- stream-tool-call
- 联网搜索
logo: dashscope
- name: qwen3-max-preview
type: llm
provider: dashscope
description: qwen3-max-preview大语言模型支持多工具调用、智能体思考、流式工具调用262144上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
capability:
- thinking
is_omni: false
tags:
- 大语言模型
@@ -673,13 +743,15 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-max
type: llm
provider: dashscope
description: qwen3-max大语言模型支持多工具调用、智能体思考、流式工具调用262144上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
capability: []
capability:
- thinking
is_omni: false
tags:
- 大语言模型
@@ -688,13 +760,15 @@ models:
- stream-tool-call
- 联网搜索
logo: dashscope
- name: qwen3-next-80b-a3b-instruct
type: llm
provider: dashscope
description: qwen3-next-80b-a3b-instruct大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
capability:
- thinking
is_omni: false
tags:
- 大语言模型
@@ -702,13 +776,15 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-next-80b-a3b-thinking
type: llm
provider: dashscope
description: qwen3-next-80b-a3b-thinking大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
capability:
- thinking
is_omni: false
tags:
- 大语言模型
@@ -716,6 +792,7 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-omni-flash-2025-12-01
type: llm
provider: dashscope
@@ -723,9 +800,9 @@ models:
is_deprecated: false
is_official: true
capability:
- vision
- video
- audio
- vision
- video
- audio
is_omni: true
tags:
- 大语言模型
@@ -735,6 +812,7 @@ models:
- video
- audio
logo: dashscope
- name: qwen3-vl-235b-a22b-instruct
type: chat
provider: dashscope
@@ -742,8 +820,9 @@ models:
is_deprecated: false
is_official: true
capability:
- vision
- video
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -754,6 +833,7 @@ models:
- vision
- video
logo: dashscope
- name: qwen3-vl-235b-a22b-thinking
type: chat
provider: dashscope
@@ -761,8 +841,9 @@ models:
is_deprecated: false
is_official: true
capability:
- vision
- video
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -773,6 +854,7 @@ models:
- vision
- video
logo: dashscope
- name: qwen3-vl-30b-a3b-instruct
type: chat
provider: dashscope
@@ -780,8 +862,9 @@ models:
is_deprecated: false
is_official: true
capability:
- vision
- video
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -792,6 +875,7 @@ models:
- vision
- video
logo: dashscope
- name: qwen3-vl-30b-a3b-thinking
type: chat
provider: dashscope
@@ -799,8 +883,9 @@ models:
is_deprecated: false
is_official: true
capability:
- vision
- video
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -811,6 +896,7 @@ models:
- vision
- video
logo: dashscope
- name: qwen3-vl-flash
type: chat
provider: dashscope
@@ -818,8 +904,9 @@ models:
is_deprecated: false
is_official: true
capability:
- vision
- video
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -830,6 +917,7 @@ models:
- vision
- video
logo: dashscope
- name: qwen3-vl-plus-2025-09-23
type: chat
provider: dashscope
@@ -837,8 +925,9 @@ models:
is_deprecated: false
is_official: true
capability:
- vision
- video
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -847,6 +936,7 @@ models:
- agent-thought
- video
logo: dashscope
- name: qwen3-vl-plus
type: chat
provider: dashscope
@@ -854,8 +944,9 @@ models:
is_deprecated: false
is_official: true
capability:
- vision
- video
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -864,45 +955,52 @@ models:
- agent-thought
- video
logo: dashscope
- name: qwq-32b
type: llm
provider: dashscope
description: qwq-32b大语言模型支持智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
capability:
- thinking
is_omni: false
tags:
- 大语言模型
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwq-plus-0305
type: llm
provider: dashscope
description: qwq-plus-0305大语言模型支持智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
capability:
- thinking
is_omni: false
tags:
- 大语言模型
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwq-plus
type: llm
provider: dashscope
description: qwq-plus大语言模型支持智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
capability:
- thinking
is_omni: false
tags:
- 大语言模型
- agent-thought
- stream-tool-call
logo: dashscope
- name: gte-rerank-v2
type: rerank
provider: dashscope
@@ -914,6 +1012,7 @@ models:
tags:
- 重排序模型
logo: dashscope
- name: gte-rerank
type: rerank
provider: dashscope
@@ -925,6 +1024,7 @@ models:
tags:
- 重排序模型
logo: dashscope
- name: multimodal-embedding-v1
type: embedding
provider: dashscope
@@ -932,13 +1032,14 @@ models:
is_deprecated: false
is_official: true
capability:
- vision
- vision
is_omni: false
tags:
- 嵌入模型
- 多模态模型
- vision
logo: dashscope
- name: text-embedding-v1
type: embedding
provider: dashscope
@@ -951,6 +1052,7 @@ models:
- 嵌入模型
- 文本嵌入
logo: dashscope
- name: text-embedding-v2
type: embedding
provider: dashscope
@@ -963,6 +1065,7 @@ models:
- 嵌入模型
- 文本嵌入
logo: dashscope
- name: text-embedding-v3
type: embedding
provider: dashscope
@@ -975,6 +1078,7 @@ models:
- 嵌入模型
- 文本嵌入
logo: dashscope
- name: text-embedding-v4
type: embedding
provider: dashscope
@@ -986,4 +1090,4 @@ models:
tags:
- 嵌入模型
- 文本嵌入
logo: dashscope
logo: dashscope

View File

@@ -20,6 +20,7 @@ models:
- audio
- video
logo: openai
- name: gpt-3.5-turbo-0125
type: llm
provider: openai
@@ -34,6 +35,7 @@ models:
- agent-thought
- stream-tool-call
logo: openai
- name: gpt-3.5-turbo-1106
type: llm
provider: openai
@@ -48,6 +50,7 @@ models:
- agent-thought
- stream-tool-call
logo: openai
- name: gpt-3.5-turbo-16k
type: llm
provider: openai
@@ -62,6 +65,7 @@ models:
- agent-thought
- stream-tool-call
logo: openai
- name: gpt-3.5-turbo-instruct
type: llm
provider: openai
@@ -73,6 +77,7 @@ models:
tags:
- 大语言模型
logo: openai
- name: gpt-3.5-turbo
type: llm
provider: openai
@@ -87,6 +92,7 @@ models:
- agent-thought
- stream-tool-call
logo: openai
- name: gpt-4-0125-preview
type: llm
provider: openai
@@ -101,6 +107,7 @@ models:
- agent-thought
- stream-tool-call
logo: openai
- name: gpt-4-1106-preview
type: llm
provider: openai
@@ -115,6 +122,7 @@ models:
- agent-thought
- stream-tool-call
logo: openai
- name: gpt-4-turbo-2024-04-09
type: llm
provider: openai
@@ -131,6 +139,7 @@ models:
- stream-tool-call
- vision
logo: openai
- name: gpt-4-turbo-preview
type: llm
provider: openai
@@ -145,6 +154,7 @@ models:
- agent-thought
- stream-tool-call
logo: openai
- name: gpt-4-turbo
type: llm
provider: openai
@@ -161,6 +171,7 @@ models:
- stream-tool-call
- vision
logo: openai
- name: o1-preview
type: llm
provider: openai
@@ -173,6 +184,7 @@ models:
- 大语言模型
- agent-thought
logo: openai
- name: o1
type: llm
provider: openai
@@ -181,6 +193,7 @@ models:
is_official: true
capability:
- vision
- thinking
is_omni: false
tags:
- 大语言模型
@@ -190,6 +203,7 @@ models:
- vision
- structured-output
logo: openai
- name: o3-2025-04-16
type: llm
provider: openai
@@ -198,6 +212,7 @@ models:
is_official: true
capability:
- vision
- thinking
is_omni: false
tags:
- 大语言模型
@@ -207,13 +222,15 @@ models:
- stream-tool-call
- structured-output
logo: openai
- name: o3-mini-2025-01-31
type: llm
provider: openai
description: o3-mini-2025-01-31大语言模型支持智能体思考、工具调用、流式工具调用、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
capability:
- thinking
is_omni: false
tags:
- 大语言模型
@@ -222,13 +239,15 @@ models:
- stream-tool-call
- structured-output
logo: openai
- name: o3-mini
type: llm
provider: openai
description: o3-mini大语言模型支持智能体思考、工具调用、流式工具调用、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
capability:
- thinking
is_omni: false
tags:
- 大语言模型
@@ -237,6 +256,7 @@ models:
- stream-tool-call
- structured-output
logo: openai
- name: o3-pro-2025-06-10
type: llm
provider: openai
@@ -245,6 +265,7 @@ models:
is_official: true
capability:
- vision
- thinking
is_omni: false
tags:
- 大语言模型
@@ -253,6 +274,7 @@ models:
- vision
- structured-output
logo: openai
- name: o3-pro
type: llm
provider: openai
@@ -261,6 +283,7 @@ models:
is_official: true
capability:
- vision
- thinking
is_omni: false
tags:
- 大语言模型
@@ -269,6 +292,7 @@ models:
- vision
- structured-output
logo: openai
- name: o3
type: llm
provider: openai
@@ -277,6 +301,7 @@ models:
is_official: true
capability:
- vision
- thinking
is_omni: false
tags:
- 大语言模型
@@ -286,6 +311,7 @@ models:
- stream-tool-call
- structured-output
logo: openai
- name: o4-mini-2025-04-16
type: llm
provider: openai
@@ -294,6 +320,7 @@ models:
is_official: true
capability:
- vision
- thinking
is_omni: false
tags:
- 大语言模型
@@ -303,6 +330,7 @@ models:
- stream-tool-call
- structured-output
logo: openai
- name: o4-mini
type: llm
provider: openai
@@ -311,6 +339,7 @@ models:
is_official: true
capability:
- vision
- thinking
is_omni: false
tags:
- 大语言模型
@@ -320,6 +349,7 @@ models:
- stream-tool-call
- structured-output
logo: openai
- name: text-embedding-3-large
type: embedding
provider: openai
@@ -331,6 +361,7 @@ models:
tags:
- 文本向量模型
logo: openai
- name: text-embedding-3-small
type: embedding
provider: openai
@@ -342,6 +373,7 @@ models:
tags:
- 文本向量模型
logo: openai
- name: text-embedding-ada-002
type: embedding
provider: openai

View File

@@ -10,6 +10,7 @@ models:
capability:
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -24,6 +25,7 @@ models:
capability:
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -38,6 +40,7 @@ models:
capability:
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -52,6 +55,7 @@ models:
capability:
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -82,6 +86,7 @@ models:
capability:
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -96,6 +101,7 @@ models:
capability:
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -110,6 +116,7 @@ models:
capability:
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -124,6 +131,7 @@ models:
capability:
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -139,6 +147,7 @@ models:
capability:
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型

View File

@@ -0,0 +1,52 @@
"""
火山引擎 ChatOpenAI 扩展
ChatOpenAI 在解析流式 SSE 时只取 delta.content会丢弃 delta.reasoning_content。
此类仅重写 _convert_chunk_to_generation_chunk将 reasoning_content 补入 additional_kwargs。
"""
from __future__ import annotations
from typing import Any, Optional, Union
from langchain_core.outputs import ChatGenerationChunk, ChatResult
from langchain_openai import ChatOpenAI
class VolcanoChatOpenAI(ChatOpenAI):
"""火山引擎 Chat 模型支持深度思考内容reasoning_content的流式和非流式透传。"""
def _create_chat_result(self, response: Union[dict, Any], generation_info: Optional[dict] = None) -> ChatResult:
result = super()._create_chat_result(response, generation_info)
# 将非流式响应中的 reasoning_content 补入 additional_kwargs
choices = response.choices if hasattr(response, "choices") else response.get("choices", [])
if choices:
message = choices[0].message if hasattr(choices[0], "message") else choices[0].get("message", {})
reasoning = (
getattr(message, "reasoning_content", None)
or (message.get("reasoning_content") if isinstance(message, dict) else None)
)
if reasoning and result.generations:
result.generations[0].message.additional_kwargs["reasoning_content"] = reasoning
return result
def _convert_chunk_to_generation_chunk(
self,
chunk: dict,
default_chunk_class: type,
base_generation_info: Optional[dict],
) -> Optional[ChatGenerationChunk]:
gen_chunk = super()._convert_chunk_to_generation_chunk(
chunk, default_chunk_class, base_generation_info
)
if gen_chunk is None:
return None
# 从原始 chunk 中提取 reasoning_content
choices = chunk.get("choices") or chunk.get("chunk", {}).get("choices", [])
if choices:
delta = choices[0].get("delta") or {}
reasoning: Any = delta.get("reasoning_content")
if reasoning:
gen_chunk.message.additional_kwargs["reasoning_content"] = reasoning
return gen_chunk

View File

@@ -27,7 +27,7 @@ class DateTimeTool(BuiltinTool):
type=ParameterType.STRING,
description="操作类型",
required=True,
enum=["format", "convert_timezone", "timestamp_to_datetime", "now"]
enum=["format", "convert_timezone", "timestamp_to_datetime", "now", "datetime_to_timestamp"]
),
ToolParameter(
name="input_value",

View File

@@ -99,7 +99,7 @@ class SimpleMCPClient:
# 建立 SSE 连接
response = await self._session.get(self.server_url)
if response.status not in (200, 202):
if not (200 <= response.status < 300):
error_text = await response.text()
raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}")
@@ -190,9 +190,7 @@ class SimpleMCPClient:
try:
async with self._session.post(self._endpoint_url, json=request) as response:
# MCP SSE 协议POST 请求返回 200 或 202 均为正常
# 202 Accepted 表示请求已接受,结果通过 SSE 流异步返回
if response.status not in (200, 202):
if not (200 <= response.status < 300):
error_text = await response.text()
raise MCPConnectionError(f"请求失败 {response.status}: {error_text}")
@@ -207,7 +205,7 @@ class SimpleMCPClient:
raise MCPConnectionError("endpoint URL 未初始化")
async with self._session.post(self._endpoint_url, json=notification) as response:
if response.status not in (200, 202):
if not (200 <= response.status < 300):
logger.warning(f"通知发送失败: {response.status}")
async def _initialize_modelscope_session(self):
@@ -225,7 +223,7 @@ class SimpleMCPClient:
try:
async with self._session.post(self.server_url, json=init_request) as response:
if response.status != 200:
if not (200 <= response.status < 300):
error_text = await response.text()
raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}")

View File

@@ -32,13 +32,16 @@ from app.core.workflow.nodes.configs import (
NoteNodeConfig,
ParameterExtractorNodeConfig,
QuestionClassifierNodeConfig,
VariableAggregatorNodeConfig
VariableAggregatorNodeConfig,
ListOperatorNodeConfig,
DocExtractorNodeConfig,
)
from app.core.workflow.nodes.cycle_graph.config import (
ConditionDetail as LoopConditionDetail,
ConditionsConfig,
CycleVariable
)
from app.core.workflow.nodes.list_operator.config import FilterCondition
from app.core.workflow.nodes.enums import (
ValueInputType,
ComparisonOperator,
@@ -90,6 +93,8 @@ class DifyConverter(BaseConverter):
NodeType.VAR_AGGREGATOR: self.convert_variable_aggregator_node_config,
NodeType.TOOL: self.convert_tool_node_config,
NodeType.NOTES: self.convert_notes_config,
NodeType.LIST_OPERATOR: self.convert_list_operator_node_config,
NodeType.DOCUMENT_EXTRACTOR: self.convert_document_extractor_node_config,
NodeType.CYCLE_START: lambda x: {},
NodeType.BREAK: lambda x: {},
}
@@ -126,7 +131,7 @@ class DifyConverter(BaseConverter):
selector = var_selector.split('.')
if len(selector) not in [2, 3] and var_selector != "context":
raise Exception(f"invalid variable selector: {var_selector}")
if len(selector) == 3:
if len(selector) == 3 and selector[0] in ("conversation", "sys"):
selector = selector[1:]
if selector[0] == "conversation":
selector[0] = "conv"
@@ -213,7 +218,9 @@ class DifyConverter(BaseConverter):
"end with": ComparisonOperator.END_WITH,
"not contains": ComparisonOperator.NOT_CONTAINS,
"exists": ComparisonOperator.NOT_EMPTY,
"not exists": ComparisonOperator.EMPTY
"not exists": ComparisonOperator.EMPTY,
"in": ComparisonOperator.IN,
"not in": ComparisonOperator.NOT_IN,
}
return operator_map.get(operator, operator)
@@ -476,11 +483,11 @@ class DifyConverter(BaseConverter):
node_data = node["data"]
result = IterationNodeConfig.model_construct(
input=self._process_list_variable_literal(node_data["iterator_selector"]),
parallel=node_data["is_parallel"],
parallel_count=node_data["parallel_nums"],
parallel=node_data.get("is_parallel", False),
parallel_count=node_data.get("parallel_nums", 4),
output=self._process_list_variable_literal(node_data["output_selector"]),
output_type=self.variable_type_map(node_data.get("output_type")),
flatten=node_data["flatten_output"],
flatten=node_data.get("flatten_output", False),
).model_dump()
self.config_validate(node["id"], node["data"]["title"], IterationNodeConfig, result)
@@ -489,7 +496,23 @@ class DifyConverter(BaseConverter):
def convert_assigner_node_config(self, node: dict) -> dict:
node_data = node["data"]
assignments = []
for assignment in node_data["items"]:
# Support both formats:
# 1. New format: node_data["items"] list
# 2. Flat format: assigned_variable_selector + input_variable_selector + write_mode
if "items" in node_data:
raw_items = node_data["items"]
elif "assigned_variable_selector" in node_data and "input_variable_selector" in node_data:
raw_items = [{
"variable_selector": node_data["assigned_variable_selector"],
"value": node_data["input_variable_selector"],
"input_type": ValueInputType.VARIABLE,
"operation": node_data.get("write_mode", "over-write"),
}]
else:
raw_items = []
for assignment in raw_items:
if assignment.get("operation") is None or assignment.get("value") is None:
continue
assignments.append(
@@ -771,3 +794,46 @@ class DifyConverter(BaseConverter):
show_author=node_data.get("showAuthor", True)
).model_dump()
return result
def convert_list_operator_node_config(self, node: dict) -> dict:
"""Dify list-operator — convert variable path array to {{ }} selector format."""
node_data = node["data"]
variable_path = node_data.get("variable", [])
input_list = self._process_list_variable_literal(variable_path) or ""
filter_by = node_data.get("filter_by", {"enabled": False, "conditions": []})
# Convert each condition's comparison_operator from Dify format to native
if filter_by.get("conditions"):
converted_conditions = []
for cond in filter_by["conditions"]:
converted_conditions.append({
**cond,
"comparison_operator": self.convert_compare_operator(
cond.get("comparison_operator", "")
)
})
filter_by = {**filter_by, "conditions": converted_conditions}
result = {
"input_list": input_list,
"filter_by": filter_by,
"order_by": node_data.get("order_by", {"enabled": False, "key": "", "value": "asc"}),
"limit": node_data.get("limit", {"enabled": False, "size": -1}),
"extract_by": node_data.get("extract_by", {"enabled": False, "serial": "1"}),
}
self.config_validate(node["id"], node["data"]["title"], ListOperatorNodeConfig, result)
return result
def convert_document_extractor_node_config(self, node: dict) -> dict:
"""Convert Dify document-extractor node to MemoryBear DocExtractorNodeConfig.
Dify document-extractor data fields:
variable_selector: list[str] - file variable path
"""
node_data = node["data"]
file_selector = self._process_list_variable_literal(
node_data.get("variable_selector", [])
) or ""
result = DocExtractorNodeConfig.model_construct(
file_selector=file_selector,
).model_dump()
self.config_validate(node["id"], node["data"]["title"], DocExtractorNodeConfig, result)
return result

View File

@@ -45,6 +45,8 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
"question-classifier": NodeType.QUESTION_CLASSIFIER,
"variable-aggregator": NodeType.VAR_AGGREGATOR,
"tool": NodeType.TOOL,
"list-operator": NodeType.LIST_OPERATOR,
"document-extractor": NodeType.DOCUMENT_EXTRACTOR,
"": NodeType.NOTES
}

View File

@@ -22,6 +22,8 @@ from app.core.workflow.nodes.configs import (
MemoryReadNodeConfig,
MemoryWriteNodeConfig,
NoteNodeConfig,
ListOperatorNodeConfig,
DocExtractorNodeConfig,
)
from app.core.workflow.nodes.enums import NodeType
@@ -51,6 +53,8 @@ class MemoryBearConverter(BaseConverter):
NodeType.MEMORY_READ: MemoryReadNodeConfig,
NodeType.MEMORY_WRITE: MemoryWriteNodeConfig,
NodeType.NOTES: NoteNodeConfig,
NodeType.LIST_OPERATOR: ListOperatorNodeConfig,
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNodeConfig,
}
@staticmethod

View File

@@ -59,6 +59,9 @@ class WorkflowResultBuilder:
conversation_vars = variable_pool.get_all_conversation_vars()
sys_vars = variable_pool.get_all_system_vars()
# 汇总所有 knowledge 节点的 citations
citations = self.aggregate_citations(node_outputs)
return {
"status": "completed" if success else "failed",
"output": final_output,
@@ -71,9 +74,25 @@ class WorkflowResultBuilder:
"conversation_id": execution_context.conversation_id,
"elapsed_time": elapsed_time,
"token_usage": token_usage,
"citations": citations,
"error": result.get("error"),
}
@staticmethod
def aggregate_citations(node_outputs: dict) -> list:
"""从所有 knowledge 节点的输出中汇总 citations去重"""
seen = set()
citations = []
for node_output in node_outputs.values():
if not isinstance(node_output, dict):
continue
for c in node_output.get("citations", []):
key = c.get("document_id")
if key and key not in seen:
seen.add(key)
citations.append(c)
return citations
@staticmethod
def aggregate_token_usage(node_outputs: dict) -> dict[str, int] | None:
"""

View File

@@ -9,10 +9,10 @@ from app.core.workflow.nodes.enums import NodeType
def merge_activate_state(x, y):
return {
k: x.get(k, False) or y.get(k, False)
for k in set(x) | set(y)
}
merged = dict(x)
for k, v in y.items():
merged[k] = merged.get(k, False) or v
return merged
def merge_looping_state(x, y):

View File

@@ -17,6 +17,51 @@ from app.core.workflow.variable.variable_objects import T, create_variable_insta
logger = logging.getLogger(__name__)
VARIABLE_PATTERN = re.compile(r"\{\{\s*(.*?)\s*}}")
class LazyVariableDict:
def __init__(self, source, literal):
self._source: dict[str, VariableStruct[Any]] = source
self._literal: bool = literal
self._cache = {}
def keys(self):
return self._source.keys()
def _resolve(self, key):
if key in self._cache:
return self._cache[key]
var_struct = self._source.get(key)
if var_struct is None:
raise KeyError(key)
value = var_struct.instance.to_literal() if self._literal else var_struct.instance.get_value()
self._cache[key] = value
return value
def get(self, key, default=None):
try:
return self._resolve(key)
except KeyError:
return default
def __getitem__(self, key):
return self._resolve(key)
def __getattr__(self, key):
if key.startswith('_'):
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{key}'")
return self._resolve(key)
def __contains__(self, key):
return key in self._source
def __iter__(self):
return iter(self._source)
def __len__(self):
return len(self._source)
class VariableSelector:
"""变量选择器
@@ -117,8 +162,7 @@ class VariablePool:
@staticmethod
def transform_selector(selector):
pattern = r"\{\{\s*(.*?)\s*\}\}"
variable_literal = re.sub(pattern, r"\1", selector).strip()
variable_literal = VARIABLE_PATTERN.sub(r"\1", selector).strip()
selector = VariableSelector.from_string(variable_literal).path
if len(selector) != 2:
raise ValueError(f"Selector not valid - {selector}")
@@ -274,7 +318,7 @@ class VariablePool:
namespace: str,
key: str,
value: Any,
var_type: VariableType,
var_type: VariableType | None,
mut: bool
):
if self.has(f"{namespace}.{key}"):
@@ -303,6 +347,16 @@ class VariablePool:
"""
return self._get_variable_struct(selector) is not None
def lazy_namespace(self, namespace: str, literal: bool = False) -> LazyVariableDict:
return LazyVariableDict(self.variables.get(namespace, {}), literal)
def lazy_all_node_outputs(self, literal: bool = False) -> dict[str, LazyVariableDict]:
return {
ns: LazyVariableDict(vars_dict, literal)
for ns, vars_dict in self.variables.items()
if ns not in ("sys", "conv")
}
def get_all_system_vars(self, literal=False) -> dict[str, Any]:
"""获取所有系统变量
@@ -439,6 +493,23 @@ class VariablePoolInitializer:
var_value = var_default
else:
var_value = DEFAULT_VALUE(var_type)
# Convert FileInput-format dicts to full FileObject dicts
if var_type == VariableType.FILE:
if not var_value:
continue
var_value = await self._resolve_file_default(var_value)
if not var_value:
continue
elif var_type == VariableType.ARRAY_FILE:
if not var_value:
var_value = []
else:
resolved = []
for item in var_value:
f = await self._resolve_file_default(item)
if f:
resolved.append(f)
var_value = resolved
await variable_pool.new(
namespace="conv",
key=var_name,
@@ -447,6 +518,17 @@ class VariablePoolInitializer:
mut=True
)
@staticmethod
async def _resolve_file_default(file_def: dict) -> dict | None:
"""Accept only already-resolved FileObject dicts (is_file=True).
FileInput-format dicts are converted at save time by WorkflowService._resolve_variables_file_defaults.
"""
if not isinstance(file_def, dict):
return None
if file_def.get("is_file"):
return file_def
return None
@staticmethod
async def _init_system_vars(
variable_pool: VariablePool,
@@ -479,5 +561,3 @@ class VariablePoolInitializer:
var_type=var_type,
mut=False
)

View File

@@ -395,7 +395,8 @@ class BaseNode(ABC):
"output": output,
"elapsed_time": elapsed_time,
"token_usage": token_usage,
"error": None
"error": None,
**self._extract_extra_fields(business_result),
}
final_output = {
"node_outputs": {self.node_id: node_output},
@@ -498,6 +499,13 @@ class BaseNode(ABC):
# Default implementation returns the business result directly
return business_result
def _extract_extra_fields(self, business_result: Any) -> dict:
"""Extracts extra fields to merge into node_output (e.g. citations).
Subclasses may override to inject additional metadata.
"""
return {}
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
"""Extracts token usage information from the business result.
@@ -552,9 +560,9 @@ class BaseNode(ABC):
return render_template(
template=template,
conv_vars=variable_pool.get_all_conversation_vars(literal=True),
node_outputs=variable_pool.get_all_node_outputs(literal=True),
system_vars=variable_pool.get_all_system_vars(literal=True),
conv_vars=variable_pool.lazy_namespace("conv", literal=True),
node_outputs=variable_pool.lazy_all_node_outputs(literal=True),
system_vars=variable_pool.lazy_namespace("sys", literal=True),
strict=strict
)
@@ -579,9 +587,9 @@ class BaseNode(ABC):
return evaluate_condition(
expression=expression,
conv_var=variable_pool.get_all_conversation_vars(),
node_outputs=variable_pool.get_all_node_outputs(),
system_vars=variable_pool.get_all_system_vars()
conv_var=variable_pool.lazy_namespace("conv"),
node_outputs=variable_pool.lazy_all_node_outputs(),
system_vars=variable_pool.lazy_namespace("sys")
)
@staticmethod

View File

@@ -13,7 +13,7 @@ from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes import BaseNode
from app.core.workflow.nodes.code.config import CodeNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
logger = logging.getLogger(__name__)
@@ -70,7 +70,8 @@ class CodeNode(BaseNode):
for output in self.typed_config.output_variables:
value = exec_result.get(output.name)
if value is None:
raise RuntimeError(f"Return value {output.name} does not exist")
result[output.name] = DEFAULT_VALUE(output.type)
continue
match output.type:
case VariableType.STRING:
if not isinstance(value, str):

View File

@@ -24,6 +24,8 @@ from app.core.workflow.nodes.start.config import StartNodeConfig
from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
from app.core.workflow.nodes.notes.config import NoteNodeConfig
from app.core.workflow.nodes.list_operator.config import ListOperatorNodeConfig
from app.core.workflow.nodes.document_extractor.config import DocExtractorNodeConfig
__all__ = [
# 基础类
@@ -49,5 +51,7 @@ __all__ = [
"MemoryReadNodeConfig",
"MemoryWriteNodeConfig",
"CodeNodeConfig",
"NoteNodeConfig"
"NoteNodeConfig",
"ListOperatorNodeConfig",
"DocExtractorNodeConfig",
]

View File

@@ -11,7 +11,6 @@ from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator, NodeType
from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance
from app.core.workflow.utils.expression_evaluator import evaluate_expression
logger = logging.getLogger(__name__)
@@ -85,12 +84,7 @@ class LoopRuntime:
for variable in self.typed_config.cycle_vars:
if variable.input_type == ValueInputType.VARIABLE:
value = evaluate_expression(
expression=variable.value,
conv_var=self.variable_pool.get_all_conversation_vars(),
node_outputs=self.variable_pool.get_all_node_outputs(),
system_vars=self.variable_pool.get_all_system_vars(),
)
value = self.variable_pool.get_value(variable.value)
else:
value = TypeTransformer.transform(variable.value, variable.type)
await self.child_variable_pool.new(self.node_id, variable.name, value, variable.type, mut=True)
@@ -98,12 +92,7 @@ class LoopRuntime:
**self.state
)
loopstate["node_outputs"][self.node_id] = {
variable.name: evaluate_expression(
expression=variable.value,
conv_var=self.variable_pool.get_all_conversation_vars(),
node_outputs=self.variable_pool.get_all_node_outputs(),
system_vars=self.variable_pool.get_all_system_vars(),
)
variable.name: self.variable_pool.get_value(variable.value)
if variable.input_type == ValueInputType.VARIABLE
else TypeTransformer.transform(variable.value, variable.type)
for variable in self.typed_config.cycle_vars

View File

@@ -14,12 +14,22 @@ logger = logging.getLogger(__name__)
def _file_object_to_file_input(f: FileObject) -> FileInput:
"""Convert workflow FileObject to multimodal FileInput."""
file_type = f.origin_file_type or ""
# Prefer mime_type for more accurate type detection
if not file_type and f.mime_type:
file_type = f.mime_type
resolved_type = FileType.trans(f.type) if isinstance(f.type, str) else f.type
if resolved_type != FileType.DOCUMENT:
raise ValueError(
f"Document extractor only supports document files, got type '{f.type}' "
f"(name={f.name or f.file_id or f.url})"
)
return FileInput(
type=FileType.DOCUMENT,
type=resolved_type,
transfer_method=TransferMethod(f.transfer_method),
url=f.url or None,
upload_file_id=f.file_id or None,
file_type=f.origin_file_type or "",
file_type=file_type,
)
@@ -81,6 +91,7 @@ class DocExtractorNode(BaseNode):
from app.services.multimodal_service import MultimodalService
svc = MultimodalService(db)
for f in files:
label = f.name or f.url or f.file_id
try:
file_input = _file_object_to_file_input(f)
# Ensure URL is populated for local files
@@ -89,11 +100,11 @@ class DocExtractorNode(BaseNode):
# Reuse cached bytes if already fetched
if f.get_content():
file_input.set_content(f.get_content())
text = await svc._extract_document_text(file_input)
text = await svc.extract_document_text(file_input)
chunks.append(text)
except Exception as e:
logger.error(
f"Node {self.node_id}: failed to extract file url={f.url} file_id={f.file_id}: {e}",
f"Node {self.node_id}: failed to extract file '{label}': {e}",
exc_info=True,
)
chunks.append("")

View File

@@ -24,6 +24,7 @@ class NodeType(StrEnum):
MEMORY_READ = "memory-read"
MEMORY_WRITE = "memory-write"
DOCUMENT_EXTRACTOR = "document-extractor"
LIST_OPERATOR = "list-operator"
UNKNOWN = "unknown"
NOTES = "notes"
@@ -45,6 +46,8 @@ class ComparisonOperator(StrEnum):
LE = "le"
GT = "gt"
GE = "ge"
IN = "in"
NOT_IN = "not_in"
class LogicOperator(StrEnum):

View File

@@ -1,19 +1,23 @@
import asyncio
import logging
import uuid
from typing import Any
from langchain_core.documents import Document
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.models import RedBearRerank, RedBearModelConfig
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory, ElasticSearchVector
from app.core.rag.models.chunk import DocumentChunk
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.db import get_db_read
from app.models import knowledge_model, knowledgeshare_model, ModelType
from app.repositories import knowledge_repository, knowledgeshare_repository
from app.models import knowledge_model, ModelType
from app.repositories import knowledge_repository
from app.schemas.chunk_schema import RetrieveType
from app.services.model_service import ModelConfigService
@@ -24,13 +28,26 @@ class KnowledgeRetrievalNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
self.vector_service: ElasticSearchVector | None = None
def _output_types(self) -> dict[str, VariableType]:
return {
"output": VariableType.ARRAY_STRING
}
def _extract_output(self, business_result: Any) -> Any:
"""下游节点只拿 chunks 列表"""
if isinstance(business_result, dict) and "chunks" in business_result:
return business_result["chunks"]
return business_result
def _extract_citations(self, business_result: Any) -> list:
if isinstance(business_result, dict):
return business_result.get("citations", [])
return []
def _extract_extra_fields(self, business_result: Any) -> dict:
return {"citations": self._extract_citations(business_result)}
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
return {
"query": self._render_template(self.typed_config.query, variable_pool),
@@ -85,46 +102,54 @@ class KnowledgeRetrievalNode(BaseNode):
unique.append(doc)
return unique
def _get_existing_kb_ids(self, db, kb_ids):
def rerank(self, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]:
"""
Resolve all accessible and valid knowledge base IDs for retrieval.
This includes:
- Private knowledge bases owned by the user
- Shared knowledge bases
- Source knowledge bases mapped via knowledge sharing relationships
Reorder the list of document blocks and return the top_k results most relevant to the query
Args:
db: Database session.
kb_ids (list[UUID]): Knowledge base IDs from node configuration.
query: query string
docs: List of document chunk to be rearranged
top_k: The number of top-level documents returned
Returns:
list[UUID]: Final list of valid knowledge base IDs.
Rearranged document chunk list (sorted in descending order of relevance)
Raises:
ValueError: If the input document list is empty or top_k is invalid
"""
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Private)
existing_ids = knowledge_repository.get_chunked_knowledgeids(
db=db,
filters=filters
)
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Share)
share_ids = knowledge_repository.get_chunked_knowledgeids(
db=db,
filters=filters
)
if share_ids:
filters = [
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(kb_ids)
reranker = self.get_reranker_model()
# parameter validation
if not docs:
raise ValueError("retrieval chunks be empty")
if top_k <= 0:
raise ValueError("top_k must be a positive integer")
try:
# Convert to LangChain Document object
documents = [
Document(
page_content=doc.page_content, # Ensure that DocumentChunk possesses this attribute
metadata=doc.metadata or {} # Deal with possible None metadata
)
for doc in docs
]
items = knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
db=db,
filters=filters
# Perform reordering (compress_documents will automatically handle relevance scores and indexing)
reranked_docs = list(reranker.compress_documents(documents, query))
# Sort in descending order based on relevance score
reranked_docs.sort(
key=lambda x: x.metadata.get("relevance_score", 0),
reverse=True
)
existing_ids.extend(items)
return existing_ids
# Convert back to a list of DocumentChunk, and save the relevance_score to metadata["score"]
result = []
for item in reranked_docs[:top_k]:
for doc in docs:
if doc.page_content == item.page_content:
doc.metadata["score"] = item.metadata["relevance_score"]
result.append(doc)
return result
except Exception as e:
raise RuntimeError(f"Failed to rerank documents: {str(e)}") from e
def get_reranker_model(self) -> RedBearRerank:
"""
@@ -164,41 +189,77 @@ class KnowledgeRetrievalNode(BaseNode):
)
return reranker
def knowledge_retrieval(self, db, query, rs, db_knowledge, kb_config):
async def knowledge_retrieval(self, db, query, db_knowledge, kb_config):
rs = []
if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER:
children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id)
tasks = []
for child in children:
if not (child and child.chunk_num > 0 and child.status == 1):
continue
kb_config.kb_id = child.id
self.knowledge_retrieval(db, query, rs, child, kb_config)
return
self.vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
child_kb_config = kb_config.model_copy()
child_kb_config.kb_id = child.id
tasks.append(self.knowledge_retrieval(db, query, child, child_kb_config))
if tasks:
result = await asyncio.gather(*tasks)
for _ in result:
rs.extend(_)
return rs
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
match kb_config.retrieve_type:
case RetrieveType.PARTICIPLE:
rs.extend(self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.similarity_threshold))
rs.extend(
await asyncio.to_thread(
vector_service.search_by_full_text, **{
"query": query,
"top_k": kb_config.top_k,
"indices": indices,
"score_threshold": kb_config.similarity_threshold
}
)
)
case RetrieveType.SEMANTIC:
rs.extend(self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.vector_similarity_weight))
rs.extend(
await asyncio.to_thread(
vector_service.search_by_vector, **{
"query": query,
"top_k": kb_config.top_k,
"indices": indices,
"score_threshold": kb_config.vector_similarity_weight
}
)
)
case RetrieveType.HYBRID:
rs1 = self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.vector_similarity_weight)
rs2 = self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.similarity_threshold)
rs1_task = asyncio.to_thread(
vector_service.search_by_vector, **{
"query": query,
"top_k": kb_config.top_k,
"indices": indices,
"score_threshold": kb_config.vector_similarity_weight
}
)
rs2_task = asyncio.to_thread(
vector_service.search_by_full_text, **{
"query": query,
"top_k": kb_config.top_k,
"indices": indices,
"score_threshold": kb_config.similarity_threshold
}
)
rs1, rs2 = await asyncio.gather(rs1_task, rs2_task)
# Deduplicate hybrid retrieval results
unique_rs = self._deduplicate_docs(rs1, rs2)
if not unique_rs:
return
return []
if self.typed_config.reranker_id:
self.vector_service.reranker = self.get_reranker_model()
rs.extend(self.vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
rs.extend(
await asyncio.to_thread(
self.rerank,
**{"query": query, "docs": unique_rs, "top_k": kb_config.top_k}
)
)
else:
rs.extend(sorted(
unique_rs,
@@ -207,6 +268,7 @@ class KnowledgeRetrievalNode(BaseNode):
)[:kb_config.top_k])
case _:
raise RuntimeError("Unknown retrieval type")
return rs
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
"""
@@ -238,17 +300,24 @@ class KnowledgeRetrievalNode(BaseNode):
knowledge_bases = self.typed_config.knowledge_bases
rs = []
tasks = []
for kb_config in knowledge_bases:
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
if not db_knowledge:
raise RuntimeError("The knowledge base does not exist or access is denied.")
self.knowledge_retrieval(db, query, rs, db_knowledge, kb_config)
tasks.append(self.knowledge_retrieval(db, query, db_knowledge, kb_config))
if tasks:
result = await asyncio.gather(*tasks)
for _ in result:
rs.extend(_)
if not rs:
return []
if self.typed_config.reranker_id:
self.vector_service.reranker = self.get_reranker_model()
final_rs = self.vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
final_rs = await asyncio.to_thread(
self.rerank,
**{"query": query, "docs": rs, "top_k": self.typed_config.reranker_top_k}
)
else:
final_rs = sorted(
rs,
@@ -259,4 +328,20 @@ class KnowledgeRetrievalNode(BaseNode):
logger.info(
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
)
return [chunk.page_content for chunk in final_rs]
citations = []
seen_doc_ids = set()
for chunk in final_rs:
meta = chunk.metadata or {}
doc_id = meta.get("document_id") or meta.get("doc_id")
if doc_id and doc_id not in seen_doc_ids:
seen_doc_ids.add(doc_id)
citations.append({
"document_id": str(doc_id),
"file_name": meta.get("file_name", ""),
"knowledge_id": str(meta.get("knowledge_id", kb_config.kb_id)),
"score": meta.get("score", 0.0),
})
return {
"chunks": [chunk.page_content for chunk in final_rs],
"citations": citations,
}

View File

@@ -0,0 +1,3 @@
from .node import ListOperatorNode
__all__ = ["ListOperatorNode"]

View File

@@ -0,0 +1,49 @@
from typing import Any
from pydantic import BaseModel, Field, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.enums import ComparisonOperator
class FilterCondition(BaseModel):
key: str = ""
comparison_operator: ComparisonOperator = ComparisonOperator.CONTAINS
value: str | list[str] | bool = ""
class FilterBy(BaseModel):
enabled: bool = False
conditions: list[FilterCondition] = Field(default_factory=list)
class OrderByConfig(BaseModel):
enabled: bool = False
key: str = ""
value: str = "asc" # "asc" | "desc"
class Limit(BaseModel):
enabled: bool = False
size: int = -1
class ExtractConfig(BaseModel):
enabled: bool = False
serial: str = "1" # 1-based index string, e.g. "1" = first
@field_validator("serial", mode="before")
@classmethod
def coerce_serial(cls, v):
return str(v)
class ListOperatorNodeConfig(BaseNodeConfig):
"""
List Operator node config.
Operation order: filter -> extract -> order -> limit
"""
input_list: str = Field(..., description="Variable selector, e.g. {{ sys.files }} or {{ conv.uploaded_files }}")
filter_by: FilterBy = Field(default_factory=FilterBy)
order_by: OrderByConfig = Field(default_factory=OrderByConfig)
limit: Limit = Field(default_factory=Limit)
extract_by: ExtractConfig = Field(default_factory=ExtractConfig)

View File

@@ -0,0 +1,150 @@
import logging
from typing import Any
from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import ComparisonOperator
from app.core.workflow.nodes.list_operator.config import ListOperatorNodeConfig, FilterCondition
from app.core.workflow.variable.base_variable import VariableType
logger = logging.getLogger(__name__)
# File object fields that hold string values
_FILE_STRING_KEYS = {"type", "name", "url", "extension", "mime_type", "transfer_method", "origin_file_type", "file_id"}
_FILE_NUMBER_KEYS = {"size"}
class ListOperatorNode(BaseNode):
def __init__(self, node_config: dict, workflow_config: dict, down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: ListOperatorNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:
return {
"result": VariableType.ANY,
"first_record": VariableType.ANY,
"last_record": VariableType.ANY,
}
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
self.typed_config = ListOperatorNodeConfig(**self.config)
cfg = self.typed_config
# Resolve input variable from path selector
items: list = self.get_variable(cfg.input_list, variable_pool)
if not isinstance(items, list):
raise TypeError(f"Variable '{cfg.input_list}' must be an array, got {type(items)}")
result = list(items)
# 1. Filter
if cfg.filter_by.enabled and cfg.filter_by.conditions:
for condition in cfg.filter_by.conditions:
result = [item for item in result if self._match_condition(item, condition, variable_pool)]
# 2. Extract (take single item by 1-based serial index)
if cfg.extract_by.enabled:
serial_str = self._resolve_value(cfg.extract_by.serial, variable_pool)
idx = int(serial_str) - 1
if idx < 0 or idx >= len(result):
raise ValueError(f"extract_by.serial={cfg.extract_by.serial} out of range (list length={len(result)})")
result = [result[idx]]
# 3. Order
if cfg.order_by.enabled:
reverse = cfg.order_by.value == "desc"
key_fn = self._make_sort_key(cfg.order_by.key)
result = sorted(result, key=key_fn, reverse=reverse)
# 4. Limit (take first N)
if cfg.limit.enabled and cfg.limit.size > 0:
result = result[:cfg.limit.size]
return {
"result": result,
"first_record": result[0] if result else None,
"last_record": result[-1] if result else None,
}
@staticmethod
def _resolve_value(value: str, variable_pool: VariablePool) -> Any:
"""If value is a {{ namespace.key }} variable selector, resolve it from the pool.
Otherwise return the raw string."""
import re
m = re.fullmatch(r"\{\{\s*(\w+\.\w+)\s*}}", value.strip())
if m:
resolved = variable_pool.get_value(value, default=value, strict=False)
return resolved
return value
@staticmethod
def _make_sort_key(key: str):
def key_fn(item):
if isinstance(item, dict):
return item.get(key) or ""
return item
return key_fn
def _match_condition(self, item: Any, cond: FilterCondition, variable_pool: VariablePool) -> bool:
op = cond.comparison_operator
value = cond.value
# Resolve value if it's a variable reference {{ namespace.key }}
if isinstance(value, str):
value = self._resolve_value(value, variable_pool)
# Resolve left value
if isinstance(item, dict):
left = item.get(cond.key)
else:
left = item # primitive array: compare element directly
# Determine if this field should be compared as a string
is_string_field = isinstance(item, dict) and cond.key in _FILE_STRING_KEYS
# Numeric operators
if op == ComparisonOperator.EQ:
if is_string_field:
return str(left) == str(value)
return self._safe_num(left) == self._safe_num(value)
if op == ComparisonOperator.NE:
if is_string_field:
return str(left) != str(value)
return self._safe_num(left) != self._safe_num(value)
if op == ComparisonOperator.LT:
return self._safe_num(left) < self._safe_num(value)
if op == ComparisonOperator.LE:
return self._safe_num(left) <= self._safe_num(value)
if op == ComparisonOperator.GT:
return self._safe_num(left) > self._safe_num(value)
if op == ComparisonOperator.GE:
return self._safe_num(left) >= self._safe_num(value)
# String / sequence operators
left_str = str(left) if left is not None else ""
if op == ComparisonOperator.CONTAINS:
return str(value) in left_str
if op == ComparisonOperator.NOT_CONTAINS:
return str(value) not in left_str
if op == ComparisonOperator.START_WITH:
return left_str.startswith(str(value))
if op == ComparisonOperator.END_WITH:
return left_str.endswith(str(value))
if op == ComparisonOperator.IN:
return left_str in (value if isinstance(value, list) else [str(value)])
if op == ComparisonOperator.NOT_IN:
return left_str not in (value if isinstance(value, list) else [str(value)])
if op == ComparisonOperator.EMPTY:
return not left
if op == ComparisonOperator.NOT_EMPTY:
return bool(left)
raise ValueError(f"Unsupported operator: {op}")
@staticmethod
def _safe_num(v) -> float:
try:
return float(v)
except (TypeError, ValueError):
return 0.0

View File

@@ -213,9 +213,10 @@ class LLMNode(BaseNode):
messages = messages[:-1] + history_message + messages[-1:]
self.messages = messages
else:
# 使用简单的 prompt 格式(向后兼容)
# 使用简单的 prompt 格式(向后兼容)——包装为标准消息列表以兼容所有 provider
prompt_template = self.config.get("prompt", "")
self.messages = self._render_template(prompt_template, variable_pool)
rendered = self._render_template(prompt_template, variable_pool)
self.messages = [{"role": "user", "content": rendered}]
return llm
@@ -245,7 +246,10 @@ class LLMNode(BaseNode):
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
# 返回 AIMessage包含响应元数据
return AIMessage(content=content, response_metadata=response.response_metadata)
return AIMessage(content=content, response_metadata={
**response.response_metadata,
"token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage')
})
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""提取输入数据(用于记录)"""
@@ -304,15 +308,16 @@ class LLMNode(BaseNode):
# 调用 LLM流式支持字符串或消息列表
last_meta_data = {}
last_usage_metadata = {}
async for chunk in llm.astream(self.messages):
# 提取内容
if hasattr(chunk, 'content'):
content = self.process_model_output(chunk.content)
else:
content = str(chunk)
if hasattr(chunk, 'response_metadata'):
if chunk.response_metadata:
last_meta_data = chunk.response_metadata
if hasattr(chunk, 'response_metadata') and chunk.response_metadata:
last_meta_data = chunk.response_metadata
if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
last_usage_metadata = chunk.usage_metadata
# 只有当内容不为空时才处理
if content:
@@ -335,7 +340,10 @@ class LLMNode(BaseNode):
# 构建完整的 AIMessage包含元数据
final_message = AIMessage(
content=full_response,
response_metadata=last_meta_data
response_metadata={
**last_meta_data,
"token_usage": last_usage_metadata or last_meta_data.get('token_usage')
}
)
# yield 完成标记

View File

@@ -27,6 +27,7 @@ from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
from app.core.workflow.nodes.breaker import BreakNode
from app.core.workflow.nodes.tool import ToolNode
from app.core.workflow.nodes.document_extractor import DocExtractorNode
from app.core.workflow.nodes.list_operator import ListOperatorNode
logger = logging.getLogger(__name__)
@@ -51,7 +52,8 @@ WorkflowNode = Union[
MemoryReadNode,
MemoryWriteNode,
CodeNode,
DocExtractorNode
DocExtractorNode,
ListOperatorNode
]
@@ -83,7 +85,8 @@ class NodeFactory:
NodeType.MEMORY_READ: MemoryReadNode,
NodeType.MEMORY_WRITE: MemoryWriteNode,
NodeType.CODE: CodeNode,
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNode
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNode,
NodeType.LIST_OPERATOR: ListOperatorNode
}
@classmethod

View File

@@ -12,7 +12,7 @@ from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
from app.db import get_db_read
from app.models import ModelType
from app.services.model_service import ModelConfigService
@@ -45,6 +45,12 @@ class ParameterExtractorNode(BaseNode):
"model_id": str(self.typed_config.model_id),
}
def _extract_output(self, business_result: Any) -> Any:
final_output = {}
for param in self.typed_config.params:
final_output[param.name] = business_result.get(param.name) or DEFAULT_VALUE(self.output_types[param.name])
return final_output
def _output_types(self) -> dict[str, VariableType]:
outputs = {}
for param in self.typed_config.params:
@@ -109,6 +115,7 @@ class ParameterExtractorNode(BaseNode):
api_key = api_config.api_key
api_base = api_config.api_base
is_omni = api_config.is_omni
capability = api_config.capability
model_type = config.type
llm = RedBearLLM(
@@ -201,7 +208,10 @@ class ParameterExtractorNode(BaseNode):
])
model_resp = await llm.ainvoke(messages)
self.response_metadata = model_resp.response_metadata
self.response_metadata = {
**model_resp.response_metadata,
"token_usage": getattr(model_resp, 'usage_metadata', None) or model_resp.response_metadata.get('token_usage')
}
model_message = self.process_model_output(model_resp.content)
result = json_repair.repair_json(model_message, return_objects=True)
logger.info(f"node: {self.node_id} get params:{result}")

View File

@@ -62,6 +62,7 @@ class QuestionClassifierNode(BaseNode):
api_key = api_config.api_key
base_url = api_config.api_base
is_omni = api_config.is_omni
capability = api_config.capability
model_type = config.type
return RedBearLLM(
@@ -135,7 +136,10 @@ class QuestionClassifierNode(BaseNode):
response = await llm.ainvoke(messages)
result = self.process_model_output(response.content)
self.response_metadata = response.response_metadata
self.response_metadata = {
**response.response_metadata,
"token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage')
}
if result in category_names:
category = result

View File

@@ -4,32 +4,33 @@ from typing import Any
from simpleeval import simple_eval, NameNotDefined, InvalidExpression
from app.core.workflow.engine.variable_pool import LazyVariableDict, VARIABLE_PATTERN
logger = logging.getLogger(__name__)
_NORMALIZE_PATTERN = re.compile(r"\{\{\s*(\d+)\.(\w+)\s*}}")
class ExpressionEvaluator:
"""Safe expression evaluator for workflow variables and node outputs."""
# Reserved namespaces
RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"}
@classmethod
def normalize_template(cls, template: str) -> str:
pattern = re.compile(
r"\{\{\s*(\d+)\.(\w+)\s*}}"
)
return pattern.sub(
return _NORMALIZE_PATTERN.sub(
r'{{ node["\1"].\2 }}',
template
)
@classmethod
def evaluate(
cls,
expression: str,
conv_vars: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
cls,
expression: str,
conv_vars: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> Any:
"""
Safely evaluate an expression using workflow variables.
@@ -49,48 +50,47 @@ class ExpressionEvaluator:
# Remove Jinja2-style brackets if present
expression = expression.strip()
expression = cls.normalize_template(expression)
pattern = r"\{\{\s*(.*?)\s*\}\}"
expression = re.sub(pattern, r"\1", expression).strip()
expression = VARIABLE_PATTERN.sub(r"\1", expression).strip()
# Build context for evaluation
context = {
"conv": conv_vars, # conversation variables
"node": node_outputs, # node outputs
"sys": system_vars or {}, # system variables
"conv": conv_vars, # conversation variables
"node": node_outputs, # node outputs
"sys": system_vars or {}, # system variables
}
context.update(conv_vars)
context["nodes"] = node_outputs
# context.update(conv_vars)
# context["nodes"] = node_outputs
context.update(node_outputs)
try:
# simpleeval supports safe operations:
# arithmetic, comparisons, logical ops, attribute/dict/list access
result = simple_eval(expression, names=context)
return result
except NameNotDefined as e:
logger.error(f"Undefined variable in expression: {expression}, error: {e}")
raise ValueError(f"Undefined variable: {e}")
except InvalidExpression as e:
logger.error(f"Invalid expression syntax: {expression}, error: {e}")
raise ValueError(f"Invalid expression syntax: {e}")
except SyntaxError as e:
logger.error(f"Syntax error in expression: {expression}, error: {e}")
raise ValueError(f"Syntax error: {e}")
except Exception as e:
logger.error(f"Expression evaluation failed: {expression}, error: {e}")
raise ValueError(f"Expression evaluation failed: {e}")
@staticmethod
def evaluate_bool(
expression: str,
conv_var: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
expression: str,
conv_var: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> bool:
"""
Evaluate a boolean expression (for conditions).
@@ -108,7 +108,7 @@ class ExpressionEvaluator:
expression, conv_var, node_outputs, system_vars
)
return bool(result)
@staticmethod
def validate_variable_names(variables: list[dict]) -> list[str]:
"""
@@ -121,7 +121,7 @@ class ExpressionEvaluator:
list[str]: List of error messages. Empty if all names are valid.
"""
errors = []
for var in variables:
var_name = var.get("name", "")
@@ -134,16 +134,16 @@ class ExpressionEvaluator:
errors.append(
f"Variable name '{var_name}' is not a valid Python identifier"
)
return errors
# 便捷函数
def evaluate_expression(
expression: str,
conv_var: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any]
expression: str,
conv_var: dict[str, Any] | LazyVariableDict,
node_outputs: dict[str, dict[str, Any] | LazyVariableDict],
system_vars: dict[str, Any] | LazyVariableDict
) -> Any:
"""Evaluate an expression (convenience function)."""
return ExpressionEvaluator.evaluate(
@@ -152,11 +152,11 @@ def evaluate_expression(
def evaluate_condition(
expression: str,
conv_var: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> bool:
expression: str,
conv_var: dict[str, Any] | LazyVariableDict,
node_outputs: dict[str, dict[str, Any] | LazyVariableDict],
system_vars: dict[str, Any] | LazyVariableDict
) -> Any:
"""Evaluate a boolean condition expression (convenience function)."""
return ExpressionEvaluator.evaluate_bool(
expression, conv_var, node_outputs, system_vars

View File

@@ -1,7 +1,10 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/3/10 13:36
import mimetypes
import os
import uuid
from typing import Any
from urllib.parse import urlparse, unquote
TRANSFORM_FILE_TYPE = {
'text/plain': 'document/text',
'text/markdown': 'document/markdown',
@@ -52,5 +55,143 @@ ALLOWED_FILE_TYPES = [
def mime_to_file_type(mime_type):
if mime_type not in ALLOWED_FILE_TYPES:
return None
return TRANSFORM_FILE_TYPE.get(mime_type, mime_type)
def build_file_object_dict_from_url(url: str, file_type: str, origin_file_type: str) -> dict[str, Any]:
"""Build a FileObject dict for a remote_url file using only URL parsing (no HTTP request).
Used as fallback when HTTP request fails.
"""
raw_path = url.split("?")[0]
name = unquote(os.path.basename(urlparse(url).path)) or None
_, ext = os.path.splitext(name or "")
extension = ext.lstrip(".").lower() if ext else None
guessed_mime = mimetypes.guess_type(url)[0]
return {
"type": file_type,
"url": url,
"transfer_method": "remote_url",
"origin_file_type": origin_file_type,
"file_id": None,
"name": name,
"size": None,
"extension": extension,
"mime_type": guessed_mime or origin_file_type,
"is_file": True,
}
async def fetch_remote_file_meta(
url: str,
file_type: str,
origin_file_type: str,
) -> dict[str, Any]:
"""Fetch remote file metadata via HEAD (fallback GET) and build a FileObject dict.
Falls back to URL-only parsing if the HTTP request fails.
"""
import httpx
name = extension = None
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.head(url, follow_redirects=True)
if resp.status_code != 200:
resp = await client.get(url, follow_redirects=True)
cl = resp.headers.get("Content-Length")
size = int(cl) if cl else None
ct = resp.headers.get("Content-Type", "").split(";")[0].strip()
mime_type = ct or origin_file_type
cd = resp.headers.get("Content-Disposition", "")
if "filename=" in cd:
name = cd.split("filename=")[-1].strip('"').strip("'")
if not name:
name = unquote(os.path.basename(urlparse(url).path)) or None
if name:
_, ext = os.path.splitext(name)
extension = ext.lstrip(".").lower() if ext else None
if not extension and mime_type:
ext = mimetypes.guess_extension(mime_type)
extension = ext.lstrip(".").lower() if ext else None
except Exception:
return build_file_object_dict_from_url(url, file_type, origin_file_type)
return build_file_object_dict_from_meta(
file_type=file_type,
transfer_method="remote_url",
origin_file_type=origin_file_type,
file_id=None,
url=url,
file_name=name,
file_size=size,
file_ext=extension,
content_type=mime_type,
)
def build_file_object_dict_from_meta(
file_type: str,
transfer_method: str,
origin_file_type: str,
file_id: str,
url: str,
file_name: str | None,
file_size: int | None,
file_ext: str | None,
content_type: str | None,
) -> dict[str, Any]:
"""Build a FileObject dict from already-fetched FileMetadata fields."""
ext = (file_ext or "").lstrip(".")
return {
"type": file_type,
"url": url,
"transfer_method": transfer_method,
"origin_file_type": content_type or origin_file_type,
"file_id": file_id,
"name": file_name,
"size": file_size,
"extension": ext.lower() if ext else None,
"mime_type": content_type,
"is_file": True,
}
def resolve_local_file_object_dict(
db,
upload_file_id: str | uuid.UUID,
file_type: str,
origin_file_type: str,
) -> dict[str, Any] | None:
"""Query FileMetadata and build a FileObject dict for a local_file.
Returns None if the file is not found or not completed.
"""
from app.models.file_metadata_model import FileMetadata
from app.core.config import settings
try:
fid = uuid.UUID(str(upload_file_id))
except ValueError:
return None
meta = db.query(FileMetadata).filter(
FileMetadata.id == fid,
FileMetadata.status == "completed"
).first()
if not meta:
return None
url = f"{settings.FILE_LOCAL_SERVER_URL}/storage/permanent/{fid}"
return build_file_object_dict_from_meta(
file_type=file_type,
transfer_method="local_file",
origin_file_type=origin_file_type,
file_id=str(fid),
url=url,
file_name=meta.file_name,
file_size=meta.file_size,
file_ext=meta.file_ext,
content_type=meta.content_type,
)

View File

@@ -1,7 +1,8 @@
"""
模板渲染器
Template Renderer
使用 Jinja2 提供安全的模板渲染功能,支持变量引用和表达式。
Provides safe template rendering using Jinja2, supporting variable references
and expressions.
"""
import logging
@@ -10,11 +11,15 @@ from typing import Any
from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined
from app.core.workflow.engine.variable_pool import LazyVariableDict
logger = logging.getLogger(__name__)
_NORMALIZE_PATTERN = re.compile(r"\{\{\s*(\d+)\.(\w+)\s*}}")
class SafeUndefined(Undefined):
"""访问未定义属性不会报错,返回空字符串"""
"""Return empty string instead of raising error when accessing undefined variables"""
__slots__ = ()
def _fail_with_undefined_error(self, *args, **kwargs):
@@ -26,26 +31,22 @@ class SafeUndefined(Undefined):
class TemplateRenderer:
"""模板渲染器"""
def __init__(self, strict: bool = True):
"""初始化渲染器
"""Initialize renderer
Args:
strict: 是否使用严格模式(未定义变量会抛出异常)
strict: Whether to enable strict mode (raise error on undefined variables)
"""
self.strict = strict
self.env = Environment(
undefined=StrictUndefined if strict else SafeUndefined,
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML
autoescape=False # Disable auto-escaping since we handle plain text instead of HTML
)
@staticmethod
def normalize_template(template: str) -> str:
pattern = re.compile(
r"\{\{\s*(\d+)\.(\w+)\s*}}"
)
return pattern.sub(
"""Normalize template syntax (convert numeric node reference to dict access)"""
return _NORMALIZE_PATTERN.sub(
r'{{ node["\1"].\2 }}',
template
)
@@ -53,24 +54,24 @@ class TemplateRenderer:
def render(
self,
template: str,
conv_vars: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
conv_vars: dict[str, Any] | LazyVariableDict,
node_outputs: dict[str, Any] | dict[str, LazyVariableDict],
system_vars: dict[str, Any] | LazyVariableDict | None = None
) -> str:
"""渲染模板
"""Render template
Args:
template: 模板字符串
conv_vars: 会话变量
node_outputs: 节点输出结果
system_vars: 系统变量
template: Template string
conv_vars: Conversation variables
node_outputs: Node outputs
system_vars: System variables
Returns:
渲染后的字符串
Rendered string
Raises:
ValueError: 模板语法错误或变量未定义
ValueError: If template syntax is invalid or variables are undefined
Examples:
>>> renderer = TemplateRenderer()
>>> renderer.render(
@@ -80,122 +81,119 @@ class TemplateRenderer:
... {}
... )
'Hello World!'
>>> renderer.render(
... "分析结果: {{node.analyze.output}}",
... "Analysis result: {{node.analyze.output}}",
... {},
... {"analyze": {"output": "正面情绪"}},
... {"analyze": {"output": "positive sentiment"}},
... {}
... )
'分析结果: 正面情绪'
'Analysis result: positive sentiment'
"""
# 构建命名空间上下文
# Build namespace context
context = {
"conv": conv_vars, # 会话变量:{{conv.user_name}}
"node": node_outputs, # 节点输出:{{node.node_1.output}}
"sys": system_vars, # 系统变量:{{sys.execution_id}}
"conv": conv_vars, # Conversation variables: {{conv.user_name}}
"node": node_outputs, # Node outputs: {{node.node_1.output}}
"sys": system_vars, # System variables: {{sys.execution_id}}
}
# 支持直接通过节点ID访问节点输出{{llm_qa.output}}
# 将所有节点输出添加到顶层上下文
# Allow direct access to node outputs by node ID: {{llm_qa.output}}
if node_outputs:
context.update(node_outputs)
# 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}}
if conv_vars:
context.update(conv_vars)
context["nodes"] = node_outputs or {} # 旧语法兼容
# # 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}}
# if conv_vars:
# context.update(conv_vars)
#
# context["nodes"] = node_outputs or {} # 旧语法兼容
template = self.normalize_template(template)
try:
tmpl = self.env.from_string(template)
return tmpl.render(**context)
except TemplateSyntaxError as e:
logger.error(f"模板语法错误: {template}, 错误: {e}")
raise ValueError(f"模板语法错误: {e}")
logger.error(f"Template syntax error: {template}, error: {e}")
raise ValueError(f"Template syntax error: {e}")
except UndefinedError as e:
logger.error(f"模板中引用了未定义的变量: {template}, 错误: {e}")
raise ValueError(f"未定义的变量: {e}")
logger.error(f"Undefined variable in template: {template}, error: {e}")
raise ValueError(f"Undefined variable: {e}")
except Exception as e:
logger.error(f"模板渲染异常: {template}, 错误: {e}")
raise ValueError(f"模板渲染失败: {e}")
logger.error(f"Template rendering error: {template}, error: {e}")
raise ValueError(f"Template rendering failed: {e}")
def validate(self, template: str) -> list[str]:
"""验证模板语法
"""Validate template syntax
Args:
template: 模板字符串
template: Template string
Returns:
错误列表,如果为空则验证通过
List of errors (empty if valid)
Examples:
>>> renderer = TemplateRenderer()
>>> renderer.validate("Hello {{var.name}}!")
[]
>>> renderer.validate("Hello {{var.name") # 缺少结束标记
['模板语法错误: ...']
>>> renderer.validate("Hello {{var.name") # Missing closing tag
['Template syntax error: ...']
"""
errors = []
try:
self.env.from_string(template)
except TemplateSyntaxError as e:
errors.append(f"模板语法错误: {e}")
errors.append(f"Template syntax error: {e}")
except Exception as e:
errors.append(f"模板验证失败: {e}")
errors.append(f"Template validation failed: {e}")
return errors
# 全局渲染器实例(严格模式)
# Global renderer instances (strict / lenient)
_strict_renderer = TemplateRenderer(strict=True)
_lenient_renderer = TemplateRenderer(strict=False)
def render_template(
template: str,
conv_vars: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any],
conv_vars: dict[str, Any] | LazyVariableDict,
node_outputs: dict[str, Any] | dict[str, LazyVariableDict],
system_vars: dict[str, Any] | LazyVariableDict,
strict: bool = True
) -> str:
"""渲染模板(便捷函数)
"""Render template (convenience function)
Args:
strict: 严格模式
template: 模板字符串
conv_vars: 会话变量
node_outputs: 节点输出
system_vars: 系统变量
strict: Whether to use strict mode
template: Template string
conv_vars: Conversation variables
node_outputs: Node outputs
system_vars: System variables
Returns:
渲染后的字符串
Rendered string
Examples:
>>> render_template(
... "请分析: {{var.text}}",
... {"text": "这是一段文本"},
... "Analyze: {{var.text}}",
... {"text": "This is a text"},
... {},
... {}
... )
'请分析: 这是一段文本'
'Analyze: This is a text'
"""
renderer = _strict_renderer if strict else _lenient_renderer
return renderer.render(template, conv_vars, node_outputs, system_vars)
def validate_template(template: str) -> list[str]:
"""验证模板语法(便捷函数)
"""Validate template syntax (convenience function)
Args:
template: 模板字符串
template: Template string
Returns:
错误列表
List of errors
"""
return _strict_renderer.validate(template)

View File

@@ -301,7 +301,7 @@ class WorkflowValidator:
for node in nodes:
if node.get("type") not in [NodeType.START, NodeType.CYCLE_START, NodeType.END] and not node.get("name"):
errors.append(
f"节点 {node.get('id')} 缺少名称(发布时必须提供)"
f"节点 {node.get('name')} 缺少名称(发布时必须提供)"
)
# 2. 验证所有非 start/end 节点都有配置
@@ -311,7 +311,7 @@ class WorkflowValidator:
config = node.get("config")
if not config or not isinstance(config, dict):
errors.append(
f"节点 {node.get('id')} 缺少配置(发布时必须提供)"
f"节点 {node.get('name')} 缺少配置(发布时必须提供)"
)
# 3. 验证必填变量

View File

@@ -91,7 +91,7 @@ def DEFAULT_VALUE(var_type: VariableType) -> Any:
case VariableType.OBJECT:
return {}
case VariableType.FILE:
return None
return {}
case VariableType.ARRAY_STRING:
return []
case VariableType.ARRAY_NUMBER:
@@ -113,6 +113,12 @@ class FileObject(BaseModel):
origin_file_type: str
file_id: str | None
# Extended file metadata
name: str | None = None
size: int | None = None
extension: str | None = None
mime_type: str | None = None
content_cache: dict = Field(default_factory=dict)
is_file: bool

View File

@@ -66,20 +66,10 @@ class FileVariable(BaseVariable):
type = 'file'
def valid_value(self, value) -> FileObject:
if isinstance(value, dict):
if not value.get("is_file"):
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
return FileObject(
**{
"type": str(value.get('type')),
"transfer_method": value.get("transfer_method"),
"url": value.get('url'),
"file_id": value.get("file_id"),
"origin_file_type": value.get("origin_file_type"),
"is_file": True
}
)
return FileObject(**value)
if isinstance(value, FileObject):
return value
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
@@ -88,7 +78,7 @@ class FileVariable(BaseVariable):
return f'{"!"if self.value.type == FileType.IMAGE else ""}[file]({self.value.url})'
def get_value(self) -> Any:
return self.value.model_dump()
return self.value.model_dump(exclude={"content_cache"})
async def get_content(self):
total_bytes = 0
@@ -186,6 +176,8 @@ def create_variable_instance(var_type: VariableType, value: Any) -> T:
return BooleanVariable(value)
case VariableType.OBJECT:
return DictVariable(value)
case VariableType.FILE:
return FileVariable(value)
case VariableType.ARRAY_STRING:
return make_array(StringVariable, value)
case VariableType.ARRAY_NUMBER:

View File

@@ -62,6 +62,7 @@ async def lifespan(app: FastAPI):
else:
logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)")
await create_all_indexes()
logger.info("All neo4j indexes and constraints created successfully!")
logger.info("应用程序启动完成")

View File

@@ -81,7 +81,7 @@ class ModelConfig(BaseModel):
# 模型配置参数
capability = Column(ARRAY(String), default=list, nullable=False, server_default=text("'{}'::varchar[]"),
comment="模型能力列表(如['vision', 'audio', 'video']")
comment="模型能力列表(如['vision', 'audio', 'video', 'thinking']")
is_omni = Column(Boolean, default=False, nullable=False, server_default="false", comment="是否为Omni模型使用特殊API调用")
config = Column(JSON, comment="模型配置参数")
# - temperature : 控制生成文本的随机性。值越高,输出越随机、越有创造性;值越低,输出越确定、越保守。

View File

@@ -1,6 +1,6 @@
from datetime import datetime, timedelta
from datetime import datetime, time
from sqlalchemy.orm import Session
from sqlalchemy import func
from sqlalchemy import func, Table, MetaData
from uuid import UUID
from typing import Dict, Optional, Any
@@ -192,10 +192,63 @@ class HomePageRepository:
return workspaces, app_count_dict, user_count_dict
@staticmethod
def get_latest_version_introduction(db: Session) -> tuple[Optional[str], Optional[Dict[str, Any]]]:
"""
从数据库获取最新已发布的版本说明
使用反射方式读取表结构,不依赖 premium 模型类
Args:
db: 数据库会话
Returns:
(版本号,版本说明字典) 的元组
如果数据库中没有已发布的版本,返回 (None, None)
"""
try:
metadata = MetaData()
version_notes = Table('version_notes', metadata, autoload_with=db.bind)
# 获取最新已发布的版本(按发布时间倒序,日期相同时按版本号倒序)
query = db.query(version_notes).filter(
version_notes.c.is_published == True
).order_by(
version_notes.c.release_date.desc(),
version_notes.c.version.desc()
)
note = query.first()
if not note:
return None, None
version_info = {
"introduction": {
"codeName": note.code_name or "",
"releaseDate": int(datetime.combine(note.release_date, time()).timestamp() * 1000) if note.release_date else 0,
"upgradePosition": note.upgrade_position or "",
"coreUpgrades": note.core_upgrades or []
},
"introduction_en": {
"codeName": note.code_name_en or note.code_name or "",
"releaseDate": int(datetime.combine(note.release_date, time()).timestamp() * 1000) if note.release_date else 0,
"upgradePosition": note.upgrade_position_en or note.upgrade_position or "",
"coreUpgrades": note.core_upgrades_en or []
}
}
return note.version, version_info
except Exception as e:
import traceback
traceback.print_exc()
return None, None
@staticmethod
def get_version_introduction(db: Session, version: str) -> Optional[Dict[str, Any]]:
"""
从数据库获取版本说明(优先读取已发布的版本)
从数据库获取指定版本说明(优先读取已发布的版本)
使用反射方式读取表结构,不依赖 premium 模型类
Args:
@@ -207,11 +260,8 @@ class HomePageRepository:
如果数据库中没有该版本,返回 None
"""
try:
from sqlalchemy import Table, MetaData
metadata = MetaData()
version_notes = Table('version_notes', metadata, autoload_with=db.engine)
version_note_items = Table('version_note_items', metadata, autoload_with=db.engine)
note = db.query(version_notes).filter(
version_notes.c.version == version,
@@ -221,31 +271,18 @@ class HomePageRepository:
if not note:
return None
items = db.query(version_note_items).filter(
version_note_items.c.note_id == note.id
).order_by(version_note_items.c.sort_order).all()
core_upgrades = []
for item in items:
title = item.title
content = item.content
if content:
core_upgrades.append(f"{title}<br>{content}")
else:
core_upgrades.append(title)
return {
"introduction": {
"codeName": "",
"releaseDate": note.release_date.isoformat() if note.release_date else "",
"upgradePosition": "",
"coreUpgrades": core_upgrades
"codeName": note.code_name or "",
"releaseDate": int(datetime.combine(note.release_date, time()).timestamp() * 1000) if note.release_date else 0,
"upgradePosition": note.upgrade_position or "",
"coreUpgrades": note.core_upgrades or []
},
"introduction_en": {
"codeName": "",
"releaseDate": note.release_date.isoformat() if note.release_date else "",
"upgradePosition": "",
"coreUpgrades": core_upgrades
"codeName": note.code_name_en or note.code_name or "",
"releaseDate": int(datetime.combine(note.release_date, time()).timestamp() * 1000) if note.release_date else 0,
"upgradePosition": note.upgrade_position_en or note.upgrade_position or "",
"coreUpgrades": note.core_upgrades_en or []
}
}
except Exception:

View File

@@ -78,6 +78,15 @@ class MemoryConfigRepository:
OPTIONAL MATCH (n) WHERE n.end_user_id = $end_user_id RETURN 'ALL' AS Label, COUNT(n) AS Count
"""
# 批量查询多个用户的记忆数量简化版本只返回total
SEARCH_FOR_ALL_BATCH = """
MATCH (n) WHERE n.end_user_id IN $end_user_ids
RETURN
n.end_user_id as user_id,
count(n) as total
ORDER BY user_id
"""
# Extracted entity details within group/app/user
SEARCH_FOR_DETIALS = """
MATCH (n:ExtractedEntity)

View File

@@ -1,17 +1,17 @@
import asyncio
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def create_fulltext_indexes():
"""Create full-text indexes for keyword search with BM25 scoring."""
connector = Neo4jConnector()
try:
# 创建 Statements 索引
await connector.execute_query("""
CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
""")
# # 创建 Dialogues 索引
# await connector.execute_query("""
# CREATE FULLTEXT INDEX dialoguesFulltext IF NOT EXISTS FOR (d:Dialogue) ON EACH [d.content]
@@ -21,27 +21,35 @@ async def create_fulltext_indexes():
await connector.execute_query("""
CREATE FULLTEXT INDEX entitiesFulltext IF NOT EXISTS FOR (e:ExtractedEntity) ON EACH [e.name]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
""")
# 创建 Chunks 索引
await connector.execute_query("""
CREATE FULLTEXT INDEX chunksFulltext IF NOT EXISTS FOR (c:Chunk) ON EACH [c.content]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
""")
# 创建 MemorySummary 索引
await connector.execute_query("""
CREATE FULLTEXT INDEX summariesFulltext IF NOT EXISTS FOR (m:MemorySummary) ON EACH [m.content]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
""")
# 创建 Community 索引
await connector.execute_query("""
CREATE FULLTEXT INDEX communitiesFulltext IF NOT EXISTS FOR (c:Community) ON EACH [c.name, c.summary]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
# 创建 Perceptual 感知记忆索引
await connector.execute_query("""
CREATE FULLTEXT INDEX perceptualFulltext IF NOT EXISTS FOR (p:Perceptual) ON EACH [p.summary, p.topic, p.domain]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
finally:
await connector.close()
async def create_vector_indexes():
"""Create vector indexes for fast embedding similarity search.
@@ -50,8 +58,7 @@ async def create_vector_indexes():
"""
connector = Neo4jConnector()
try:
# Statement embedding index
await connector.execute_query("""
CREATE VECTOR INDEX statement_embedding_index IF NOT EXISTS
@@ -62,8 +69,7 @@ async def create_vector_indexes():
`vector.similarity_function`: 'cosine'
}}
""")
# Chunk embedding index
await connector.execute_query("""
CREATE VECTOR INDEX chunk_embedding_index IF NOT EXISTS
@@ -75,7 +81,6 @@ async def create_vector_indexes():
}}
""")
# Entity name embedding index
await connector.execute_query("""
CREATE VECTOR INDEX entity_embedding_index IF NOT EXISTS
@@ -86,8 +91,7 @@ async def create_vector_indexes():
`vector.similarity_function`: 'cosine'
}}
""")
# Memory summary embedding index
await connector.execute_query("""
CREATE VECTOR INDEX summary_embedding_index IF NOT EXISTS
@@ -98,7 +102,7 @@ async def create_vector_indexes():
`vector.similarity_function`: 'cosine'
}}
""")
# Community summary embedding index
await connector.execute_query("""
CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS
@@ -108,8 +112,8 @@ async def create_vector_indexes():
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""")
""")
# Dialogue embedding index (optional)
await connector.execute_query("""
CREATE VECTOR INDEX dialogue_embedding_index IF NOT EXISTS
@@ -120,15 +124,27 @@ async def create_vector_indexes():
`vector.similarity_function`: 'cosine'
}}
""")
# Perceptual summary embedding index
await connector.execute_query("""
CREATE VECTOR INDEX perceptual_summary_embedding_index IF NOT EXISTS
FOR (p:Perceptual)
ON p.summary_embedding
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""")
finally:
await connector.close()
async def create_unique_constraints():
"""Create uniqueness constraints for core node identifiers.
Ensures concurrent MERGE operations remain safe and prevents duplicates.
"""
connector = Neo4jConnector()
try:
try:
# Dialogue.id unique
await connector.execute_query(
"""
@@ -136,7 +152,7 @@ async def create_unique_constraints():
FOR (d:Dialogue) REQUIRE d.id IS UNIQUE
"""
)
# Statement.id unique
await connector.execute_query(
"""
@@ -144,7 +160,7 @@ async def create_unique_constraints():
FOR (s:Statement) REQUIRE s.id IS UNIQUE
"""
)
# Chunk.id unique
await connector.execute_query(
"""
@@ -152,13 +168,13 @@ async def create_unique_constraints():
FOR (c:Chunk) REQUIRE c.id IS UNIQUE
"""
)
finally:
await connector.close()
async def create_all_indexes():
"""Create all indexes and constraints in one go."""
await create_fulltext_indexes()
await create_vector_indexes()
await create_unique_constraints()
print("✓ All indexes and constraints created successfully!")

View File

@@ -1449,3 +1449,44 @@ ON CREATE SET r.end_user_id = edge.end_user_id,
r.created_at = edge.created_at
RETURN elementId(r) AS uuid
"""
SEARCH_PERCEPTUAL_BY_KEYWORD = """
CALL db.index.fulltext.queryNodes("perceptualFulltext", $q) YIELD node AS p, score
WHERE p.end_user_id = $end_user_id
RETURN p.id AS id,
p.end_user_id AS end_user_id,
p.perceptual_type AS perceptual_type,
p.file_path AS file_path,
p.file_name AS file_name,
p.file_ext AS file_ext,
p.summary AS summary,
p.keywords AS keywords,
p.topic AS topic,
p.domain AS domain,
p.created_at AS created_at,
p.file_type AS file_type,
score
ORDER BY score DESC
LIMIT $limit
"""
PERCEPTUAL_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('perceptual_summary_embedding_index', $limit * 100, $embedding)
YIELD node AS p, score
WHERE p.summary_embedding IS NOT NULL AND p.end_user_id = $end_user_id
RETURN p.id AS id,
p.end_user_id AS end_user_id,
p.perceptual_type AS perceptual_type,
p.file_path AS file_path,
p.file_name AS file_name,
p.file_ext AS file_ext,
p.summary AS summary,
p.keywords AS keywords,
p.topic AS topic,
p.domain AS domain,
p.created_at AS created_at,
p.file_type AS file_type,
score
ORDER BY score DESC
LIMIT $limit
"""

View File

@@ -8,6 +8,7 @@ from app.repositories.neo4j.cypher_queries import (
ENTITY_EMBEDDING_SEARCH,
EXPAND_COMMUNITY_STATEMENTS,
MEMORY_SUMMARY_EMBEDDING_SEARCH,
PERCEPTUAL_EMBEDDING_SEARCH,
SEARCH_CHUNK_BY_CHUNK_ID,
SEARCH_CHUNKS_BY_CONTENT,
SEARCH_COMMUNITIES_BY_KEYWORD,
@@ -15,6 +16,7 @@ from app.repositories.neo4j.cypher_queries import (
SEARCH_ENTITIES_BY_NAME,
SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
SEARCH_PERCEPTUAL_BY_KEYWORD,
SEARCH_STATEMENTS_BY_CREATED_AT,
SEARCH_STATEMENTS_BY_KEYWORD,
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
@@ -34,11 +36,11 @@ logger = logging.getLogger(__name__)
async def _update_activation_values_batch(
connector: Neo4jConnector,
nodes: List[Dict[str, Any]],
node_label: str,
end_user_id: Optional[str] = None,
max_retries: int = 3
connector: Neo4jConnector,
nodes: List[Dict[str, Any]],
node_label: str,
end_user_id: Optional[str] = None,
max_retries: int = 3
) -> List[Dict[str, Any]]:
"""
批量更新节点的激活值
@@ -58,7 +60,7 @@ async def _update_activation_values_batch(
"""
if not nodes:
return []
# 延迟导入以避免循环依赖
from app.core.memory.storage_services.forgetting_engine.access_history_manager import (
AccessHistoryManager,
@@ -66,7 +68,7 @@ async def _update_activation_values_batch(
from app.core.memory.storage_services.forgetting_engine.actr_calculator import (
ACTRCalculator,
)
# 创建计算器和管理器实例
actr_calculator = ACTRCalculator()
access_manager = AccessHistoryManager(
@@ -74,7 +76,7 @@ async def _update_activation_values_batch(
actr_calculator=actr_calculator,
max_retries=max_retries
)
# 提取节点ID列表并去重保持原始顺序
seen_ids = set()
unique_node_ids = []
@@ -83,7 +85,7 @@ async def _update_activation_values_batch(
if node_id and node_id not in seen_ids:
seen_ids.add(node_id)
unique_node_ids.append(node_id)
if not unique_node_ids:
logger.warning(f"批量更新激活值没有有效的节点ID")
return nodes
@@ -95,7 +97,7 @@ async def _update_activation_values_batch(
f"批量更新激活值检测到重复节点具有有效ID的节点数量={id_nodes_count}, "
f"去重后唯一ID数量={len(unique_node_ids)}"
)
# 批量记录访问
try:
updated_nodes = await access_manager.record_batch_access(
@@ -103,14 +105,14 @@ async def _update_activation_values_batch(
node_label=node_label,
end_user_id=end_user_id
)
logger.info(
f"批量更新激活值成功: {node_label}, "
f"更新数量={len(updated_nodes)}/{len(unique_node_ids)}"
)
return updated_nodes
except Exception as e:
logger.error(
f"批量更新激活值失败: {node_label}, 错误: {str(e)}"
@@ -120,9 +122,9 @@ async def _update_activation_values_batch(
async def _update_search_results_activation(
connector: Neo4jConnector,
results: Dict[str, List[Dict[str, Any]]],
end_user_id: Optional[str] = None
connector: Neo4jConnector,
results: Dict[str, List[Dict[str, Any]]],
end_user_id: Optional[str] = None
) -> Dict[str, List[Dict[str, Any]]]:
"""
更新搜索结果中所有知识节点的激活值
@@ -144,11 +146,11 @@ async def _update_search_results_activation(
'entities': 'ExtractedEntity',
'summaries': 'MemorySummary'
}
# 并行更新所有类型的节点
update_tasks = []
update_keys = []
for key, label in knowledge_node_types.items():
if key in results and results[key]:
update_tasks.append(
@@ -160,13 +162,13 @@ async def _update_search_results_activation(
)
)
update_keys.append(key)
if not update_tasks:
return results
# 并行执行所有更新
update_results = await asyncio.gather(*update_tasks, return_exceptions=True)
# 更新结果字典,保留原始搜索分数
updated_results = results.copy()
for key, update_result in zip(update_keys, update_results):
@@ -175,10 +177,10 @@ async def _update_search_results_activation(
# 保留原始的 score 字段BM25/Embedding 分数)
original_nodes = results[key]
updated_nodes = update_result
# 创建 ID 到更新节点的映射(用于快速查找激活值数据)
updated_map = {node.get('id'): node for node in updated_nodes if node.get('id')}
# 合并数据:保留所有原始节点(包括重复的),用更新后的激活值数据填充
merged_nodes = []
for original_node in original_nodes:
@@ -186,7 +188,7 @@ async def _update_search_results_activation(
if node_id and node_id in updated_map:
# 从原始节点开始,用更新后的激活值数据覆盖
merged_node = original_node.copy()
# 更新激活值相关字段
activation_fields = {
'activation_value',
@@ -196,35 +198,35 @@ async def _update_search_results_activation(
'importance_score',
'version',
'statement', # Statement 节点的内容字段
'content' # MemorySummary 节点的内容字段
'content' # MemorySummary 节点的内容字段
}
# 只更新激活值相关字段,保留原始节点的其他字段
for field in activation_fields:
if field in updated_map[node_id]:
merged_node[field] = updated_map[node_id][field]
merged_nodes.append(merged_node)
else:
# 如果没有更新数据,保留原始节点
merged_nodes.append(original_node)
updated_results[key] = merged_nodes
else:
# 更新失败,记录错误但保留原始结果
logger.warning(
f"更新 {key} 激活值失败: {str(update_result)}"
)
return updated_results
async def search_graph(
connector: Neo4jConnector,
q: str,
end_user_id: Optional[str] = None,
limit: int = 50,
include: List[str] = None,
connector: Neo4jConnector,
q: str,
end_user_id: Optional[str] = None,
limit: int = 50,
include: List[str] = None,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Search across Statements, Entities, Chunks, and Summaries using a free-text query.
@@ -249,41 +251,45 @@ async def search_graph(
"""
if include is None:
include = ["statements", "chunks", "entities", "summaries"]
# Prepare tasks for parallel execution
tasks = []
task_keys = []
if "statements" in include:
tasks.append(connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD,
json_format=True,
q=q,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("statements")
if "entities" in include:
tasks.append(connector.execute_query(
SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
json_format=True,
q=q,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("entities")
if "chunks" in include:
tasks.append(connector.execute_query(
SEARCH_CHUNKS_BY_CONTENT,
json_format=True,
q=q,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("chunks")
if "summaries" in include:
tasks.append(connector.execute_query(
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
json_format=True,
q=q,
end_user_id=end_user_id,
limit=limit,
@@ -293,15 +299,16 @@ async def search_graph(
if "communities" in include:
tasks.append(connector.execute_query(
SEARCH_COMMUNITIES_BY_KEYWORD,
json_format=True,
q=q,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("communities")
# Execute all queries in parallel
task_results = await asyncio.gather(*tasks, return_exceptions=True)
# Build results dictionary
results = {}
for key, result in zip(task_keys, task_results):
@@ -310,14 +317,14 @@ async def search_graph(
results[key] = []
else:
results[key] = result
# Deduplicate results before updating activation values
# This prevents duplicates from propagating through the pipeline
from app.core.memory.src.search import _deduplicate_results
for key in results:
if isinstance(results[key], list):
results[key] = _deduplicate_results(results[key])
# 更新知识节点的激活值Statement, ExtractedEntity, MemorySummary
# Skip activation updates if only searching summaries (optimization)
needs_activation_update = any(
@@ -331,17 +338,17 @@ async def search_graph(
results=results,
end_user_id=end_user_id
)
return results
async def search_graph_by_embedding(
connector: Neo4jConnector,
embedder_client,
query_text: str,
end_user_id: Optional[str] = None,
limit: int = 50,
include: List[str] = ["statements", "chunks", "entities","summaries"],
connector: Neo4jConnector,
embedder_client,
query_text: str,
end_user_id: Optional[str] = None,
limit: int = 50,
include: List[str] = ["statements", "chunks", "entities", "summaries"],
) -> Dict[str, List[Dict[str, Any]]]:
"""
Embedding-based semantic search across Statements, Chunks, and Entities.
@@ -355,13 +362,13 @@ async def search_graph_by_embedding(
- Returns up to 'limit' per included type
"""
import time
# Get embedding for the query
embed_start = time.time()
embeddings = await embedder_client.response([query_text])
embed_time = time.time() - embed_start
print(f"[PERF] Embedding generation took: {embed_time:.4f}s")
logger.debug(f"[PERF] Embedding generation took: {embed_time:.4f}s")
if not embeddings or not embeddings[0]:
logger.warning(
f"search_graph_by_embedding: embedding 生成失败或为空,"
@@ -378,6 +385,7 @@ async def search_graph_by_embedding(
if "statements" in include:
tasks.append(connector.execute_query(
STATEMENT_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
@@ -388,6 +396,7 @@ async def search_graph_by_embedding(
if "chunks" in include:
tasks.append(connector.execute_query(
CHUNK_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
@@ -398,6 +407,7 @@ async def search_graph_by_embedding(
if "entities" in include:
tasks.append(connector.execute_query(
ENTITY_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
@@ -408,6 +418,7 @@ async def search_graph_by_embedding(
if "summaries" in include:
tasks.append(connector.execute_query(
MEMORY_SUMMARY_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
@@ -418,6 +429,7 @@ async def search_graph_by_embedding(
if "communities" in include:
tasks.append(connector.execute_query(
COMMUNITY_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
@@ -428,8 +440,8 @@ async def search_graph_by_embedding(
query_start = time.time()
task_results = await asyncio.gather(*tasks, return_exceptions=True)
query_time = time.time() - query_start
print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
logger.debug(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
# Build results dictionary
results: Dict[str, List[Dict[str, Any]]] = {
"statements": [],
@@ -438,7 +450,7 @@ async def search_graph_by_embedding(
"summaries": [],
"communities": [],
}
for key, result in zip(task_keys, task_results):
if isinstance(result, Exception):
logger.warning(f"search_graph_by_embedding: {key} 向量查询异常: {result}")
@@ -473,13 +485,15 @@ async def search_graph_by_embedding(
logger.info(f"[PERF] Skipping activation updates (only summaries)")
return results
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
connector: Neo4jConnector,
end_user_id: str,
entities: List[Dict[str, Any]],
use_contains_fallback: bool = True,
batch_size: int = 500,
max_concurrency: int = 5,
connector: Neo4jConnector,
end_user_id: str,
entities: List[Dict[str, Any]],
use_contains_fallback: bool = True,
batch_size: int = 500,
max_concurrency: int = 5,
) -> Dict[str, List[Dict[str, Any]]]:
"""
为第二层去重消歧批量检索候选实体(适配新版 cypher_queries
@@ -560,14 +574,14 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
async def search_graph_by_keyword_temporal(
connector: Neo4jConnector,
query_text: str,
end_user_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 50,
connector: Neo4jConnector,
query_text: str,
end_user_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 50,
) -> Dict[str, List[Any]]:
"""
Temporal keyword search across Statements.
@@ -579,7 +593,7 @@ async def search_graph_by_keyword_temporal(
- Returns up to 'limit' statements
"""
if not query_text:
print(f"query_text不能为空")
logger.warning(f"query_text不能为空")
return {"statements": []}
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
@@ -591,7 +605,7 @@ async def search_graph_by_keyword_temporal(
invalid_date=invalid_date,
limit=limit,
)
print(f"查询结果为:\n{statements}")
logger.debug(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
@@ -605,13 +619,13 @@ async def search_graph_by_keyword_temporal(
async def search_graph_by_temporal(
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 10,
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
@@ -632,10 +646,6 @@ async def search_graph_by_temporal(
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
@@ -643,15 +653,15 @@ async def search_graph_by_temporal(
results=results,
end_user_id=end_user_id
)
return results
async def search_graph_by_dialog_id(
connector: Neo4jConnector,
dialog_id: str,
end_user_id: Optional[str] = None,
limit: int = 1,
connector: Neo4jConnector,
dialog_id: str,
end_user_id: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Dialogues.
@@ -661,7 +671,7 @@ async def search_graph_by_dialog_id(
- Returns up to 'limit' dialogues
"""
if not dialog_id:
print(f"dialog_id不能为空")
logger.warning(f"dialog_id不能为空")
return {"dialogues": []}
dialogues = await connector.execute_query(
@@ -674,13 +684,13 @@ async def search_graph_by_dialog_id(
async def search_graph_by_chunk_id(
connector: Neo4jConnector,
chunk_id : str,
end_user_id: Optional[str] = None,
limit: int = 1,
connector: Neo4jConnector,
chunk_id: str,
end_user_id: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
if not chunk_id:
print(f"chunk_id不能为空")
logger.warning(f"chunk_id不能为空")
return {"chunks": []}
chunks = await connector.execute_query(
SEARCH_CHUNK_BY_CHUNK_ID,
@@ -692,10 +702,10 @@ async def search_graph_by_chunk_id(
async def search_graph_community_expand(
connector: Neo4jConnector,
community_ids: List[str],
end_user_id: str,
limit: int = 10,
connector: Neo4jConnector,
community_ids: List[str],
end_user_id: str,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
三期:社区展开检索 —— 主题 → 细节两级检索。
@@ -748,12 +758,11 @@ async def search_graph_community_expand(
async def search_graph_by_created_at(
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
@@ -767,16 +776,11 @@ async def search_graph_by_created_at(
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_CREATED_AT,
end_user_id=end_user_id,
created_at=created_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id} created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
@@ -784,16 +788,16 @@ async def search_graph_by_created_at(
results=results,
end_user_id=end_user_id
)
return results
async def search_graph_by_valid_at(
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
@@ -807,16 +811,11 @@ async def search_graph_by_valid_at(
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_VALID_AT,
end_user_id=end_user_id,
valid_at=valid_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id} valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
@@ -824,16 +823,16 @@ async def search_graph_by_valid_at(
results=results,
end_user_id=end_user_id
)
return results
async def search_graph_g_created_at(
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
@@ -847,16 +846,11 @@ async def search_graph_g_created_at(
statements = await connector.execute_query(
SEARCH_STATEMENTS_G_CREATED_AT,
end_user_id=end_user_id,
created_at=created_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
@@ -864,16 +858,16 @@ async def search_graph_g_created_at(
results=results,
end_user_id=end_user_id
)
return results
async def search_graph_g_valid_at(
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
@@ -887,16 +881,10 @@ async def search_graph_g_valid_at(
statements = await connector.execute_query(
SEARCH_STATEMENTS_G_VALID_AT,
end_user_id=end_user_id,
valid_at=valid_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
@@ -904,16 +892,16 @@ async def search_graph_g_valid_at(
results=results,
end_user_id=end_user_id
)
return results
async def search_graph_l_created_at(
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
@@ -927,16 +915,11 @@ async def search_graph_l_created_at(
statements = await connector.execute_query(
SEARCH_STATEMENTS_L_CREATED_AT,
end_user_id=end_user_id,
created_at=created_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
@@ -944,16 +927,16 @@ async def search_graph_l_created_at(
results=results,
end_user_id=end_user_id
)
return results
async def search_graph_l_valid_at(
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
@@ -967,16 +950,11 @@ async def search_graph_l_valid_at(
statements = await connector.execute_query(
SEARCH_STATEMENTS_L_VALID_AT,
end_user_id=end_user_id,
valid_at=valid_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
@@ -984,5 +962,89 @@ async def search_graph_l_valid_at(
results=results,
end_user_id=end_user_id
)
return results
async def search_perceptual(
connector: Neo4jConnector,
q: str,
end_user_id: Optional[str] = None,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Search Perceptual memory nodes using fulltext keyword search.
Matches against summary, topic, and domain fields via the perceptualFulltext index.
Args:
connector: Neo4j connector
q: Query text
end_user_id: Optional user filter
limit: Max results
Returns:
Dictionary with 'perceptuals' key containing matched perceptual memory nodes
"""
try:
perceptuals = await connector.execute_query(
SEARCH_PERCEPTUAL_BY_KEYWORD,
q=q,
end_user_id=end_user_id,
limit=limit,
)
except Exception as e:
logger.warning(f"search_perceptual: keyword search failed: {e}")
perceptuals = []
# Deduplicate
from app.core.memory.src.search import _deduplicate_results
perceptuals = _deduplicate_results(perceptuals)
return {"perceptuals": perceptuals}
async def search_perceptual_by_embedding(
connector: Neo4jConnector,
embedder_client,
query_text: str,
end_user_id: Optional[str] = None,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Search Perceptual memory nodes using embedding-based semantic search.
Uses cosine similarity on summary_embedding via the perceptual_summary_embedding_index.
Args:
connector: Neo4j connector
embedder_client: Embedding client with async response() method
query_text: Query text to embed
end_user_id: Optional user filter
limit: Max results
Returns:
Dictionary with 'perceptuals' key containing matched perceptual memory nodes
"""
embeddings = await embedder_client.response([query_text])
if not embeddings or not embeddings[0]:
logger.warning(f"search_perceptual_by_embedding: embedding generation failed for '{query_text[:50]}'")
return {"perceptuals": []}
embedding = embeddings[0]
try:
perceptuals = await connector.execute_query(
PERCEPTUAL_EMBEDDING_SEARCH,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
)
except Exception as e:
logger.warning(f"search_perceptual_by_embedding: vector search failed: {e}")
perceptuals = []
from app.core.memory.src.search import _deduplicate_results
perceptuals = _deduplicate_results(perceptuals)
return {"perceptuals": perceptuals}

View File

@@ -11,10 +11,28 @@ Classes:
from typing import Any, List, Dict
from neo4j import AsyncGraphDatabase, basic_auth
from neo4j.time import DateTime as Neo4jDateTime, Date as Neo4jDate, Time as Neo4jTime, Duration as Neo4jDuration
from app.core.config import settings
def _convert_neo4j_types(value: Any) -> Any:
"""递归将 neo4j 原生时间类型转为 Python 原生类型 / ISO 字符串,确保可被 json.dumps 序列化。"""
if isinstance(value, Neo4jDateTime):
return value.to_native().isoformat() if value.tzinfo else value.iso_format()
if isinstance(value, Neo4jDate):
return value.iso_format()
if isinstance(value, Neo4jTime):
return value.iso_format()
if isinstance(value, Neo4jDuration):
return str(value)
if isinstance(value, dict):
return {k: _convert_neo4j_types(v) for k, v in value.items()}
if isinstance(value, list):
return [_convert_neo4j_types(item) for item in value]
return value
class Neo4jConnector:
"""Neo4j数据库连接器
@@ -59,11 +77,12 @@ class Neo4jConnector:
"""
await self.driver.close()
async def execute_query(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]:
async def execute_query(self, query: str, json_format=False, **kwargs: Any) -> List[Dict[str, Any]]:
"""执行Cypher查询
Args:
query: Cypher查询语句
json_format: json格式化
**kwargs: 查询参数将作为参数传递给Cypher查询
Returns:
@@ -78,7 +97,10 @@ class Neo4jConnector:
**kwargs
)
records, summary, keys = result
return [record.data() for record in records]
if json_format:
return [_convert_neo4j_types(record.data()) for record in records]
else:
return [record.data() for record in records]
async def execute_write_transaction(self, transaction_func, **kwargs: Any) -> Any:
"""在写事务中执行操作

View File

@@ -3,9 +3,9 @@
"""
import uuid
from typing import Any, Annotated
from typing import Any, Annotated, Literal
from sqlalchemy.orm import Session
from sqlalchemy import desc
from sqlalchemy import desc, select
from fastapi import Depends
from app.models.workflow_model import (
@@ -128,29 +128,36 @@ class WorkflowExecutionRepository:
Returns:
执行记录列表
"""
return self.db.query(WorkflowExecution).filter(
stmt = select(WorkflowExecution).filter(
WorkflowExecution.app_id == app_id
).order_by(
desc(WorkflowExecution.started_at)
).limit(limit).offset(offset).all()
).limit(limit).offset(offset)
return list(self.db.execute(stmt).scalars())
def get_by_conversation_id(
self,
conversation_id: uuid.UUID
conversation_id: uuid.UUID,
status: Literal["running", "completed", "failed"] = None,
limit_count: int = 50
) -> list[WorkflowExecution]:
"""根据会话 ID 获取执行记录列表
Args:
limit_count:
conversation_id: 会话 ID
status: 状态(可选)
Returns:
执行记录列表
"""
return self.db.query(WorkflowExecution).filter(
stmt = select(WorkflowExecution).filter(
WorkflowExecution.conversation_id == conversation_id
).order_by(
desc(WorkflowExecution.started_at)
).all()
)
if status:
stmt = stmt.filter(WorkflowExecution.status == status)
stmt = stmt.order_by(desc(WorkflowExecution.started_at)).limit(limit_count)
return list(self.db.execute(stmt).scalars())
def count_by_app_id(self, app_id: uuid.UUID) -> int:
"""统计应用的执行次数
@@ -199,11 +206,12 @@ class WorkflowNodeExecutionRepository:
Returns:
节点执行记录列表(按执行顺序排序)
"""
return self.db.query(WorkflowNodeExecution).filter(
stmt = select(WorkflowNodeExecution).filter(
WorkflowNodeExecution.execution_id == execution_id
).order_by(
WorkflowNodeExecution.execution_order
).all()
)
return list(self.db.execute(stmt).scalars())
def get_by_node_id(
self,
@@ -219,12 +227,13 @@ class WorkflowNodeExecutionRepository:
Returns:
节点执行记录列表
"""
return self.db.query(WorkflowNodeExecution).filter(
stmt = select(WorkflowNodeExecution).filter(
WorkflowNodeExecution.execution_id == execution_id,
WorkflowNodeExecution.node_id == node_id
).order_by(
WorkflowNodeExecution.retry_count
).all()
)
return list(self.db.execute(stmt).scalars())
# ==================== 依赖注入函数 ====================

View File

@@ -241,6 +241,8 @@ class ModelParameters(BaseModel):
presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0, description="存在惩罚")
n: int = Field(default=1, ge=1, le=10, description="生成的回复数量")
stop: Optional[List[str]] = Field(default=None, description="停止序列")
deep_thinking: bool = Field(default=False, description="是否启用深度思考模式(需模型支持,如 DeepSeek-R1、QwQ 等)")
thinking_budget_tokens: Optional[int] = Field(default=None, ge=1024, le=131072, description="深度思考 token 预算(仅部分模型支持)")
class VariableDefinition(BaseModel):
@@ -612,6 +614,7 @@ class AppChatRequest(BaseModel):
user_id: Optional[str] = Field(default=None, description="用户ID用于会话管理")
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
stream: bool = Field(default=False, description="是否流式返回")
thinking: bool = Field(default=False, description="是否启用深度思考需Agent配置支持")
files: List[FileInput] = Field(default_factory=list, description="附件列表(支持多文件)")
@@ -641,6 +644,7 @@ class CitationSource(BaseModel):
class DraftRunResponse(BaseModel):
"""试运行响应(非流式)"""
message: str = Field(..., description="AI 回复消息")
reasoning_content: Optional[str] = Field(default=None, description="深度思考内容")
conversation_id: Optional[str] = Field(default=None, description="会话ID用于多轮对话")
usage: Optional[Dict[str, Any]] = Field(default=None, description="Token 使用情况")
elapsed_time: Optional[float] = Field(default=None, description="耗时(秒)")
@@ -648,6 +652,12 @@ class DraftRunResponse(BaseModel):
citations: List[CitationSource] = Field(default_factory=list, description="引用来源")
audio_url: Optional[str] = Field(default=None, description="TTS 语音URL")
def model_dump(self, **kwargs):
data = super().model_dump(**kwargs)
if not data.get("reasoning_content"):
data.pop("reasoning_content", None)
return data
class OpeningResponse(BaseModel):
"""应用开场白响应"""

View File

@@ -31,7 +31,8 @@ class ChatRequest(BaseModel):
stream: bool = Field(default=False, description="是否流式返回")
web_search: bool = Field(default=False, description="是否启用网络搜索")
memory: bool = Field(default=True, description="是否启用记忆功能")
files: Optional[List[FileInput]] = Field(default=None, description="附件列表(支持多文件")
thinking: bool = Field(default=False, description="是否启用深度思考需Agent配置支持")
files: List[FileInput] = Field(default_factory=list, description="附件列表(支持多文件)")
# ---------- Output Schemas ----------

View File

@@ -17,6 +17,7 @@ class Write_UserInput(BaseModel):
end_user_id: str
config_id: Optional[str] = None
class AgentMemory_Long_Term(ABC):
"""长期记忆配置常量"""
STORAGE_NEO4J = "neo4j"
@@ -25,8 +26,9 @@ class AgentMemory_Long_Term(ABC):
STRATEGY_CHUNK = "chunk"
STRATEGY_TIME = "time"
DEFAULT_SCOPE = 6
TIME_SCOPE=5
class AgentMemoryDataset(ABC):
PRONOUN=['','本人','在下','自己','','鄙人','','']
NAME='用户'
TIME_SCOPE = 5
class AgentMemoryDataset(ABC):
PRONOUN = ['', '本人', '在下', '自己', '', '鄙人', '', '']
NAME = '用户'

View File

@@ -138,21 +138,13 @@ class CreateEndUserRequest(BaseModel):
"""Request schema for creating an end user.
Attributes:
workspace_id: Workspace ID (required)
other_id: External user identifier (required)
other_name: Display name for the end user
memory_config_id: Optional memory config ID. If not provided, uses workspace default.
"""
workspace_id: str = Field(..., description="Workspace ID (required)")
other_id: str = Field(..., description="External user identifier (required)")
other_name: Optional[str] = Field("", description="Display name")
@field_validator("workspace_id")
@classmethod
def validate_workspace_id(cls, v: str) -> str:
"""Validate that workspace_id is not empty."""
if not v or not v.strip():
raise ValueError("workspace_id is required and cannot be empty")
return v.strip()
memory_config_id: Optional[str] = Field(None, description="Memory config ID. Falls back to workspace default if not provided.")
@field_validator("other_id")
@classmethod
@@ -171,11 +163,13 @@ class CreateEndUserResponse(BaseModel):
other_id: External user identifier
other_name: Display name
workspace_id: Workspace the user belongs to
memory_config_id: Connected memory config ID
"""
id: str = Field(..., description="End user UUID")
other_id: str = Field(..., description="External user identifier")
other_name: str = Field("", description="Display name")
workspace_id: str = Field(..., description="Workspace ID")
memory_config_id: Optional[str] = Field(None, description="Connected memory config ID")
class MemoryConfigItem(BaseModel):

View File

@@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
from app.core.agent.langchain_agent import LangChainAgent
from app.core.logging_config import get_business_logger
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
from app.db import get_db
from app.models import MultiAgentConfig, AgentConfig, ModelType
from app.models import WorkflowConfig
@@ -20,11 +21,11 @@ from app.schemas.model_schema import ModelInfo
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from app.services.conversation_service import ConversationService
from app.services.draft_run_service import AgentRunService
from app.services.memory_agent_service import get_end_user_connected_config
from app.services.model_service import ModelApiKeyService
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
from app.services.multimodal_service import MultimodalService
from app.services.workflow_service import WorkflowService
from app.schemas import FileType
logger = get_business_logger()
@@ -43,18 +44,17 @@ class AppChatService:
message: str,
conversation_id: uuid.UUID,
config: AgentConfig,
user_id: Optional[str] = None,
files: list[FileInput],
user_id: str,
variables: Optional[Dict[str, Any]] = None,
web_search: bool = False,
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
workspace_id: Optional[str] = None,
files: Optional[List[FileInput]] = None
workspace_id: Optional[str] = None
) -> Dict[str, Any]:
"""聊天(非流式)"""
start_time = time.time()
config_id = None
# 应用 features 配置
features_config: dict = config.features or {}
@@ -93,7 +93,8 @@ class AppChatService:
tools.extend(skill_tools)
if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval,
user_id)
tools.extend(kb_tools)
memory_flag = False
if memory:
@@ -116,7 +117,9 @@ class AppChatService:
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
deep_thinking=model_parameters.get("deep_thinking", False),
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
capability=api_key_obj.capability or [],
)
model_info = ModelInfo(
@@ -168,11 +171,6 @@ class AppChatService:
message=message,
history=history,
context=None,
end_user_id=user_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
config_id=config_id,
memory_flag=memory_flag,
files=processed_files # 传递处理后的文件
)
@@ -209,7 +207,8 @@ class AppChatService:
"model": api_key_obj.model_name,
"usage": result.get("usage", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}),
"audio_url": None,
"citations": filtered_citations
"citations": filtered_citations,
"reasoning_content": result.get("reasoning_content")
}
if files:
for f in files:
@@ -229,6 +228,26 @@ class AppChatService:
# 保存消息
if audio_url:
assistant_meta["audio_url"] = audio_url
if memory_flag:
connected_config = get_end_user_connected_config(user_id, self.db)
memory_config_id: str = connected_config.get("memory_config_id")
file_list = []
for file in files:
file_dict = file.model_dump()
file_dict["upload_file_id"] = str(file_dict["upload_file_id"]) if file_dict["upload_file_id"] else None
file_list.append(file_dict)
messages = [
{"role": "user", "content": message, "files": file_list},
{"role": "assistant", "content": result["content"]}
]
if memory_config_id:
await write_long_term(
storage_type,
user_id,
messages,
user_rag_memory_id,
memory_config_id
)
self.conversation_service.add_message(
conversation_id=conversation_id,
role="user",
@@ -247,6 +266,7 @@ class AppChatService:
"conversation_id": conversation_id,
"message_id": str(message_id),
"message": result["content"],
"reasoning_content": result.get("reasoning_content"),
"usage": result.get("usage", {
"prompt_tokens": 0,
"completion_tokens": 0,
@@ -264,20 +284,19 @@ class AppChatService:
message: str,
conversation_id: uuid.UUID,
config: AgentConfig,
files: list[FileInput],
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
web_search: bool = False,
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
workspace_id: Optional[str] = None,
files: Optional[List[FileInput]] = None
workspace_id: Optional[str] = None
) -> AsyncGenerator[str, None]:
"""聊天(流式)"""
try:
start_time = time.time()
config_id = None
message_id = uuid.uuid4()
# 应用 features 配置
@@ -319,7 +338,8 @@ class AppChatService:
tools.extend(skill_tools)
if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(
config.knowledge_retrieval, user_id)
tools.extend(kb_tools)
# 添加长期记忆工具
memory_flag = False
@@ -343,7 +363,10 @@ class AppChatService:
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
streaming=True
streaming=True,
deep_thinking=model_parameters.get("deep_thinking", False),
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
capability=api_key_obj.capability or [],
)
model_info = ModelInfo(
@@ -392,6 +415,7 @@ class AppChatService:
# 流式调用 Agent支持多模态同时并行启动 TTS
full_content = ""
full_reasoning = ""
total_tokens = 0
text_queue: asyncio.Queue = asyncio.Queue()
@@ -411,15 +435,13 @@ class AppChatService:
message=message,
history=history,
context=None,
end_user_id=user_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
config_id=config_id,
memory_flag=memory_flag,
files=processed_files
):
if isinstance(chunk, int):
total_tokens = chunk
elif isinstance(chunk, dict) and chunk.get("type") == "reasoning":
full_reasoning += chunk['content']
yield f"event: reasoning\ndata: {json.dumps({'content': chunk['content']}, ensure_ascii=False)}\n\n"
else:
full_content += chunk
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
@@ -459,14 +481,15 @@ class AppChatService:
# 保存消息
human_meta = {
"files":[],
"files": [],
"history_files": {}
}
assistant_meta = {
"model": api_key_obj.model_name,
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens},
"audio_url": None,
"citations": filtered_citations
"citations": filtered_citations,
"reasoning_content": full_reasoning or None
}
if files:
@@ -484,6 +507,27 @@ class AppChatService:
if stream_audio_url:
assistant_meta["audio_url"] = stream_audio_url
if memory_flag:
connected_config = get_end_user_connected_config(user_id, self.db)
memory_config_id: str = connected_config.get("memory_config_id")
file_list = []
for file in files:
file_dict = file.model_dump()
file_dict["upload_file_id"] = str(file_dict["upload_file_id"]) if file_dict["upload_file_id"] else None
file_list.append(file_dict)
messages = [
{"role": "user", "content": message, "files": file_list},
{"role": "assistant", "content": full_content}
]
if memory_config_id:
await write_long_term(
storage_type,
user_id,
messages,
user_rag_memory_id,
memory_config_id
)
self.conversation_service.add_message(
conversation_id=conversation_id,
role="user",
@@ -618,7 +662,6 @@ class AppChatService:
# 2. 创建编排器
orchestrator = MultiAgentOrchestrator(self.db, config)
# 3. 流式执行任务
async for event in orchestrator.execute_stream(
message=message,
@@ -631,13 +674,13 @@ class AppChatService:
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
if "sub_usage" in event:
# 拦截 sub_usage 事件,累加 token
if "event: sub_usage" in event:
if "data:" in event:
try:
data_line = event.split("data: ", 1)[1].strip()
data = json.loads(data_line)
if "total_tokens" in data:
total_tokens += data["total_tokens"]
total_tokens += data.get("total_tokens", 0)
except:
pass
else:

View File

@@ -13,7 +13,7 @@ import uuid
from typing import Annotated, Any, Dict, List, Optional, Tuple
from fastapi import Depends
from sqlalchemy import and_, delete, func, or_, select
from sqlalchemy import and_, delete, func, or_, select, update as sa_update
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
@@ -401,7 +401,7 @@ class AppService:
def _create_workflow_config(
self,
app_id: uuid.UUID,
data: app_schema.WorkflowConfigCreate,
data,
now: datetime.datetime
):
workflow_cfg = WorkflowConfig(
@@ -678,7 +678,9 @@ class AppService:
self._create_multi_agent_config(app.id, data.multi_agent_config, now)
if app.type == "workflow" and data.workflow_config:
self._create_workflow_config(app.id, data.workflow_config, now)
from app.schemas.workflow_schema import WorkflowConfigCreate
wf_data = WorkflowConfigCreate(**data.workflow_config) if isinstance(data.workflow_config, dict) else data.workflow_config
self._create_workflow_config(app.id, wf_data, now)
self.db.commit()
self.db.refresh(app)
@@ -757,6 +759,17 @@ class AppService:
# 逻辑删除应用
app.is_active = False
# 更新 app_shares 表中该应用的所有共享记录为失效状态,并更新 updated_at 时间
stmt = sa_update(AppShare).where(
AppShare.source_app_id == app_id,
AppShare.is_active.is_(True)
).values(
is_active=False,
updated_at=datetime.datetime.now()
)
self.db.execute(stmt)
self.db.commit()
logger.info(
@@ -1347,6 +1360,7 @@ class AppService:
variables=cfg.get("variables", []),
execution_config=cfg.get("execution_config", {}),
triggers=cfg.get("triggers", []),
features=cfg.get("features", {}),
is_active=True,
created_at=now,
updated_at=now,

View File

@@ -534,6 +534,7 @@ class ConversationService:
api_key = api_config.api_key
api_base = api_config.api_base
is_omni = api_config.is_omni
capability = api_config.capability
model_type = config.type
llm = RedBearLLM(
@@ -542,7 +543,8 @@ class ConversationService:
provider=provider,
api_key=api_key,
base_url=api_base,
is_omni=is_omni
is_omni=is_omni,
support_thinking="thinking" in (capability or []),
),
type=ModelType(model_type)
)

View File

@@ -24,7 +24,7 @@ from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.models import AgentConfig, ModelConfig, ModelType
from app.models import AgentConfig, ModelConfig
from app.repositories.tool_repository import ToolRepository
from app.schemas.app_schema import FileInput, Citation
from app.schemas.model_schema import ModelInfo
@@ -37,7 +37,6 @@ from app.services.model_parameter_merger import ModelParameterMerger
from app.services.model_service import ModelApiKeyService
from app.services.multimodal_service import MultimodalService
from app.services.tool_service import ToolService
from app.schemas import FileType
logger = get_business_logger()
@@ -459,7 +458,7 @@ class AgentRunService:
statement = opening["statement"]
suggested_questions = opening["suggested_questions"]
# 如果有变量,进行替换(仅支持 {{var_name}} 格式)
if variables:
for var_name, var_value in variables.items():
@@ -596,6 +595,9 @@ class AgentRunService:
max_tokens=effective_params.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
deep_thinking=effective_params.get("deep_thinking", False),
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
capability=api_key_config.get("capability", []),
)
# 5. 处理会话ID创建或验证新会话时写入开场白
@@ -661,11 +663,6 @@ class AgentRunService:
message=message,
history=history,
context=context,
end_user_id=user_id,
config_id=config_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_flag=memory_flag,
files=processed_files # 传递处理后的文件
)
@@ -695,7 +692,8 @@ class AgentRunService:
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
})
}),
"reasoning_content": result.get("reasoning_content")
},
files=files,
processed_files=processed_files,
@@ -707,6 +705,7 @@ class AgentRunService:
response = {
"message": result["content"],
"reasoning_content": result.get("reasoning_content"),
"conversation_id": conversation_id,
"usage": result.get("usage", {
"prompt_tokens": 0,
@@ -844,7 +843,10 @@ class AgentRunService:
max_tokens=effective_params.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
streaming=True
streaming=True,
deep_thinking=effective_params.get("deep_thinking", False),
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
capability=api_key_config.get("capability", []),
)
# 5. 处理会话ID创建或验证新会话时写入开场白
@@ -904,6 +906,7 @@ class AgentRunService:
# 9. 流式调用 Agent支持多模态同时并行启动 TTS
full_content = ""
full_reasoning = ""
total_tokens = 0
# 启动流式 TTS文本边输出边合成
@@ -918,15 +921,13 @@ class AgentRunService:
message=message,
history=history,
context=context,
end_user_id=user_id,
config_id=config_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_flag=memory_flag,
files=processed_files
):
if isinstance(chunk, int):
total_tokens = chunk
elif isinstance(chunk, dict) and chunk.get("type") == "reasoning":
full_reasoning += chunk["content"]
yield self._format_sse_event("reasoning", {"content": chunk["content"]})
else:
full_content += chunk
yield self._format_sse_event("message", {"content": chunk})
@@ -955,7 +956,8 @@ class AgentRunService:
app_id=agent_config.app_id,
user_id=user_id,
meta_data={
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens},
"reasoning_content": full_reasoning or None
},
files=files,
processed_files=processed_files,
@@ -1676,7 +1678,7 @@ class AgentRunService:
"""从 text_queue 取文本按句子切分后喂给 synthesizer"""
import re
buf = ""
sentence_end = re.compile(r'[\u3002\uff01\uff1f\.!?\n]')
sentence_end = re.compile(r'[\u3002\uff01\uff1f.!?\n]')
while True:
chunk = await text_queue.get()
if chunk is None:
@@ -1905,6 +1907,7 @@ class AgentRunService:
"conversation_id": result['conversation_id'],
"parameters_used": model_info["parameters"],
"message": result.get("message"),
"reasoning_content": result.get("reasoning_content"),
"usage": usage,
"elapsed_time": elapsed,
"tokens_per_second": (
@@ -2023,7 +2026,7 @@ class AgentRunService:
# 需要从 ModelApiKey 获取实际的模型名称,或者在 ModelConfig 中添加 model 字段
return None
def _with_parameters(self, agent_config: AgentConfig, parameters: Dict[str, Any]) -> AgentConfig:
def _with_parameters(self, agent_config: AgentConfig, parameters: Dict[str, Any]) -> tuple[AgentConfig, Any]:
"""创建一个带有覆盖参数的 agent_config浅拷贝只修改 model_parameters
Args:
@@ -2121,6 +2124,7 @@ class AgentRunService:
start_time = time.time()
full_content = ""
full_reasoning = ""
returned_conversation_id = model_conversation_id
audio_url = None
audio_status = None
@@ -2179,6 +2183,18 @@ class AgentRunService:
"content": chunk
}))
# 转发深度思考事件(带模型标识)
if event_type == "reasoning" and event_data:
reasoning_chunk = event_data.get("content", "")
full_reasoning += reasoning_chunk
await event_queue.put(self._format_sse_event("model_reasoning", {
"model_index": idx,
"model_config_id": model_config_id,
"label": model_label,
"conversation_id": returned_conversation_id,
"content": event_data.get("content", "")
}))
# 从 end 事件中提取 features 输出字段
if event_type == "end" and event_data:
audio_url = event_data.get("audio_url")
@@ -2210,6 +2226,7 @@ class AgentRunService:
"conversation_id": returned_conversation_id,
"parameters_used": model_info["parameters"],
"message": full_content,
"reasoning_content": full_reasoning or None,
"elapsed_time": elapsed,
"audio_url": audio_url,
"audio_status": audio_status,
@@ -2362,6 +2379,7 @@ class AgentRunService:
"label": r["label"],
"conversation_id": r.get("conversation_id"),
"message": r.get("message"),
"reasoning_content": r.get("reasoning_content"),
"elapsed_time": r.get("elapsed_time", 0),
"audio_url": r.get("audio_url"),
"audio_status": r.get("audio_status"),

View File

@@ -415,6 +415,7 @@ class LLMRouter:
api_key=api_key_config.api_key,
base_url=api_key_config.api_base,
is_omni=api_key_config.is_omni,
support_thinking="thinking" in (api_key_config.capability or []),
temperature=0.3,
max_tokens=500
)

View File

@@ -393,6 +393,7 @@ class MasterAgentRouter:
api_key=api_key_config.api_key,
base_url=api_key_config.api_base,
is_omni=api_key_config.is_omni,
support_thinking="thinking" in (api_key_config.capability or []),
extra_params = extra_params
)
@@ -403,6 +404,17 @@ class MasterAgentRouter:
response = await llm.ainvoke(prompt)
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
# 提取 token 消耗
self._last_routing_tokens = 0
if hasattr(response, 'usage_metadata') and response.usage_metadata:
um = response.usage_metadata
self._last_routing_tokens = um.get("total_tokens", 0) if isinstance(um, dict) else getattr(um, "total_tokens", 0)
elif hasattr(response, 'response_metadata') and response.response_metadata:
token_usage = response.response_metadata.get("token_usage") or response.response_metadata.get("usage", {})
if isinstance(token_usage, dict):
self._last_routing_tokens = token_usage.get("total_tokens", 0)
logger.info(f"Master Agent 路由 token 消耗: {self._last_routing_tokens}")
# 提取响应内容
if hasattr(response, 'content'):
return response.content

View File

@@ -37,6 +37,7 @@ from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.agent.utils.write_tools import write as write_neo4j
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.memory.utils.log.audit_logger import audit_logger
from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
@@ -49,10 +50,6 @@ from app.services.memory_konwledges_server import (
)
from app.services.memory_perceptual_service import MemoryPerceptualService
try:
from app.core.memory.utils.log.audit_logger import audit_logger
except ImportError:
audit_logger = None
logger = get_logger(__name__)
config_logger = get_config_logger()
@@ -68,24 +65,22 @@ class MemoryAgentService:
if str(messages) == 'success':
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
# 记录成功的操作
if audit_logger:
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=True,
duration=duration, details={"message_length": len(message)})
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=True,
duration=duration, details={"message_length": len(message)})
return context
else:
logger.warning(f"Write operation failed for group {end_user_id}")
# 记录失败的操作
if audit_logger:
audit_logger.log_operation(
operation="WRITE",
config_id=config_id,
end_user_id=end_user_id,
success=False,
duration=duration,
error=f"写入失败: {messages[:100]}"
)
audit_logger.log_operation(
operation="WRITE",
config_id=config_id,
end_user_id=end_user_id,
success=False,
duration=duration,
error=f"写入失败: {messages[:100]}"
)
raise ValueError(f"写入失败: {messages}")
@@ -338,10 +333,9 @@ class MemoryAgentService:
logger.error(error_msg)
# Log failed operation
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=False, duration=duration, error=error_msg)
duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=False, duration=duration, error=error_msg)
raise ValueError(error_msg)
@@ -401,10 +395,10 @@ class MemoryAgentService:
# Ensure proper error handling and logging
error_msg = f"Write operation failed: {str(e)}"
logger.error(error_msg)
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=False, duration=duration, error=error_msg)
duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=False, duration=duration, error=error_msg)
raise ValueError(error_msg)
async def read_memory(
@@ -468,12 +462,6 @@ class MemoryAgentService:
logger.info(f"Read operation for group {end_user_id} with config_id {config_id}")
# 导入审计日志记录器
try:
from app.core.memory.utils.log.audit_logger import audit_logger
except ImportError:
audit_logger = None
config_load_start = time.time()
try:
# Use a separate database session to avoid transaction failures
@@ -492,16 +480,15 @@ class MemoryAgentService:
logger.error(error_msg)
# Log failed operation
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
operation="READ",
config_id=config_id,
end_user_id=end_user_id,
success=False,
duration=duration,
error=error_msg
)
duration = time.time() - start_time
audit_logger.log_operation(
operation="READ",
config_id=config_id,
end_user_id=end_user_id,
success=False,
duration=duration,
error=error_msg
)
raise ValueError(error_msg)
@@ -515,10 +502,13 @@ class MemoryAgentService:
async with make_read_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
"end_user_id": end_user_id
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
"memory_config": memory_config}
initial_state = {
"messages": [HumanMessage(content=message)],
"search_switch": search_switch,
"end_user_id": end_user_id
, "storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"memory_config": memory_config}
# 获取节点更新信息
_intermediate_outputs = []
summary = ''
@@ -530,7 +520,7 @@ class MemoryAgentService:
for node_name, node_data in update_event.items():
# if 'save_neo4j' == node_name:
# massages = node_data
print(f"处理节点: {node_name}")
logger.info(f"处理节点: {node_name}")
# 处理不同Summary节点的返回结构
if 'Summary' in node_name:
@@ -557,6 +547,11 @@ class MemoryAgentService:
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
_intermediate_outputs.extend(retrieve_node)
# Perceptual_Retrieve 节点
perceptual_node = node_data.get('perceptual_data', {}).get('_intermediate', None)
if perceptual_node and perceptual_node != [] and perceptual_node != {}:
_intermediate_outputs.append(perceptual_node)
# Verify 节点
verify_n = node_data.get('verify', {}).get('_intermediate', None)
if verify_n and verify_n != [] and verify_n != {}:
@@ -633,15 +628,15 @@ class MemoryAgentService:
total_time = time.time() - start_time
logger.info(
f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
operation="READ",
config_id=config_id,
end_user_id=end_user_id,
success=True,
duration=duration
)
duration = time.time() - start_time
audit_logger.log_operation(
operation="READ",
config_id=config_id,
end_user_id=end_user_id,
success=True,
duration=duration
)
return {
"answer": summary,
@@ -651,16 +646,16 @@ class MemoryAgentService:
# Ensure proper error handling and logging
error_msg = f"Read operation failed: {str(e)}"
logger.error(error_msg)
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
operation="READ",
config_id=config_id,
end_user_id=end_user_id,
success=False,
duration=duration,
error=error_msg
)
duration = time.time() - start_time
audit_logger.log_operation(
operation="READ",
config_id=config_id,
end_user_id=end_user_id,
success=False,
duration=duration,
error=error_msg
)
raise ValueError(error_msg)
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:

View File

@@ -280,6 +280,53 @@ class MemoryAPIService:
code=BizCode.MEMORY_READ_FAILED
)
def create_end_user(
self,
workspace_id: uuid.UUID,
other_id: str,
) -> Dict[str, Any]:
"""Create or retrieve an end user for the workspace.
Uses get_or_create semantics: if an end user with the same other_id
already exists in the workspace, returns the existing one.
Args:
workspace_id: Workspace ID from API key authorization
other_id: External user identifier
Returns:
Dict with id, other_id, other_name, and workspace_id
Raises:
BusinessException: If creation fails
"""
logger.info(f"Creating end user - other_id: {other_id}, workspace_id: {workspace_id}")
try:
from app.repositories.end_user_repository import EndUserRepository
end_user_repo = EndUserRepository(self.db)
end_user = end_user_repo.get_or_create_end_user(
app_id=None,
workspace_id=workspace_id,
other_id=other_id,
)
logger.info(f"End user ready: {end_user.id}")
return {
"id": str(end_user.id),
"other_id": end_user.other_id or "",
"other_name": end_user.other_name or "",
"workspace_id": str(end_user.workspace_id),
}
except Exception as e:
logger.error(f"Failed to create end user for workspace {workspace_id}: {e}")
raise BusinessException(
message=f"Failed to create end user: {str(e)}",
code=BizCode.INTERNAL_ERROR
)
def list_memory_configs(
self,
workspace_id: uuid.UUID,

View File

@@ -1,11 +1,12 @@
from sqlalchemy.orm import Session
from typing import List, Optional
from sqlalchemy import desc, nullslast, or_, and_, cast, String
from typing import List, Optional, Dict, Any
import uuid
from fastapi import HTTPException
from app.models.user_model import User
from app.models.app_model import App
from app.models.end_user_model import EndUser
from app.models.end_user_model import EndUser, EndUser as EndUserModel
from app.models.memory_increment_model import MemoryIncrement
from app.repositories import (
@@ -49,44 +50,40 @@ def get_current_workspace_type(
def get_workspace_end_users(
db: Session,
workspace_id: uuid.UUID,
db: Session,
workspace_id: uuid.UUID,
current_user: User
) -> List[EndUser]:
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)
返回结果按 created_at 从新到旧排序NULL 值排在最后)
"""
business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}")
try:
try:
# 查询应用ORM
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
if not apps_orm:
business_logger.info("工作空间下没有应用")
return []
# 提取所有 app_id
# app_ids = [app.id for app in apps_orm]
# 批量查询所有 end_users一次查询而非循环查询
# 按 created_at 降序排序NULL 值排在最后id 作为次级排序键保证确定性
from app.models.end_user_model import EndUser as EndUserModel
from sqlalchemy import desc, nullslast
end_users_orm = db.query(EndUserModel).filter(
EndUserModel.workspace_id == workspace_id
).order_by(
nullslast(desc(EndUserModel.created_at)),
desc(EndUserModel.id)
).all()
# 转换为 Pydantic 模型(只在需要时转换)
end_users = [EndUserSchema.model_validate(eu) for eu in end_users_orm]
business_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
return end_users
except HTTPException:
raise
except Exception as e:
@@ -94,6 +91,85 @@ def get_workspace_end_users(
raise
def get_workspace_end_users_paginated(
db: Session,
workspace_id: uuid.UUID,
current_user: User,
page: int,
pagesize: int,
keyword: Optional[str] = None
) -> Dict[str, Any]:
"""获取工作空间的宿主列表(分页版本,支持模糊搜索)
返回结果按 created_at 从新到旧排序NULL 值排在最后)
支持通过 keyword 参数同时模糊搜索 other_name 和 id 字段
Args:
db: 数据库会话
workspace_id: 工作空间ID
current_user: 当前用户
page: 页码从1开始
pagesize: 每页数量
keyword: 搜索关键词(可选,同时模糊匹配 other_name 和 id
Returns:
dict: 包含 items宿主列表和 total总记录数的字典
"""
business_logger.info(f"获取工作空间宿主列表(分页): workspace_id={workspace_id}, keyword={keyword}, page={page}, pagesize={pagesize}, 操作者: {current_user.username}")
try:
# 构建基础查询
base_query = db.query(EndUserModel).filter(
EndUserModel.workspace_id == workspace_id
)
# 构建搜索条件过滤空字符串和None
keyword = keyword.strip() if keyword else None
if keyword:
keyword_pattern = f"%{keyword}%"
# other_name 匹配始终生效id 匹配仅对 other_name 为空的记录生效
base_query = base_query.filter(
or_(
EndUserModel.other_name.ilike(keyword_pattern),
and_(
or_(
EndUserModel.other_name.is_(None),
EndUserModel.other_name == "",
),
cast(EndUserModel.id, String).ilike(keyword_pattern),
),
)
)
business_logger.info(f"应用模糊搜索: keyword={keyword}(匹配 other_nameother_name 为空时匹配 id")
# 获取总记录数
total = base_query.count()
if total == 0:
business_logger.info("工作空间下没有宿主")
return {"items": [], "total": 0}
# 分页查询
# 按 created_at 降序排序NULL 值排在最后id 作为次级排序键保证确定性
end_users_orm = base_query.order_by(
nullslast(desc(EndUserModel.created_at)),
desc(EndUserModel.id)
).offset((page - 1) * pagesize).limit(pagesize).all()
# 转换为 Pydantic 模型
end_users = [EndUserSchema.model_validate(eu) for eu in end_users_orm]
business_logger.info(f"成功获取 {len(end_users)} 个宿主记录,总计 {total}")
return {"items": end_users, "total": total}
except HTTPException:
raise
except Exception as e:
business_logger.error(f"获取工作空间宿主列表(分页)失败: workspace_id={workspace_id} - {str(e)}")
raise
def get_workspace_memory_increment(
db: Session,
workspace_id: uuid.UUID,
@@ -277,15 +353,13 @@ async def get_workspace_total_memory_count(
"details": []
}
# 2. 对每个 host_id 调用 search_all 获取 total
# 2. 使用 search_all_batch 批量查询所有宿主的记忆数量
from app.services import memory_storage_service
total_count = 0
details = []
# 如果提供了 end_user_id只查询该用户
if end_user_id:
search_result = await memory_storage_service.search_all(end_user_id=end_user_id)
batch_result = await memory_storage_service.search_all_batch([end_user_id])
count = batch_result.get(end_user_id, 0)
# 查询用户名称
from app.repositories.end_user_repository import EndUserRepository
repo = EndUserRepository(db)
@@ -293,42 +367,31 @@ async def get_workspace_total_memory_count(
user_name = end_user.other_name if end_user else None
return {
"total_memory_count": search_result.get("total", 0),
"total_memory_count": count,
"host_count": 1,
"details": [{
"end_user_id": end_user_id,
"count": search_result.get("total", 0),
"count": count,
"name": user_name
}]
}
for host in hosts:
try:
end_user_id_str = str(host.id)
search_result = await memory_storage_service.search_all(
end_user_id=end_user_id_str
)
host_total = search_result.get("total", 0)
total_count += host_total
details.append({
"end_user_id": end_user_id_str,
"count": host_total,
"name": host.other_name # 使用 other_name 字段
})
business_logger.debug(f"EndUser {end_user_id_str} ({host.other_name}) 记忆数: {host_total}")
except Exception as e:
business_logger.warning(f"获取 end_user {host.id} 记忆数失败: {str(e)}")
# 失败的 host 记为 0
details.append({
"end_user_id": str(host.id),
"count": 0,
"name": host.other_name # 使用 other_name 字段
})
# 批量查询所有宿主记忆数量(一次 Neo4j 查询)
end_user_ids = [str(host.id) for host in hosts]
batch_result = await memory_storage_service.search_all_batch(end_user_ids)
# 构建 host name 映射
host_name_map = {str(host.id): host.other_name for host in hosts}
total_count = sum(batch_result.values())
details = [
{
"end_user_id": uid,
"count": batch_result.get(uid, 0),
"name": host_name_map.get(uid)
}
for uid in end_user_ids
]
result = {
"total_memory_count": total_count,
@@ -443,6 +506,180 @@ def get_rag_user_kb_total_chunk(
business_logger.error(f"获取用户知识库总chunk数失败: workspace_id={workspace_id} - {str(e)}")
raise
def get_dashboard_yesterday_changes(
db: Session,
workspace_id: uuid.UUID,
storage_type: str,
today_data: dict
) -> dict:
"""
计算各指标相比昨天的变化百分比。
- total_app_change / total_knowledge_change只看活跃记录
百分比 = (截止今日活跃总量 - 截止昨日活跃总量) / 截止昨日活跃总量
- total_memory_change / total_api_call_change
百分比 = (今日总量 - 昨日总量) / 昨日总量
昨日总量为 0 时返回 None。返回值为浮点数例如 0.5 表示增长 50%
Args:
db: 数据库会话
workspace_id: 工作空间ID
storage_type: 存储类型 'neo4j' | 'rag'
today_data: 当前数据,包含 total_memory, total_app, total_knowledge, total_api_call
Returns:
{
"total_memory_change": float | None,
"total_app_change": float | None,
"total_knowledge_change": float | None,
"total_api_call_change": float | None
}
"""
from datetime import datetime
from sqlalchemy import func
from app.models.api_key_model import ApiKey, ApiKeyLog
from app.models.knowledge_model import Knowledge
from app.models.app_model import App
from app.models.appshare_model import AppShare
business_logger.info(f"计算昨日对比百分比: workspace_id={workspace_id}, storage_type={storage_type}")
now_local = datetime.now()
today_start = now_local.replace(hour=0, minute=0, second=0, microsecond=0)
changes = {
"total_memory_change": None,
"total_app_change": None,
"total_knowledge_change": None,
"total_api_call_change": None,
}
def _calc_percentage(today_val, yesterday_val):
"""计算百分比昨日为0时返回None"""
if yesterday_val is None or yesterday_val == 0:
return None
return round((today_val - yesterday_val) / yesterday_val, 4)
# --- total_api_call_change: (截止今日累计总数 - 截止昨日累计总数) / 截止昨日累计总数 ---
try:
api_key_ids = [
row[0] for row in db.query(ApiKey.id).filter(
ApiKey.workspace_id == workspace_id
).all()
]
if api_key_ids:
# 截止今日的累计调用总数
total_api_until_now = db.query(func.count(ApiKeyLog.id)).filter(
ApiKeyLog.api_key_id.in_(api_key_ids),
ApiKeyLog.created_at < now_local
).scalar() or 0
# 截止昨日的累计调用总数today_start 即昨日结束)
total_api_until_yesterday = db.query(func.count(ApiKeyLog.id)).filter(
ApiKeyLog.api_key_id.in_(api_key_ids),
ApiKeyLog.created_at < today_start
).scalar() or 0
changes["total_api_call_change"] = _calc_percentage(total_api_until_now, total_api_until_yesterday)
else:
changes["total_api_call_change"] = None
except Exception as e:
business_logger.warning(f"计算API调用昨日对比失败: {str(e)}")
# --- total_knowledge_change: 只看活跃(status=1)且为顶层知识库(parent_id=workspace_id),百分比 = (今日活跃总量 - 昨日活跃总量) / 昨日活跃总量 ---
try:
# 截止今日的活跃知识库总量(当前 status=1parent_id=workspace_id
today_knowledge = db.query(func.count(Knowledge.id)).filter(
Knowledge.workspace_id == workspace_id,
Knowledge.status == 1,
Knowledge.parent_id == Knowledge.workspace_id
).scalar() or 0
# 截止昨日的活跃知识库总量(昨日之前创建的、当前仍 status=1parent_id=workspace_id
yesterday_knowledge = db.query(func.count(Knowledge.id)).filter(
Knowledge.workspace_id == workspace_id,
Knowledge.status == 1,
Knowledge.parent_id == Knowledge.workspace_id,
Knowledge.created_at < today_start
).scalar() or 0
changes["total_knowledge_change"] = _calc_percentage(today_knowledge, yesterday_knowledge)
except Exception as e:
business_logger.warning(f"计算知识库昨日对比失败: {str(e)}")
# --- total_app_change: 只看活跃(is_active=True),百分比 = (今日活跃总量 - 昨日活跃总量) / 昨日活跃总量 ---
try:
# === 自有app ===
today_own_apps = db.query(func.count(App.id)).filter(
App.workspace_id == workspace_id,
App.is_active == True
).scalar() or 0
yesterday_own_apps = db.query(func.count(App.id)).filter(
App.workspace_id == workspace_id,
App.is_active == True,
App.created_at < today_start
).scalar() or 0
# === 被分享app ===
today_shared_apps = db.query(func.count(AppShare.id)).filter(
AppShare.target_workspace_id == workspace_id,
AppShare.is_active == True
).scalar() or 0
yesterday_shared_apps = db.query(func.count(AppShare.id)).filter(
AppShare.target_workspace_id == workspace_id,
AppShare.is_active == True,
AppShare.created_at < today_start
).scalar() or 0
today_total_app = today_own_apps + today_shared_apps
yesterday_total_app = yesterday_own_apps + yesterday_shared_apps
changes["total_app_change"] = _calc_percentage(today_total_app, yesterday_total_app)
except Exception as e:
business_logger.warning(f"计算应用数量昨日对比失败: {str(e)}")
# --- total_memory_change: (今日总量 - 昨日总量) / 昨日总量 ---
try:
today_memory = today_data.get("total_memory")
if today_memory is None:
changes["total_memory_change"] = None
elif storage_type == "neo4j":
last_record = db.query(MemoryIncrement).filter(
MemoryIncrement.workspace_id == workspace_id,
MemoryIncrement.created_at < today_start
).order_by(desc(MemoryIncrement.created_at)).first()
if last_record is None or last_record.total_num == 0:
changes["total_memory_change"] = None
else:
changes["total_memory_change"] = _calc_percentage(today_memory, last_record.total_num)
elif storage_type == "rag":
from app.models.document_model import Document
from app.models.end_user_model import EndUser as _EndUser
from app.models.app_model import App as _App
end_user_ids = [
str(eid) for (eid,) in db.query(_EndUser.id)
.join(_App, _EndUser.app_id == _App.id)
.filter(_App.workspace_id == workspace_id)
.all()
]
if not end_user_ids:
changes["total_memory_change"] = None
else:
file_names = [f"{uid}.txt" for uid in end_user_ids]
yesterday_chunk = int(db.query(func.sum(Document.chunk_num)).filter(
Document.file_name.in_(file_names),
Document.created_at < today_start
).scalar() or 0)
if yesterday_chunk == 0:
changes["total_memory_change"] = None
else:
changes["total_memory_change"] = _calc_percentage(today_memory, yesterday_chunk)
except Exception as e:
business_logger.warning(f"计算记忆总量昨日对比失败: {str(e)}")
business_logger.info(f"昨日对比百分比计算完成: {changes}")
return changes
def get_current_user_total_chunk(
end_user_id: str,
db: Session,
@@ -638,7 +875,24 @@ def get_rag_content(
business_logger.error(f"获取文档 {document.id} 的chunks失败: {str(e)}")
continue
# 4. 返回结果
# 4. 将所有 page_content 拼接后按角色分割为对话列表
merged_text = "\n".join(page_contents)
conversations = []
if merged_text.strip():
import re
# 在任意位置匹配 "user:" 或 "assistant:",不限于行首
parts = re.split(r'(user|assistant):', merged_text)
# parts 结构: ['', 'user', ' content...', 'assistant', ' content...', ...]
i = 1
while i < len(parts) - 1:
role = parts[i].strip()
content = parts[i + 1].strip()
# 将 content 中的 \n 还原为真实换行
content = content.replace("\\n", "\n")
if role in ("user", "assistant") and content:
conversations.append({"role": role, "content": content})
i += 2
result = {
"page": {
"page": page,
@@ -646,10 +900,10 @@ def get_rag_content(
"total": global_total,
"hasnext": offset_end < global_total,
},
"items": page_contents
"items": conversations
}
business_logger.info(f"成功获取RAG内容: total={global_total}, page={page}, 返回={len(page_contents)}")
business_logger.info(f"成功获取RAG内容: total={global_total}, page={page}, 返回={len(conversations)}对话")
return result
except Exception as e:
@@ -788,4 +1042,60 @@ async def generate_rag_profile(
"tags_count": len(tags),
"personas_count": len(personas),
"insight_generated": bool(insight_sections.get("memory_insight")),
}
}
def get_dashboard_common_stats(db: Session, workspace_id) -> dict:
"""
获取 dashboard 中 neo4j/rag 分支共享的统计数据:
total_app、total_knowledge、total_api_call
Returns:
dict: {"total_app": int, "total_knowledge": int, "total_api_call": int}
"""
result = {"total_app": 0, "total_knowledge": 0, "total_api_call": 0}
# total_app: 统计当前空间下的所有app数量包含自有 + 被分享给本工作空间的app
try:
from app.services import app_service as _app_svc
_, total_app = _app_svc.AppService(db).list_apps(
workspace_id=workspace_id, include_shared=True, pagesize=1
)
result["total_app"] = total_app
except Exception as e:
business_logger.warning(f"获取应用数量失败: {e}")
# total_knowledge: 统计顶层知识库parent_id = workspace_id
try:
from sqlalchemy import func as _func
from app.models.knowledge_model import Knowledge as _Knowledge
total_knowledge = db.query(_func.count(_Knowledge.id)).filter(
_Knowledge.workspace_id == workspace_id,
_Knowledge.status == 1,
_Knowledge.parent_id == _Knowledge.workspace_id
).scalar() or 0
result["total_knowledge"] = total_knowledge
except Exception as e:
business_logger.warning(f"获取知识库数量失败: {e}")
# total_api_call: 截止当前的历史累计调用总数
try:
from sqlalchemy import func as _api_func
from app.models.api_key_model import ApiKey as _ApiKey, ApiKeyLog as _ApiKeyLog
_api_key_ids = [
row[0] for row in db.query(_ApiKey.id).filter(
_ApiKey.workspace_id == workspace_id
).all()
]
if _api_key_ids:
total_api_calls = db.query(_api_func.count(_ApiKeyLog.id)).filter(
_ApiKeyLog.api_key_id.in_(_api_key_ids)
).scalar() or 0
else:
total_api_calls = 0
result["total_api_call"] = total_api_calls
except Exception as e:
business_logger.warning(f"获取API调用统计失败: {e}")
return result

View File

@@ -232,7 +232,8 @@ class MemoryPerceptualService:
provider=model_config.provider,
api_key=model_config.api_key,
base_url=model_config.api_base,
is_omni=model_config.is_omni
is_omni=model_config.is_omni,
support_thinking="thinking" in (model_config.capability or []),
)
)
return llm, model_config
@@ -243,28 +244,9 @@ class MemoryPerceptualService:
memory_config: MemoryConfig,
file: FileInput
):
memories = self.repository.get_by_url(file.url)
if memories:
business_logger.info(f"Perceptual memory already exists: {file.url}")
if end_user_id not in [memory.end_user_id for memory in memories]:
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
memory_cache = memories[0]
memory = self.repository.create_perceptual_memory(
end_user_id=uuid.UUID(end_user_id),
perceptual_type=PerceptualType(memory_cache.perceptual_type),
file_path=memory_cache.file_path,
file_name=memory_cache.file_name,
file_ext=memory_cache.file_ext,
summary=memory_cache.summary,
meta_data=memory_cache.meta_data
)
self.db.commit()
return memory
else:
for memory in memories:
if memory.end_user_id == uuid.UUID(end_user_id):
return memory
llm, model_config = self._get_mutlimodal_client(file.type, memory_config)
if model_config is None or llm is None:
return None
multimodel_service = MultimodalService(self.db, ModelInfo(
model_name=model_config.model_name,
provider=model_config.provider,
@@ -286,15 +268,20 @@ class MemoryPerceptualService:
with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
opt_system_prompt = f.read()
rendered_system_message = Template(opt_system_prompt).render(file_type=file.type, language='zh')
except FileNotFoundError:
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
except FileNotFoundError as e:
business_logger.error(f"Failed to generate perceptual memory: {str(e)}")
return None
messages = [
{"role": RoleType.SYSTEM.value, "content": [{"type": "text", "text": rendered_system_message}]},
{"role": RoleType.USER.value, "content": [
{"type": "text", "text": "Summarize the following file"}, file_message
]}
]
result = await llm.ainvoke(messages)
try:
result = await llm.ainvoke(messages)
except Exception as e:
business_logger.error(f"Failed to generate perceptual memory: {str(e)}")
return None
content = result.content
final_output = ""
if isinstance(content, list):

View File

@@ -613,37 +613,6 @@ async def search_entity(end_user_id: Optional[str] = None) -> Dict[str, Any]:
return data
async def search_all(end_user_id: Optional[str] = None) -> Dict[str, Any]:
result = await _neo4j_connector.execute_query(
MemoryConfigRepository.SEARCH_FOR_ALL,
end_user_id=end_user_id,
)
# 检查结果是否为空或长度不足
if not result or len(result) < 4:
data = {
"total": 0,
"counts": {
"dialogue": 0,
"chunk": 0,
"statement": 0,
"entity": 0,
},
}
return data
data = {
"total": result[-1]["Count"],
"counts": {
"dialogue": result[0]["Count"],
"chunk": result[1]["Count"],
"statement": result[2]["Count"],
"entity": result[3]["Count"],
},
}
return data
async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, Any]:
"""统一知识库类型分布接口。
@@ -695,6 +664,37 @@ async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]
return result
async def search_all_batch(end_user_ids: List[str]) -> Dict[str, int]:
"""批量查询多个用户的记忆数量简化版本只返回total
Args:
end_user_ids: 用户ID列表
Returns:
Dict[str, int]: 以user_id为key的记忆数量字典
格式: {"user_id": total_count}
"""
if not end_user_ids:
return {}
result = await _neo4j_connector.execute_query(
MemoryConfigRepository.SEARCH_FOR_ALL_BATCH,
end_user_ids=end_user_ids,
)
# 转换结果为字典格式,字典格式在查询中无需遍历结果集,直接返回
data = {}
for row in result:
data[row["user_id"]] = row["total"]
# 为没有数据的用户填充默认值,转换字典格式还为无数据填充默认值
for user_id in end_user_ids:
if user_id not in data:
data[user_id] = 0
return data
async def analytics_hot_memory_tags(
db: Session,
current_user: User,

View File

@@ -45,12 +45,20 @@ class ModelParameterMerger:
"frequency_penalty": 0.0,
"presence_penalty": 0.0,
"n": 1,
"stop": None
"stop": None,
"deep_thinking": False,
"thinking_budget_tokens": None
}
# 合并参数:默认值 -> 模型配置 -> Agent 配置
merged = default_params.copy()
# Pydantic 对象转为 dict
if model_config_params and hasattr(model_config_params, 'model_dump'):
model_config_params = model_config_params.model_dump()
if agent_config_params and hasattr(agent_config_params, 'model_dump'):
agent_config_params = agent_config_params.model_dump()
# 应用模型配置参数
if model_config_params:
for key in default_params:

View File

@@ -69,7 +69,8 @@ class ModelConfigService:
return items
@staticmethod
def get_model_by_name(db: Session, name: str, provider: str | None = None, tenant_id: uuid.UUID | None = None) -> ModelConfig:
def get_model_by_name(db: Session, name: str, provider: str | None = None,
tenant_id: uuid.UUID | None = None) -> ModelConfig:
"""根据名称获取模型配置"""
model = ModelConfigRepository.get_by_name(db, name, provider=provider, tenant_id=tenant_id)
if not model:
@@ -77,7 +78,8 @@ class ModelConfigService:
return model
@staticmethod
def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[ModelConfig]:
def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[
ModelConfig]:
"""按名称模糊匹配获取模型配置列表"""
return ModelConfigRepository.search_by_name(db, name, tenant_id=tenant_id, limit=limit)
@@ -91,7 +93,8 @@ class ModelConfigService:
api_base: Optional[str] = None,
model_type: str = "llm",
test_message: str = "Hello",
is_omni: bool = False
is_omni: bool = False,
capability: Optional[list] = None
) -> Dict[str, Any]:
"""验证模型配置是否有效
@@ -122,6 +125,7 @@ class ModelConfigService:
api_key=api_key,
base_url=api_base,
is_omni=is_omni,
support_thinking="thinking" in (capability or []),
temperature=0.7,
max_tokens=100
)
@@ -158,13 +162,13 @@ class ModelConfigService:
# 统一使用 RedBearEmbeddings自动支持火山引擎多模态
embedding = RedBearEmbeddings(model_config)
test_texts = [test_message, "测试文本"]
# 火山引擎使用 embed_batch其他使用 embed_documents
if provider.lower() == "volcano":
vectors = await asyncio.to_thread(embedding.embed_batch, test_texts)
else:
vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
elapsed_time = time.time() - start_time
return {
@@ -200,11 +204,11 @@ class ModelConfigService:
},
"error": None
}
elif model_type_lower == "image":
# 图片生成模型验证
from app.core.models.generation import RedBearImageGenerator
generator = RedBearImageGenerator(model_config)
result = await generator.agenerate(
prompt="a cute panda",
@@ -212,7 +216,7 @@ class ModelConfigService:
)
elapsed_time = time.time() - start_time
logger.info(f"成功生成图片,结果: {result}")
return {
"valid": True,
"message": "图片生成模型配置验证成功",
@@ -224,21 +228,21 @@ class ModelConfigService:
},
"error": None
}
elif model_type_lower == "video":
# 视频生成模型验证
from app.core.models.generation import RedBearVideoGenerator
generator = RedBearVideoGenerator(model_config)
result = await generator.agenerate(
prompt="a cute panda playing in bamboo forest",
duration=5
)
elapsed_time = time.time() - start_time
# 视频生成是异步任务返回任务ID
task_id = result.get("task_id") if isinstance(result, dict) else None
return {
"valid": True,
"message": "视频生成模型配置验证成功",
@@ -265,7 +269,6 @@ class ModelConfigService:
# 提取详细的错误信息
error_message = str(e)
error_type = type(e).__name__
print("=========error_message:",error_message.lower())
# 特殊处理常见的错误类型
if "unsupported countries" in error_message.lower() or "unsupported region" in error_message.lower():
# 区域/国家限制(适用于所有提供商)
@@ -319,7 +322,8 @@ class ModelConfigService:
api_base=api_key_data.api_base,
model_type=model_data.type,
test_message="Hello",
is_omni=model_data.is_omni
is_omni=model_data.is_omni,
capability=model_data.capability
)
if not validation_result["valid"]:
raise BusinessException(
@@ -354,14 +358,16 @@ class ModelConfigService:
return model
@staticmethod
def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate, tenant_id: uuid.UUID | None = None) -> ModelConfig:
def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate,
tenant_id: uuid.UUID | None = None) -> ModelConfig:
"""更新模型配置"""
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
if not existing_model:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
if model_data.name and model_data.name != existing_model.name:
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id):
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider,
tenant_id=tenant_id):
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id)
@@ -370,25 +376,27 @@ class ModelConfigService:
return model
@staticmethod
async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate,
tenant_id: uuid.UUID) -> ModelConfig:
"""创建组合模型"""
if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE, tenant_id=tenant_id):
if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE,
tenant_id=tenant_id):
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
# 验证所有 API Key 存在且类型匹配
for api_key_id in model_data.api_key_ids:
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
if not api_key:
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
# 检查 API Key 关联的模型配置类型
for model_config in api_key.model_configs:
# chat 和 llm 类型可以兼容
compatible_types = {ModelType.LLM, ModelType.CHAT}
config_type = model_config.type
request_type = model_data.type
if not (config_type == request_type or
if not (config_type == request_type or
(config_type in compatible_types and request_type in compatible_types)):
raise BusinessException(
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
@@ -399,7 +407,7 @@ class ModelConfigService:
# f"API Key {api_key_id} 关联的模型是组合模型,不能用于创建新的组合模型",
# BizCode.INVALID_PARAMETER
# )
# 创建组合模型
model_config_data = {
"tenant_id": tenant_id,
@@ -418,49 +426,51 @@ class ModelConfigService:
model = ModelConfigRepository.create(db, model_config_data)
db.flush()
# 关联 API Keys
for api_key_id in model_data.api_key_ids:
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
if api_key:
model.api_keys.append(api_key)
db.commit()
db.refresh(model)
return model
@staticmethod
async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate,
tenant_id: uuid.UUID) -> ModelConfig:
"""更新组合模型"""
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
if not existing_model:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
if model_data.name and model_data.name != existing_model.name:
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id):
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider,
tenant_id=tenant_id):
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
if not existing_model.is_composite:
raise BusinessException("该模型不是组合模型", BizCode.INVALID_PARAMETER)
# 验证所有 API Key 存在且类型匹配
for api_key_id in model_data.api_key_ids:
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
if not api_key:
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
for model_config in api_key.model_configs:
compatible_types = {ModelType.LLM, ModelType.CHAT}
config_type = model_config.type
request_type = existing_model.type
if not (config_type == request_type or
if not (config_type == request_type or
(config_type in compatible_types and request_type in compatible_types)):
raise BusinessException(
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
BizCode.INVALID_PARAMETER
)
# 更新基本信息
existing_model.name = model_data.name
# existing_model.type = model_data.type
@@ -471,14 +481,14 @@ class ModelConfigService:
existing_model.is_public = model_data.is_public
if "load_balance_strategy" in model_data.model_fields_set:
existing_model.load_balance_strategy = model_data.load_balance_strategy
# 更新 API Keys 关联
existing_model.api_keys.clear()
for api_key_id in model_data.api_key_ids:
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
if api_key:
existing_model.api_keys.append(api_key)
db.commit()
db.refresh(existing_model)
return existing_model
@@ -532,7 +542,7 @@ class ModelApiKeyService:
"""根据provider为多个ModelConfig创建API Key"""
created_keys = []
failed_models = [] # 记录验证失败的模型
for model_config_id in data.model_config_ids:
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
if not model_config:
@@ -540,10 +550,10 @@ class ModelApiKeyService:
data.is_omni = model_config.is_omni
data.capability = model_config.capability
# 从ModelBase获取model_name
model_name = model_config.model_base.name if model_config.model_base else model_config.name
# 检查是否存在API Key包括软删除需要考虑tenant_id
existing_key = db.query(ModelApiKey).join(
ModelApiKey.model_configs
@@ -553,7 +563,7 @@ class ModelApiKeyService:
ModelApiKey.model_name == model_name,
ModelConfig.tenant_id == model_config.tenant_id
).first()
if existing_key:
# 如果已存在,重新激活并更新
if existing_key.is_active:
@@ -566,14 +576,14 @@ class ModelApiKeyService:
existing_key.model_name = model_name
existing_key.capability = data.capability
existing_key.is_omni = data.is_omni
# 检查是否已关联该模型配置
if model_config not in existing_key.model_configs:
existing_key.model_configs.append(model_config)
created_keys.append(existing_key)
continue
# 验证配置
validation_result = await ModelConfigService.validate_model_config(
db=db,
@@ -583,13 +593,14 @@ class ModelApiKeyService:
api_base=data.api_base,
model_type=model_config.type,
test_message="Hello",
is_omni=data.is_omni
is_omni=data.is_omni,
capability=model_config.capability
)
if not validation_result["valid"]:
# 记录验证失败的模型,但不抛出异常
failed_models.append(model_name)
continue
# 创建API Key
api_key_data = ModelApiKeyCreate(
model_config_ids=[model_config_id],
@@ -606,12 +617,12 @@ class ModelApiKeyService:
)
api_key_obj = ModelApiKeyRepository.create(db, api_key_data)
created_keys.append(api_key_obj)
if created_keys:
db.commit()
for key in created_keys:
db.refresh(key)
return created_keys, failed_models
@staticmethod
@@ -626,7 +637,7 @@ class ModelApiKeyService:
api_key_data.is_omni = model_config.is_omni
if api_key_data.capability is None:
api_key_data.capability = model_config.capability
# 检查API Key是否已存在(包括软删除)需要考虑tenant_id
existing_key = db.query(ModelApiKey).join(
ModelApiKey.model_configs
@@ -650,15 +661,15 @@ class ModelApiKeyService:
existing_key.model_name = api_key_data.model_name
existing_key.capability = api_key_data.capability
existing_key.is_omni = api_key_data.is_omni
# 检查是否已关联该模型配置
if model_config not in existing_key.model_configs:
existing_key.model_configs.append(model_config)
db.commit()
db.refresh(existing_key)
return existing_key
# 验证配置
validation_result = await ModelConfigService.validate_model_config(
db=db,
@@ -668,7 +679,8 @@ class ModelApiKeyService:
api_base=api_key_data.api_base,
model_type=model_config.type,
test_message="Hello",
is_omni=api_key_data.is_omni
is_omni=api_key_data.is_omni,
capability=model_config.capability
)
if not validation_result["valid"]:
raise BusinessException(
@@ -691,7 +703,7 @@ class ModelApiKeyService:
# 获取关联的模型配置以获取模型类型
if existing_api_key.model_configs:
model_config = existing_api_key.model_configs[0]
validation_result = await ModelConfigService.validate_model_config(
db=db,
model_name=api_key_data.model_name or existing_api_key.model_name,
@@ -700,7 +712,8 @@ class ModelApiKeyService:
api_base=api_key_data.api_base or existing_api_key.api_base,
model_type=model_config.type,
test_message="Hello",
is_omni=model_config.is_omni
is_omni=model_config.is_omni,
capability=model_config.capability
)
if not validation_result["valid"]:
raise BusinessException(
@@ -729,15 +742,15 @@ class ModelApiKeyService:
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
if not model_config:
return None
api_keys = [key for key in model_config.api_keys if key.is_active]
if not api_keys:
return None
# 如果是轮询策略,按使用次数最少,次数相同则选最早使用的
if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN:
return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min))
# 否则返回第一个
return api_keys[0]
@@ -760,20 +773,19 @@ class ModelApiKeyService:
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
class ModelBaseService:
"""基础模型服务"""
@staticmethod
def get_model_base_list(db: Session, query: model_schema.ModelBaseQuery, tenant_id: uuid.UUID = None) -> List:
models = ModelBaseRepository.get_list(db, query)
provider_groups = {}
for m in models:
model_dict = model_schema.ModelBase.model_validate(m).model_dump()
if tenant_id:
model_dict['is_added'] = ModelBaseRepository.check_added_by_tenant(db, m.id, tenant_id)
provider = m.provider
if provider not in provider_groups:
provider_groups[provider] = {
@@ -781,7 +793,7 @@ class ModelBaseService:
"models": []
}
provider_groups[provider]["models"].append(model_dict)
return list(provider_groups.values())
@staticmethod
@@ -823,10 +835,10 @@ class ModelBaseService:
model_base = ModelBaseRepository.get_by_id(db, model_base_id)
if not model_base:
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
if ModelBaseRepository.check_added_by_tenant(db, model_base_id, tenant_id):
raise BusinessException("模型已添加", BizCode.DUPLICATE_NAME)
model_config_data = {
"model_id": model_base_id,
"tenant_id": tenant_id,

View File

@@ -287,6 +287,11 @@ class MultiAgentOrchestrator:
sub_conversation_id = None
total_tokens = 0
# 累加 Master Agent 路由决策消耗的 token
total_tokens += task_analysis.get("routing_tokens", 0)
# 累加 Master Agent 整合消耗的 token
total_tokens += getattr(self, '_last_merge_tokens', 0)
if isinstance(results, dict):
sub_conversation_id = results.get("conversation_id") or results.get("result", {}).get("conversation_id")
# 提取 token 信息
@@ -358,12 +363,16 @@ class MultiAgentOrchestrator:
variables=variables
)
# 获取路由决策消耗的 token
routing_tokens = getattr(self.router, '_last_routing_tokens', 0)
logger.info(
"Master Agent 分析完成",
extra={
"selected_agent": routing_decision.get("selected_agent_id"),
"confidence": routing_decision.get("confidence"),
"strategy": routing_decision.get("strategy")
"strategy": routing_decision.get("strategy"),
"routing_tokens": routing_tokens
}
)
@@ -372,7 +381,8 @@ class MultiAgentOrchestrator:
"variables": variables or {},
"sub_agents": self.config.sub_agents,
"initial_context": variables or {},
"routing_decision": routing_decision
"routing_decision": routing_decision,
"routing_tokens": routing_tokens
}
async def _execute_sequential(
@@ -1032,6 +1042,11 @@ class MultiAgentOrchestrator:
# 5. 流式执行子 Agent
sub_conversation_id = None
# Master Agent 路由决策消耗的 token通过 sub_usage 事件发送给上层
routing_tokens = task_analysis.get("routing_tokens", 0)
if routing_tokens > 0:
yield self._format_sse_event("sub_usage", {"total_tokens": routing_tokens})
async for event in self._execute_sub_agent_stream(
agent_data["config"],
message,
@@ -1054,6 +1069,7 @@ class MultiAgentOrchestrator:
except:
pass
# 直接透传所有事件(包括 sub_usage累加统一由上层处理
yield event
# 6. 如果有会话 ID发送一个包含它的事件
@@ -2600,6 +2616,7 @@ class MultiAgentOrchestrator:
api_key=api_key_config.api_key,
base_url=api_key_config.api_base,
is_omni=api_key_config.is_omni,
support_thinking="thinking" in (api_key_config.capability or []),
temperature=0.7, # 整合任务使用中等温度
max_tokens=2000
)
@@ -2612,6 +2629,17 @@ class MultiAgentOrchestrator:
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
# 提取整合消耗的 token
merge_tokens = 0
if hasattr(response, 'usage_metadata') and response.usage_metadata:
um = response.usage_metadata
merge_tokens = um.get("total_tokens", 0) if isinstance(um, dict) else getattr(um, "total_tokens", 0)
elif hasattr(response, 'response_metadata') and response.response_metadata:
token_usage = response.response_metadata.get("token_usage") or response.response_metadata.get("usage", {})
if isinstance(token_usage, dict):
merge_tokens = token_usage.get("total_tokens", 0)
self._last_merge_tokens = merge_tokens
# 提取响应内容
if hasattr(response, 'content'):
merged_response = response.content
@@ -2621,7 +2649,8 @@ class MultiAgentOrchestrator:
logger.info(
"Master Agent 整合完成",
extra={
"merged_length": len(merged_response)
"merged_length": len(merged_response),
"merge_tokens": merge_tokens
}
)
@@ -2766,6 +2795,7 @@ class MultiAgentOrchestrator:
api_key=api_key_config.api_key,
base_url=api_key_config.api_base,
is_omni=api_key_config.is_omni,
support_thinking="thinking" in (api_key_config.capability or []),
temperature=0.7,
max_tokens=2000,
extra_params={"streaming": True} # 启用流式输出

View File

@@ -441,13 +441,13 @@ class MultimodalService:
if file.transfer_method == TransferMethod.REMOTE_URL:
return True, {
"type": "text",
"text": f"<document url=\"{file.url}\">\n{await self._extract_document_text(file)}\n</document>"
"text": f"<document url=\"{file.url}\">\n{await self.extract_document_text(file)}\n</document>"
}
else:
# 本地文件,提取文本内容
server_url = settings.FILE_LOCAL_SERVER_URL
file.url = f"{server_url}/storage/permanent/{file.upload_file_id}"
text = await self._extract_document_text(file)
text = await self.extract_document_text(file)
file_metadata = self.db.query(FileMetadata).filter(
FileMetadata.id == file.upload_file_id
).first()
@@ -545,7 +545,7 @@ class MultimodalService:
server_url = settings.FILE_LOCAL_SERVER_URL
return f"{server_url}/storage/permanent/{file_id}"
async def _extract_document_text(self, file: FileInput) -> str:
async def extract_document_text(self, file: FileInput) -> str:
"""
提取文档文本内容

View File

@@ -185,7 +185,8 @@ class PromptOptimizerService:
provider=api_config.provider,
api_key=api_config.api_key,
base_url=api_config.api_base,
is_omni=api_config.is_omni
is_omni=api_config.is_omni,
support_thinking="thinking" in (api_config.capability or []),
), type=ModelType(model_config.type))
try:
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')

View File

@@ -1,26 +1,24 @@
"""基于分享链接的聊天服务"""
import uuid
import time
import asyncio
import json
import time
import uuid
from typing import Optional, Dict, Any, AsyncGenerator
from deprecated import deprecated
from sqlalchemy.orm import Session
from app.repositories.model_repository import ModelApiKeyRepository
from app.services.memory_konwledges_server import write_rag
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException, ResourceNotFoundException
from app.core.logging_config import get_business_logger
from app.models import MultiAgentConfig
from app.models import ReleaseShare, AppRelease, Conversation
from app.repositories import knowledge_repository
from app.services.conversation_service import ConversationService
from app.services.draft_run_service import create_web_search_tool
from app.services.model_service import ModelApiKeyService
from app.services.release_share_service import ReleaseShareService
from app.core.exceptions import BusinessException, ResourceNotFoundException
from app.core.error_codes import BizCode
from app.core.logging_config import get_business_logger
from app.services.multi_agent_service import MultiAgentService
from app.models import MultiAgentConfig
from app.repositories import knowledge_repository
import json
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
from app.services.release_share_service import ReleaseShareService
logger = get_business_logger()
@@ -118,6 +116,7 @@ class SharedChatService:
return conversation
@deprecated("Use the chat method under app_chat_service instead.")
async def chat(
self,
share_token: str,
@@ -136,10 +135,7 @@ class SharedChatService:
config_id = actual_config_id
from app.core.agent.langchain_agent import LangChainAgent
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
from app.services.model_parameter_merger import ModelParameterMerger
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from sqlalchemy import select
from app.models import ModelApiKey
start_time = time.time()
actual_config_id = None
@@ -252,7 +248,9 @@ class SharedChatService:
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
deep_thinking=model_parameters.get("deep_thinking", False),
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
capability=api_key_obj.capability or [],
)
# 加载历史消息
@@ -273,11 +271,6 @@ class SharedChatService:
message=message,
history=history,
context=None,
end_user_id=user_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
config_id=config_id,
memory_flag=memory_flag
)
# 保存消息
@@ -324,6 +317,7 @@ class SharedChatService:
"elapsed_time": elapsed_time
}
@deprecated("Use the chat method under app_chat_service instead.")
async def chat_stream(
self,
share_token: str,
@@ -341,8 +335,6 @@ class SharedChatService:
from app.core.agent.langchain_agent import LangChainAgent
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from sqlalchemy import select
from app.models import ModelApiKey
import json
start_time = time.time()
@@ -460,7 +452,10 @@ class SharedChatService:
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
streaming=True
streaming=True,
deep_thinking=model_parameters.get("deep_thinking", False),
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
capability=api_key_obj.capability or [],
)
# 加载历史消息
@@ -486,14 +481,11 @@ class SharedChatService:
message=message,
history=history,
context=None,
end_user_id=user_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
config_id=config_id,
memory_flag=memory_flag
):
if isinstance(chunk, int):
total_tokens = chunk
elif isinstance(chunk, dict) and chunk.get("type") == "reasoning":
yield f"event: reasoning\ndata: {json.dumps({'content': chunk['content']}, ensure_ascii=False)}\n\n"
else:
full_content += chunk
# 发送消息块事件
@@ -585,6 +577,7 @@ class SharedChatService:
return conversations, total
@deprecated("Use the chat method under app_chat_service instead.")
async def multi_agent_chat(
self,
share_token: str,
@@ -680,6 +673,7 @@ class SharedChatService:
"elapsed_time": elapsed_time
}
@deprecated("Use the chat method under app_chat_service instead.")
async def multi_agent_chat_stream(
self,
share_token: str,

Some files were not shown because too many files have changed in this diff Show More