Compare commits

..

1 Commits

Author SHA1 Message Date
lanceyq
82c6d1a90f [feat] Context manager: Used to measure the execution time of code blocks 2026-03-31 14:56:26 +08:00
353 changed files with 5562 additions and 15252 deletions

View File

@@ -1,157 +0,0 @@
name: Release Notify Workflow
on:
pull_request:
types: [closed]
jobs:
notify:
if: >
github.event.pull_request.merged == true &&
startsWith(github.event.pull_request.base.ref, 'release')
runs-on: ubuntu-latest
steps:
# 防止 GitHub HEAD 未同步
- run: sleep 3
# 1⃣ 获取分支 HEAD
- name: Get HEAD
id: head
run: |
HEAD_SHA=$(curl -s \
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
https://api.github.com/repos/${{ github.repository }}/git/ref/heads/${{ github.event.pull_request.base.ref }} \
| jq -r '.object.sha')
echo "head_sha=$HEAD_SHA" >> $GITHUB_OUTPUT
# 2⃣ 判断是否最终PR
- name: Check Latest
id: check
run: |
if [ "${{ github.event.pull_request.merge_commit_sha }}" = "${{ steps.head.outputs.head_sha }}" ]; then
echo "ok=true" >> $GITHUB_OUTPUT
else
echo "ok=false" >> $GITHUB_OUTPUT
fi
# 3⃣ 尝试从 PR body 提取 Sourcery 摘要
- name: Extract Sourcery Summary
if: steps.check.outputs.ok == 'true'
id: sourcery
env:
PR_BODY: ${{ github.event.pull_request.body }}
run: |
python3 << 'PYEOF'
import os, re
body = os.environ.get("PR_BODY", "") or ""
match = re.search(
r"## Summary by Sourcery\s*\n(.*?)(?=\n## |\Z)",
body,
re.DOTALL
)
if match:
summary = match.group(1).strip()
found = "true"
else:
summary = ""
found = "false"
with open("sourcery_summary.txt", "w", encoding="utf-8") as f:
f.write(summary)
with open(os.environ["GITHUB_OUTPUT"], "a") as gh:
gh.write(f"found={found}\n")
gh.write("summary<<EOF\n")
gh.write(summary + "\n")
gh.write("EOF\n")
PYEOF
# 4⃣ Fallback: 获取 commits + 通义千问总结
- name: Get Commits
if: steps.check.outputs.ok == 'true' && steps.sourcery.outputs.found == 'false'
run: |
curl -s \
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
${{ github.event.pull_request.commits_url }} \
| jq -r '.[].commit.message' | head -n 20 > commits.txt
- name: AI Summary (Qwen Fallback)
if: steps.check.outputs.ok == 'true' && steps.sourcery.outputs.found == 'false'
id: qwen
env:
DASHSCOPE_API_KEY: ${{ secrets.DASHSCOPE_API_KEY }}
run: |
python3 << 'PYEOF'
import json, os, urllib.request
with open("commits.txt", "r") as f:
commits = f.read().strip()
prompt = "请用中文总结以下代码提交输出3-5条要点面向测试人员。直接输出编号列表不要输出标题或前言\n" + commits
payload = {"model": "qwen-plus", "input": {"prompt": prompt}}
data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
req = urllib.request.Request(
"https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation",
data=data,
headers={
"Authorization": "Bearer " + os.environ["DASHSCOPE_API_KEY"],
"Content-Type": "application/json"
}
)
resp = urllib.request.urlopen(req)
result = json.loads(resp.read().decode())
summary = result.get("output", {}).get("text", "AI 摘要生成失败")
with open(os.environ["GITHUB_OUTPUT"], "a") as gh:
gh.write("summary<<EOF\n")
gh.write(summary + "\n")
gh.write("EOF\n")
PYEOF
# 5⃣ 企业微信通知Markdown
- name: Notify WeChat
if: steps.check.outputs.ok == 'true'
env:
WECHAT_WEBHOOK: ${{ secrets.WECHAT_WEBHOOK }}
BRANCH: ${{ github.event.pull_request.base.ref }}
AUTHOR: ${{ github.event.pull_request.user.login }}
PR_TITLE: ${{ github.event.pull_request.title }}
PR_URL: ${{ github.event.pull_request.html_url }}
SOURCERY_FOUND: ${{ steps.sourcery.outputs.found }}
SOURCERY_SUMMARY: ${{ steps.sourcery.outputs.summary }}
QWEN_SUMMARY: ${{ steps.qwen.outputs.summary }}
run: |
python3 << 'PYEOF'
import json, os, urllib.request
if os.environ.get("SOURCERY_FOUND") == "true":
label = "Summary by Sourcery"
summary = os.environ.get("SOURCERY_SUMMARY", "")
else:
label = "AI变更摘要"
summary = os.environ.get("QWEN_SUMMARY", "AI 摘要生成失败")
content = (
"## 🚀 Release 发布通知\n"
"> 📦 **分支**: " + os.environ["BRANCH"] + "\n"
"> 👤 **提交人**: " + os.environ["AUTHOR"] + "\n"
"> 📝 **标题**: " + os.environ["PR_TITLE"] + "\n\n"
"### 🧠 " + label + "\n" +
summary + "\n\n"
"---\n"
"🔗 [查看PR详情](" + os.environ["PR_URL"] + ")"
)
payload = {"msgtype": "markdown", "markdown": {"content": content}}
data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
req = urllib.request.Request(
os.environ["WECHAT_WEBHOOK"],
data=data,
headers={"Content-Type": "application/json"}
)
resp = urllib.request.urlopen(req)
print(resp.read().decode())
PYEOF

View File

@@ -1,36 +0,0 @@
name: Sync to Gitee
on:
push:
branches:
- main # Production
- develop # Integration
- 'release/*' # Release preparation
- 'hotfix/*' # Urgent fixes
tags:
- '*' # All version tags (v1.0.0, etc.)
jobs:
sync:
runs-on: ubuntu-latest
steps:
- name: Checkout Source Code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Sync to Gitee
run: |
GITEE_URL="https://${{ secrets.GITEE_USERNAME }}:${{ secrets.GITEE_TOKEN }}@gitee.com/hangzhou-hongxiong-intelligent_1/MemoryBear.git"
git remote add gitee "$GITEE_URL"
# 遍历并推送所有分支
for branch in $(git branch -r | grep -v HEAD | sed 's/origin\///'); do
echo "Syncing branch: $branch"
git push -f gitee "origin/$branch:refs/heads/$branch"
done
# 推送所有标签
echo "Syncing tags..."
git push gitee --tags --force

2
.gitignore vendored
View File

@@ -18,7 +18,6 @@ examples/
.kiro
.vscode
.idea
.claude
# Temporary outputs
.DS_Store
@@ -27,7 +26,6 @@ time.log
celerybeat-schedule.db
search_results.json
redbear-mem-metrics/
redbear-mem-benchmark/
pitch-deck/
api/migrations/versions

View File

@@ -2,10 +2,6 @@
# MemoryBear empowers AI with human-like memory capabilities
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](LICENSE)
[![Python](https://img.shields.io/badge/Python-3.12+-green?logo=python&logoColor=white)](https://www.python.org/)
[![Gitee Sync](https://img.shields.io/github/actions/workflow/status/SuanmoSuanyangTechnology/MemoryBear/sync-to-gitee.yml?label=Gitee%20Sync&logo=gitee&logoColor=white)](https://github.com/SuanmoSuanyangTechnology/MemoryBear/actions/workflows/sync-to-gitee.yml)
[中文](./README_CN.md) | English
### [Installation Guide](#memorybear-installation-guide)

View File

@@ -2,10 +2,6 @@
# MemoryBear 让AI拥有如同人类一样的记忆
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](LICENSE)
[![Python](https://img.shields.io/badge/Python-3.12+-green?logo=python&logoColor=white)](https://www.python.org/)
[![Gitee Sync](https://img.shields.io/github/actions/workflow/status/SuanmoSuanyangTechnology/MemoryBear/sync-to-gitee.yml?label=Gitee%20Sync&logo=gitee&logoColor=white)](https://github.com/SuanmoSuanyangTechnology/MemoryBear/actions/workflows/sync-to-gitee.yml)
中文 | [English](./README.md)
### [安装教程](#memorybear安装教程)

View File

@@ -111,17 +111,11 @@ celery_app.conf.update(
# Clustering tasks → memory_tasks queue (使用相同的 worker避免 macOS fork 问题)
'app.tasks.run_incremental_clustering': {'queue': 'memory_tasks'},
# Metadata extraction → memory_tasks queue
'app.tasks.extract_user_metadata': {'queue': 'memory_tasks'},
# Document tasks → document_tasks queue (prefork worker)
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'},
# GraphRAG tasks → graphrag_tasks queue (独立队列,避免阻塞文档解析)
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'graphrag_tasks'},
'app.core.rag.tasks.build_graphrag_for_document': {'queue': 'graphrag_tasks'},
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},

View File

@@ -14,6 +14,7 @@ from . import (
document_controller,
emotion_config_controller,
emotion_controller,
end_user_controller,
file_controller,
file_storage_controller,
home_page_controller,
@@ -98,5 +99,6 @@ manager_router.include_router(file_storage_controller.router)
manager_router.include_router(ontology_controller.router)
manager_router.include_router(skill_controller.router)
manager_router.include_router(i18n_controller.router)
manager_router.include_router(end_user_controller.router)
__all__ = ["manager_router"]

View File

@@ -292,19 +292,10 @@ def get_opening(
):
"""返回开场白文本和预设问题,供前端对话界面初始化时展示"""
workspace_id = current_user.current_workspace_id
# 根据应用类型获取 features
from app.models.app_model import App as AppModel
app = db.get(AppModel, app_id)
if app and app.type == "workflow":
cfg = app_service.get_workflow_config(db=db, app_id=app_id, workspace_id=workspace_id)
features = cfg.features or {}
else:
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
features = cfg.features or {}
if hasattr(features, "model_dump"):
features = features.model_dump()
opening = features.get("opening_statement", {})
return success(data=app_schema.OpeningResponse(
enabled=opening.get("enabled", False),
@@ -1079,14 +1070,6 @@ async def update_workflow_config(
current_user: Annotated[User, Depends(get_current_user)]
):
workspace_id = current_user.current_workspace_id
if payload.variables:
from app.services.workflow_service import WorkflowService
resolved = await WorkflowService(db)._resolve_variables_file_defaults(
[v.model_dump() for v in payload.variables]
)
# Patch default values back into VariableDefinition objects
for var_def, resolved_def in zip(payload.variables, resolved):
var_def.default = resolved_def.get("default", var_def.default)
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
return success(data=WorkflowConfigSchema.model_validate(cfg))
@@ -1250,11 +1233,9 @@ async def export_app(
async def import_app(
file: UploadFile = File(...),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
app_id: Optional[str] = Form(None),
current_user: User = Depends(get_current_user)
):
"""从 YAML 文件导入 agent / multi_agent / workflow 应用。
传入 app_id 时覆盖该应用的配置(类型必须一致),否则创建新应用。
跨空间/跨租户导入时,模型/工具/知识库会按名称匹配,匹配不到则置空并返回 warnings。
"""
if not file.filename.lower().endswith((".yaml", ".yml")):
@@ -1265,15 +1246,13 @@ async def import_app(
if not dsl or "app" not in dsl:
return fail(msg="YAML 格式无效,缺少 app 字段", code=BizCode.BAD_REQUEST)
target_app_id = uuid.UUID(app_id) if app_id else None
result_app, warnings = AppDslService(db).import_dsl(
new_app, warnings = AppDslService(db).import_dsl(
dsl=dsl,
workspace_id=current_user.current_workspace_id,
tenant_id=current_user.tenant_id,
user_id=current_user.id,
app_id=target_app_id,
)
return success(
data={"app": app_schema.App.model_validate(result_app), "warnings": warnings},
data={"app": app_schema.App.model_validate(new_app), "warnings": warnings},
msg="应用导入成功" + (",但部分资源需手动配置" if warnings else "")
)

View File

@@ -53,12 +53,10 @@ async def login_for_access_token(
user = auth_service.authenticate_user_or_raise(db, form_data.email, form_data.password)
auth_logger.info(f"用户认证成功: {user.email} (ID: {user.id})")
if form_data.invite:
auth_service.bind_workspace_with_invite(
db=db,
auth_service.bind_workspace_with_invite(db=db,
user=user,
invite_token=form_data.invite,
workspace_id=invite_info.workspace_id
)
workspace_id=invite_info.workspace_id)
except BusinessException as e:
# 用户不存在且有邀请码,尝试注册
if e.code == BizCode.USER_NOT_FOUND:
@@ -136,7 +134,7 @@ async def refresh_token(
# 检查用户是否存在
user = auth_service.get_user_by_id(db, userId)
if not user:
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NO_ACCESS)
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
# 检查 refresh token 黑名单
if settings.ENABLE_SINGLE_SESSION:

View File

@@ -23,7 +23,6 @@ from app.models.user_model import User
from app.schemas import chunk_schema
from app.schemas.response_schema import ApiResponse
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
from app.services.model_service import ModelApiKeyService
# Obtain a dedicated API logger
api_logger = get_api_logger()
@@ -461,20 +460,18 @@ async def retrieve_chunks(
if retrieve_data.retrieve_type == chunk_schema.RetrieveType.Graph:
kb_ids = [str(kb_id) for kb_id in private_kb_ids]
workspace_ids = [str(workspace_id) for workspace_id in private_workspace_ids]
llm_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.llm_id)
emb_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.embedding_id)
# Prepare to configure chat_mdl、embedding_model、vision_model information
chat_model = Base(
key=llm_key.api_key,
model_name=llm_key.model_name,
base_url=llm_key.api_base
key=db_knowledge.llm.api_keys[0].api_key,
model_name=db_knowledge.llm.api_keys[0].model_name,
base_url=db_knowledge.llm.api_keys[0].api_base
)
embedding_model = OpenAIEmbed(
key=emb_key.api_key,
model_name=emb_key.model_name,
base_url=emb_key.api_base
key=db_knowledge.embedding.api_keys[0].api_key,
model_name=db_knowledge.embedding.api_keys[0].model_name,
base_url=db_knowledge.embedding.api_keys[0].api_base
)
doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids= kb_ids, emb_mdl=embedding_model, llm=chat_model)
if doc:
rs.insert(0, doc)
return success(data=jsonable_encoder(rs), msg="retrieval successful")

View File

@@ -314,10 +314,8 @@ async def parse_documents(
)
# 4. Check if the file exists
api_logger.debug(f"Constructed file path: {file_path}")
api_logger.debug(f"File metadata - kb_id: {db_file.kb_id}, parent_id: {db_file.parent_id}, file_id: {db_file.id}, extension: {db_file.file_ext}")
if not os.path.exists(file_path):
api_logger.error(f"File not found (possibly deleted): file_path={file_path}, file_id={db_file.id}, document_id={document_id}")
api_logger.warning(f"File not found (possibly deleted): file_path={file_path}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found (possibly deleted)"

View File

@@ -0,0 +1,48 @@
"""End User 管理接口 - 无需认证"""
from app.core.logging_config import get_business_logger
from app.core.response_utils import success
from app.db import get_db
from app.repositories.end_user_repository import EndUserRepository
from app.schemas.memory_api_schema import (
CreateEndUserRequest,
CreateEndUserResponse,
)
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
router = APIRouter(prefix="/end_users", tags=["End Users"])
logger = get_business_logger()
@router.post("")
async def create_end_user(
data: CreateEndUserRequest,
db: Session = Depends(get_db),
):
"""
Create an end user.
Creates a new end user for the given workspace.
If an end user with the same other_id already exists in the workspace,
returns the existing one.
"""
logger.info(f"Create end user request - other_id: {data.other_id}, workspace_id: {data.workspace_id}")
end_user_repo = EndUserRepository(db)
end_user = end_user_repo.get_or_create_end_user(
app_id=None,
workspace_id=data.workspace_id,
other_id=data.other_id,
)
logger.info(f"End user ready: {end_user.id}")
result = {
"id": str(end_user.id),
"other_id": end_user.other_id or "",
"other_name": end_user.other_name or "",
"workspace_id": str(end_user.workspace_id),
}
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")

View File

@@ -3,10 +3,9 @@ from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.response_utils import success
from app.db import get_db, SessionLocal
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.repositories.home_page_repository import HomePageRepository
from app.schemas.response_schema import ApiResponse
from app.services.home_page_service import HomePageService
@@ -32,32 +31,9 @@ def get_workspace_list(
@router.get("/version", response_model=ApiResponse)
def get_system_version():
"""获取系统版本号 + 说明"""
current_version = None
version_info = None
# 1⃣ 优先从数据库获取最新已发布的版本
try:
db = SessionLocal()
try:
current_version, version_info = HomePageRepository.get_latest_version_introduction(db)
finally:
db.close()
except Exception as e:
pass
# 2⃣ 降级:使用环境变量中的版本号
if not current_version:
"""获取系统版本号+说明"""
current_version = settings.SYSTEM_VERSION
version_info = HomePageService.load_version_introduction(current_version)
# 3⃣ 如果数据库和 JSON 都没有,返回基本信息
if not version_info:
version_info = {
"introduction": {"codeName": "", "releaseDate": "", "upgradePosition": "", "coreUpgrades": []},
"introduction_en": {"codeName": "", "releaseDate": "", "upgradePosition": "", "coreUpgrades": []}
}
return success(
data={
"version": current_version,

View File

@@ -352,7 +352,6 @@ async def delete_knowledge(
# 2. Soft-delete knowledge base
api_logger.debug(f"Perform a soft delete: {db_knowledge.name} (ID: {knowledge_id})")
db_knowledge.status = 2
db_knowledge.updated_at = datetime.datetime.now()
db.commit()
api_logger.info(f"The knowledge base has been successfully deleted: {db_knowledge.name} (ID: {knowledge_id})")
return success(msg="The knowledge base has been successfully deleted")

View File

@@ -1,5 +1,5 @@
import asyncio
import uuid
import time
from contextlib import contextmanager
from fastapi import APIRouter, Depends, HTTPException, status, Query
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
@@ -18,6 +18,18 @@ from app.core.logging_config import get_api_logger
# 获取API专用日志器
api_logger = get_api_logger()
@contextmanager
def timer(label: str, user_count: int = 0):
"""上下文管理器:用于测量代码块执行时间"""
start = time.perf_counter()
try:
yield
finally:
elapsed = (time.perf_counter() - start) * 1000 # 转换为毫秒
extra_info = f", 用户数: {user_count}" if user_count > 0 else ""
api_logger.info(f"[性能统计] {label}: {elapsed:.2f}ms{extra_info}")
router = APIRouter(
prefix="/dashboard",
tags=["Dashboard"],
@@ -49,67 +61,76 @@ def get_workspace_total_end_users(
@router.get("/end_users", response_model=ApiResponse)
async def get_workspace_end_users(
workspace_id: Optional[uuid.UUID] = Query(None, description="工作空间ID可选默认当前用户工作空间"),
keyword: Optional[str] = Query(None, description="搜索关键词(同时模糊匹配 other_name 和 id"),
page: int = Query(1, ge=1, description="页码从1开始"),
pagesize: int = Query(10, ge=1, description="每页数量"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
获取工作空间的宿主列表(分页查询,支持模糊搜索
获取工作空间的宿主列表(高性能优化版本 v2
返回工作空间下的宿主列表,支持分页查询和模糊搜索。
通过 keyword 参数同时模糊匹配 other_name 和 id 字段。
优化策略:
1. 批量查询 end_users一次查询而非循环
2. 并发查询所有用户的记忆数量Neo4j
3. RAG 模式使用批量查询(一次 SQL
4. 只返回必要字段减少数据传输
5. 添加短期缓存减少重复查询
6. 并发执行配置查询和记忆数量查询
Args:
workspace_id: 工作空间ID可选默认当前用户工作空间
keyword: 搜索关键词(可选,同时模糊匹配 other_name 和 id
page: 页码从1开始默认1
pagesize: 每页数量默认10
db: 数据库会话
current_user: 当前用户
Returns:
ApiResponse: 包含宿主列表和分页信息
返回格式:
{
"end_user": {"id": "uuid", "other_name": "名称"},
"memory_num": {"total": 数量},
"memory_config": {"memory_config_id": "id", "memory_config_name": "名称"}
}
"""
# 如果未提供 workspace_id使用当前用户的工作空间
if workspace_id is None:
import asyncio
import json
# from app.aioRedis import aio_redis_get, aio_redis_set
# 总耗时统计
total_start = time.perf_counter()
workspace_id = current_user.current_workspace_id
# # 尝试从缓存获取30秒缓存- 暂时注释以便进行性能测试
# with timer("Redis缓存读取"):
# cache_key = f"end_users:workspace:{workspace_id}"
# try:
# cached_data = await aio_redis_get(cache_key)
# if cached_data:
# api_logger.info(f"从缓存获取宿主列表: workspace_id={workspace_id}")
# return success(data=json.loads(cached_data), msg="宿主列表获取成功")
# except Exception as e:
# api_logger.warning(f"Redis 缓存读取失败: {str(e)}")
# 获取当前空间类型
with timer("获取空间类型"):
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表, 类型: {current_workspace_type}")
# 获取分页的 end_users
end_users_result = memory_dashboard_service.get_workspace_end_users_paginated(
# 获取 end_users(已优化为批量查询)
with timer("获取用户列表"):
end_users = memory_dashboard_service.get_workspace_end_users(
db=db,
workspace_id=workspace_id,
current_user=current_user,
page=page,
pagesize=pagesize,
keyword=keyword
current_user=current_user
)
end_users = end_users_result.get("items", [])
total = end_users_result.get("total", 0)
if not end_users:
api_logger.info(f"工作空间下没有宿主或当前页无数据: total={total}, page={page}")
return success(data={
"items": [],
"page": {
"page": page,
"pagesize": pagesize,
"total": total,
"hasnext": (page * pagesize) < total
}
}, msg="宿主列表获取成功")
api_logger.info("工作空间下没有宿主")
# # 缓存空结果,避免重复查询 - 暂时注释
# try:
# await aio_redis_set(cache_key, json.dumps([]), expire=30)
# except Exception as e:
# api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
return success(data=[], msg="宿主列表获取成功")
end_user_ids = [str(user.id) for user in end_users]
user_count = len(end_user_ids)
api_logger.info(f"需要处理的用户数: {user_count}")
# 并发执行两个独立的查询任务
async def get_memory_configs():
"""获取记忆配置(在线程池中执行同步查询)"""
with timer("功能模块-获取记忆配置", user_count):
try:
return await asyncio.to_thread(
get_end_users_connected_configs_batch,
@@ -121,8 +142,10 @@ async def get_workspace_end_users(
async def get_memory_nums():
"""获取记忆数量"""
with timer(f"功能模块-获取记忆数量[{current_workspace_type}]", user_count):
if current_workspace_type == "rag":
# RAG 模式:批量查询
with timer(" - RAG批量查询chunks"):
try:
chunk_map = await asyncio.to_thread(
memory_dashboard_service.get_users_total_chunk_batch,
@@ -134,17 +157,31 @@ async def get_workspace_end_users(
return {uid: {"total": 0} for uid in end_user_ids}
elif current_workspace_type == "neo4j":
# Neo4j 模式:批量查询(简化版本只返回total
# Neo4j 模式:并发查询(带并发限制
# 使用信号量限制并发数,避免大量用户时压垮 Neo4j
MAX_CONCURRENT_QUERIES = 10
semaphore = asyncio.Semaphore(MAX_CONCURRENT_QUERIES)
async def get_neo4j_memory_num(end_user_id: str):
async with semaphore:
single_start = time.perf_counter()
try:
batch_result = await memory_storage_service.search_all_batch(end_user_ids)
return {uid: {"total": count} for uid, count in batch_result.items()}
result = await memory_storage_service.search_all(end_user_id)
elapsed = (time.perf_counter() - single_start) * 1000
api_logger.info(f" - Neo4j单用户查询[{end_user_id}]: {elapsed:.2f}ms")
return result
except Exception as e:
api_logger.error(f"批量获取 Neo4j 记忆数量失败: {str(e)}")
return {uid: {"total": 0} for uid in end_user_ids}
api_logger.error(f"获取用户 {end_user_id} Neo4j 记忆数量失败: {str(e)}")
return {"total": 0}
with timer(" - Neo4j并发查询所有用户"):
memory_nums_list = await asyncio.gather(*[get_neo4j_memory_num(uid) for uid in end_user_ids])
return {end_user_ids[i]: memory_nums_list[i] for i in range(len(end_user_ids))}
return {uid: {"total": 0} for uid in end_user_ids}
# 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据
with timer("触发Celery初始化任务"):
try:
from app.celery_app import celery_app as _celery_app
_celery_app.send_task(
@@ -160,17 +197,19 @@ async def get_workspace_end_users(
api_logger.warning(f"触发按需初始化任务失败(不影响主流程): {e}")
# 并发执行配置查询和记忆数量查询
with timer("并发执行两个功能模块"):
memory_configs_map, memory_nums_map = await asyncio.gather(
get_memory_configs(),
get_memory_nums()
)
# 构建结果列表
items = []
# 构建结果(优化:使用列表推导式)
with timer("构建返回结果"):
result = []
for end_user in end_users:
user_id = str(end_user.id)
config_info = memory_configs_map.get(user_id, {})
items.append({
result.append({
'end_user': {
'id': user_id,
'other_name': end_user.other_name
@@ -182,6 +221,13 @@ async def get_workspace_end_users(
}
})
# # 写入缓存30秒过期- 暂时注释以便进行性能测试
# with timer("Redis缓存写入"):
# try:
# await aio_redis_set(cache_key, json.dumps(result), expire=30)
# except Exception as e:
# api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
# 触发社区聚类补全任务(异步,不阻塞接口响应)
try:
from app.tasks import init_community_clustering_for_users
@@ -190,18 +236,9 @@ async def get_workspace_end_users(
except Exception as e:
api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}")
# 构建分页响应
result = {
"items": items,
"page": {
"page": page,
"pagesize": pagesize,
"total": total,
"hasnext": (page * pagesize) < total
}
}
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录,总计 {total}")
total_elapsed = (time.perf_counter() - total_start) * 1000
api_logger.info(f"[性能统计] 接口总耗时: {total_elapsed:.2f}ms, 用户数: {user_count}")
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
return success(data=result, msg="宿主列表获取成功")
@@ -591,7 +628,7 @@ async def dashboard_data(
"total_api_call": None
}
# 1. 获取记忆总量total_memory—— neo4j 独有逻辑:查询 neo4j 存储节点
# 1. 获取记忆总量total_memory
try:
total_memory_data = await memory_dashboard_service.get_workspace_total_memory_count(
db=db,
@@ -600,32 +637,48 @@ async def dashboard_data(
end_user_id=end_user_id
)
neo4j_data["total_memory"] = total_memory_data.get("total_memory_count", 0)
api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}")
# total_app: 统计当前空间下的所有app数量
# 包含自有app + 被分享给本工作空间的app
from app.services import app_service as _app_svc
_, total_app = _app_svc.AppService(db).list_apps(
workspace_id=workspace_id, include_shared=True, pagesize=1
)
neo4j_data["total_app"] = total_app
api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}, 应用数量: {neo4j_data['total_app']}")
except Exception as e:
api_logger.warning(f"获取记忆总量失败: {str(e)}")
# 2. 获取共享统计数据total_app、total_knowledge、total_api_call
common_stats = memory_dashboard_service.get_dashboard_common_stats(db, workspace_id)
neo4j_data.update(common_stats)
api_logger.info(f"成功获取共享统计: app={common_stats['total_app']}, knowledge={common_stats['total_knowledge']}, api_call={common_stats['total_api_call']}")
# 计算昨日对比
# 2. 获取知识库类型统计total_knowledge
try:
changes = memory_dashboard_service.get_dashboard_yesterday_changes(
db=db,
workspace_id=workspace_id,
storage_type=storage_type,
today_data=neo4j_data
from app.services.memory_agent_service import MemoryAgentService
memory_agent_service = MemoryAgentService()
knowledge_stats = await memory_agent_service.get_knowledge_type_stats(
end_user_id=end_user_id,
only_active=True,
current_workspace_id=workspace_id,
db=db
)
neo4j_data.update(changes)
neo4j_data["total_knowledge"] = knowledge_stats.get("total", 0)
api_logger.info(f"成功获取知识库类型统计total: {neo4j_data['total_knowledge']}")
except Exception as e:
api_logger.warning(f"计算neo4j昨日对比失败: {str(e)}")
neo4j_data.update({
"total_memory_change": None,
"total_app_change": None,
"total_knowledge_change": None,
"total_api_call_change": None,
})
api_logger.warning(f"获取知识库类型统计失败: {str(e)}")
# 3. 获取API调用统计total_api_call
try:
# 使用 AppStatisticsService 获取真实的API调用统计
app_stats_service = AppStatisticsService(db)
api_stats = app_stats_service.get_workspace_api_statistics(
workspace_id=workspace_id,
start_date=start_date,
end_date=end_date
)
# 计算总调用次数
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
neo4j_data["total_api_call"] = total_api_calls
api_logger.info(f"成功获取API调用统计: {neo4j_data['total_api_call']}")
except Exception as e:
api_logger.error(f"获取API调用统计失败: {str(e)}")
neo4j_data["total_api_call"] = 0
result["neo4j_data"] = neo4j_data
api_logger.info("成功获取neo4j_data")
@@ -639,36 +692,43 @@ async def dashboard_data(
"total_api_call": None
}
# 1. 获取记忆总量total_memory—— rag 独有逻辑:查询 document 表的 chunk_num
# 获取RAG相关数据
try:
# total_memory: 只统计用户知识库permission_id='Memory'的chunk数
total_chunk = memory_dashboard_service.get_rag_user_kb_total_chunk(db, current_user)
rag_data["total_memory"] = total_chunk
api_logger.info(f"成功获取RAG记忆总量: {total_chunk}")
except Exception as e:
api_logger.warning(f"获取RAG记忆总量失败: {str(e)}")
# 2. 获取共享统计数据total_app、total_knowledge、total_api_call
common_stats = memory_dashboard_service.get_dashboard_common_stats(db, workspace_id)
rag_data.update(common_stats)
api_logger.info(f"成功获取共享统计: app={common_stats['total_app']}, knowledge={common_stats['total_knowledge']}, api_call={common_stats['total_api_call']}")
# 计算昨日对比
try:
changes = memory_dashboard_service.get_dashboard_yesterday_changes(
db=db,
workspace_id=workspace_id,
storage_type=storage_type,
today_data=rag_data
# total_app: 统计当前空间下的所有app数量
# 包含自有app + 被分享给本工作空间的app
from app.services import app_service as _app_svc
_, total_app = _app_svc.AppService(db).list_apps(
workspace_id=workspace_id, include_shared=True, pagesize=1
)
rag_data.update(changes)
rag_data["total_app"] = total_app
# total_knowledge: 使用 total_kb总知识库数
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
rag_data["total_knowledge"] = total_kb
# total_api_call: 使用 AppStatisticsService 获取真实的API调用统计
try:
app_stats_service = AppStatisticsService(db)
api_stats = app_stats_service.get_workspace_api_statistics(
workspace_id=workspace_id,
start_date=start_date,
end_date=end_date
)
# 计算总调用次数
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
rag_data["total_api_call"] = total_api_calls
api_logger.info(f"成功获取RAG模式API调用统计: {rag_data['total_api_call']}")
except Exception as e:
api_logger.warning(f"计算RAG昨日对比失败: {str(e)}")
rag_data.update({
"total_memory_change": None,
"total_app_change": None,
"total_knowledge_change": None,
"total_api_call_change": None,
})
api_logger.warning(f"获取RAG模式API调用统计失败使用默认值: {str(e)}")
rag_data["total_api_call"] = 0
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={total_app}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}")
except Exception as e:
api_logger.warning(f"获取RAG相关数据失败: {str(e)}")
result["rag_data"] = rag_data
api_logger.info("成功获取rag_data")

View File

@@ -26,7 +26,7 @@ from app.services.memory_storage_service import (
analytics_hot_memory_tags,
analytics_recent_activity_stats,
kb_type_distribution,
search_all_batch,
search_all,
search_chunk,
search_detials,
search_dialogue,
@@ -409,10 +409,7 @@ async def search_all_num(
) -> dict:
api_logger.info(f"Search all requested for end_user_id: {end_user_id}")
try:
if not end_user_id:
return success(data={"total": 0}, msg="查询成功")
batch_result = await search_all_batch([end_user_id])
result = {"total": batch_result.get(end_user_id, 0)}
result = await search_all(end_user_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Search all failed: {str(e)}")

View File

@@ -163,7 +163,6 @@ def _get_ontology_service(
api_key=api_key_config.api_key,
base_url=api_key_config.api_base,
is_omni=api_key_config.is_omni,
support_thinking="thinking" in (api_key_config.capability or []),
max_retries=3,
timeout=60.0
)

View File

@@ -124,11 +124,10 @@ async def get_prompt_opt(
skill=data.skill
):
# chunk 是 prompt 的增量内容
yield f"event:message\ndata: {json.dumps(chunk, ensure_ascii=False)}\n\n"
yield f"event:message\ndata: {json.dumps(chunk)}\n\n"
except Exception as e:
yield f"event:error\ndata: {json.dumps(
{"error": str(e)},
ensure_ascii=False
{"error": str(e)}
)}\n\n"
yield "event:end\ndata: {}\n\n"

View File

@@ -453,10 +453,31 @@ async def chat(
# 流式返回
agent_config = agent_config_4_app_release(release)
if not (agent_config.model_parameters.get("deep_thinking", False) and payload.thinking):
agent_config.model_parameters["deep_thinking"] = False
if payload.stream:
# async def event_generator():
# async for event in service.chat_stream(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# ):
# yield event
# return StreamingResponse(
# event_generator(),
# media_type="text/event-stream",
# headers={
# "Cache-Control": "no-cache",
# "Connection": "keep-alive",
# "X-Accel-Buffering": "no"
# }
# )
async def event_generator():
async for event in app_chat_service.agnet_chat_stream(
message=payload.message,
@@ -482,6 +503,20 @@ async def chat(
"X-Accel-Buffering": "no"
}
)
# 非流式返回
# result = await service.chat(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# )
# return success(data=conversation_schema.ChatResponse(**result))
result = await app_chat_service.agnet_chat(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
@@ -540,6 +575,48 @@ async def chat(
)
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
# 多 Agent 流式返回
# if payload.stream:
# async def event_generator():
# async for event in service.multi_agent_chat_stream(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# ):
# yield event
# return StreamingResponse(
# event_generator(),
# media_type="text/event-stream",
# headers={
# "Cache-Control": "no-cache",
# "Connection": "keep-alive",
# "X-Accel-Buffering": "no"
# }
# )
# # 多 Agent 非流式返回
# result = await service.multi_agent_chat(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# )
# return success(data=conversation_schema.ChatResponse(**result))
elif app_type == AppType.WORKFLOW:
config = workflow_config_4_app_release(release)
if not config.id:
@@ -637,8 +714,7 @@ async def config_query(
"app_type": release.app.type,
"variables": release.config.get("variables"),
"memory": release.config.get("memory", {}).get("enabled"),
"features": release.config.get("features"),
"model_parameters": release.config.get("model_parameters")
"features": release.config.get("features")
}
elif release.app.type == AppType.MULTI_AGENT:
content = {

View File

@@ -4,7 +4,7 @@
认证方式: API Key
"""
from fastapi import APIRouter
from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller, end_user_api_controller
from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller
# 创建 V1 API 路由器
service_router = APIRouter()
@@ -16,6 +16,5 @@ service_router.include_router(rag_api_document_controller.router)
service_router.include_router(rag_api_file_controller.router)
service_router.include_router(rag_api_chunk_controller.router)
service_router.include_router(memory_api_controller.router)
service_router.include_router(end_user_api_controller.router)
__all__ = ["service_router"]

View File

@@ -14,7 +14,6 @@ from app.core.response_utils import success
from app.db import get_db
from app.models.app_model import App
from app.models.app_model import AppType
from app.models.app_release_model import AppRelease
from app.repositories import knowledge_repository
from app.repositories.end_user_repository import EndUserRepository
from app.schemas import AppChatRequest, conversation_schema
@@ -62,18 +61,18 @@ async def list_apps():
# return success(data={"received": True}, msg="消息已接收")
def _checkAppConfig(release: AppRelease):
if release.type == AppType.AGENT:
if not release.config:
def _checkAppConfig(app: App):
if app.type == AppType.AGENT:
if not app.current_release.config:
raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
elif release.type == AppType.MULTI_AGENT:
if not release.config:
elif app.type == AppType.MULTI_AGENT:
if not app.current_release.config:
raise BusinessException("Multi-Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
elif release.type == AppType.WORKFLOW:
if not release.config:
elif app.type == AppType.WORKFLOW:
if not app.current_release.config:
raise BusinessException("工作流应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
else:
raise BusinessException("不支持的应用类型", BizCode.APP_TYPE_NOT_SUPPORTED)
raise BusinessException("不支持的应用类型", BizCode.AGENT_CONFIG_MISSING)
@router.post("/chat")
@@ -87,22 +86,10 @@ async def chat(
app_service: Annotated[AppService, Depends(get_app_service)] = None,
message: str = Body(..., description="聊天消息内容"),
):
"""
Agent/Workflow 聊天接口
- 不传 version使用当前生效版本current_release回滚后为回滚目标版本
- 传 version=release_id使用指定版本uuid的历史快照例如 {"version": "{{release_id}}"}
"""
body = await request.json()
payload = AppChatRequest(**body)
app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id)
# 版本切换:指定 release_id 时查找对应历史快照,否则使用当前激活版本
if payload.version is not None:
active_release = app_service.get_release_by_id(app.id, payload.version)
else:
active_release = app.current_release
other_id = payload.user_id
workspace_id = api_key_auth.workspace_id
end_user_repo = EndUserRepository(db)
@@ -140,7 +127,7 @@ async def chat(
storage_type = 'neo4j'
app_type = app.type
# check app config
_checkAppConfig(active_release)
_checkAppConfig(app)
# 获取或创建会话(提前验证)
conversation = conversation_service.create_or_get_conversation(
@@ -155,13 +142,8 @@ async def chat(
# print("="*50)
# print(app.current_release.default_model_config_id)
agent_config = agent_config_4_app_release(active_release)
agent_config = agent_config_4_app_release(app.current_release)
# print(agent_config.default_model_config_id)
# thinking 开关:仅当 agent 配置了 deep_thinking 且请求 thinking=True 时才启用
if not (agent_config.model_parameters.get("deep_thinking", False) and payload.thinking):
agent_config.model_parameters["deep_thinking"] = False
# 流式返回
if payload.stream:
async def event_generator():
@@ -207,7 +189,7 @@ async def chat(
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
elif app_type == AppType.MULTI_AGENT:
# 多 Agent 流式返回
config = multi_agent_config_4_app_release(active_release)
config = multi_agent_config_4_app_release(app.current_release)
if payload.stream:
async def event_generator():
async for event in app_chat_service.multi_agent_chat_stream(
@@ -250,7 +232,7 @@ async def chat(
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
elif app_type == AppType.WORKFLOW:
# 多 Agent 流式返回
config = workflow_config_4_app_release(active_release)
config = workflow_config_4_app_release(app.current_release)
if payload.stream:
async def event_generator():
async for event in app_chat_service.workflow_chat_stream(
@@ -266,7 +248,7 @@ async def chat(
user_rag_memory_id=user_rag_memory_id,
app_id=app.id,
workspace_id=workspace_id,
release_id=active_release.id,
release_id=app.current_release.id,
public=True
):
event_type = event.get("event", "message")
@@ -301,7 +283,7 @@ async def chat(
files=payload.files,
app_id=app.id,
workspace_id=workspace_id,
release_id=active_release.id
release_id=app.current_release.id
)
logger.debug(
"工作流试运行返回结果",
@@ -315,4 +297,6 @@ async def chat(
msg="工作流任务执行成功"
)
else:
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)

View File

@@ -1,92 +0,0 @@
"""End User 服务接口 - 基于 API Key 认证"""
import uuid
from fastapi import APIRouter, Body, Depends, Request
from sqlalchemy.orm import Session
from app.core.api_key_auth import require_api_key
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.response_utils import success
from app.db import get_db
from app.repositories.end_user_repository import EndUserRepository
from app.schemas.api_key_schema import ApiKeyAuth
from app.schemas.memory_api_schema import CreateEndUserRequest, CreateEndUserResponse
from app.services.memory_config_service import MemoryConfigService
router = APIRouter(prefix="/end_user", tags=["V1 - End User API"])
logger = get_business_logger()
@router.post("/create")
@require_api_key(scopes=["memory"])
async def create_end_user(
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
message: str = Body(..., description="Request body"),
):
"""
Create or retrieve an end user for the workspace.
Creates a new end user and connects it to a memory configuration.
If an end user with the same other_id already exists in the workspace,
returns the existing one.
Optionally accepts a memory_config_id to connect the end user to a specific
memory configuration. If not provided, falls back to the workspace default config.
"""
body = await request.json()
payload = CreateEndUserRequest(**body)
workspace_id = api_key_auth.workspace_id
logger.info("Create end user request - other_id: %s, workspace_id: %s", payload.other_id, workspace_id)
# Resolve memory_config_id: explicit > workspace default
memory_config_id = None
config_service = MemoryConfigService(db)
if payload.memory_config_id:
try:
memory_config_id = uuid.UUID(payload.memory_config_id)
except ValueError:
raise BusinessException(
f"Invalid memory_config_id format: {payload.memory_config_id}",
BizCode.INVALID_PARAMETER
)
config = config_service.get_config_with_fallback(memory_config_id, workspace_id)
if not config:
raise BusinessException(
f"Memory config not found: {payload.memory_config_id}",
BizCode.MEMORY_CONFIG_NOT_FOUND
)
memory_config_id = config.config_id
else:
default_config = config_service.get_workspace_default_config(workspace_id)
if default_config:
memory_config_id = default_config.config_id
logger.info(f"Using workspace default memory config: {memory_config_id}")
else:
logger.warning(f"No default memory config found for workspace: {workspace_id}")
end_user_repo = EndUserRepository(db)
end_user = end_user_repo.get_or_create_end_user_with_config(
app_id=api_key_auth.resource_id,
workspace_id=workspace_id,
other_id=payload.other_id,
memory_config_id=memory_config_id,
)
logger.info(f"End user ready: {end_user.id}")
result = {
"id": str(end_user.id),
"other_id": end_user.other_id or "",
"other_name": end_user.other_name or "",
"workspace_id": str(end_user.workspace_id),
"memory_config_id": str(end_user.memory_config_id) if end_user.memory_config_id else None,
}
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")

View File

@@ -6,8 +6,6 @@ from app.core.response_utils import success
from app.db import get_db
from app.schemas.api_key_schema import ApiKeyAuth
from app.schemas.memory_api_schema import (
CreateEndUserRequest,
CreateEndUserResponse,
ListConfigsResponse,
MemoryReadRequest,
MemoryReadResponse,
@@ -115,31 +113,3 @@ async def list_memory_configs(
logger.info(f"Listed {result['total']} configs for workspace: {api_key_auth.workspace_id}")
return success(data=ListConfigsResponse(**result).model_dump(), msg="Configs listed successfully")
@router.post("/end_users")
@require_api_key(scopes=["memory"])
async def create_end_user(
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""
Create an end user.
Creates a new end user for the authorized workspace.
If an end user with the same other_id already exists, returns the existing one.
"""
body = await request.json()
payload = CreateEndUserRequest(**body)
logger.info(f"Create end user request - other_id: {payload.other_id}, workspace_id: {api_key_auth.workspace_id}")
memory_api_service = MemoryAPIService(db)
result = memory_api_service.create_end_user(
workspace_id=api_key_auth.workspace_id,
other_id=payload.other_id,
)
logger.info(f"End user ready: {result['id']}")
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")

View File

@@ -11,14 +11,17 @@ LangChain Agent 封装
import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
from app.db import get_db
from app.core.logging_config import get_business_logger
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType, ModelProvider
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
from langchain.agents import create_agent
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.tools import BaseTool
from langgraph.errors import GraphRecursionError
from app.core.logging_config import get_business_logger
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType
logger = get_business_logger()
@@ -38,10 +41,7 @@ class LangChainAgent:
tools: Optional[Sequence[BaseTool]] = None,
streaming: bool = False,
max_iterations: Optional[int] = None, # 最大迭代次数None 表示自动计算)
max_tool_consecutive_calls: int = 3, # 单个工具最大连续调用次数
deep_thinking: bool = False, # 是否启用深度思考模式
thinking_budget_tokens: Optional[int] = None, # 深度思考 token 预算
capability: Optional[List[str]] = None # 模型能力列表,用于校验是否支持深度思考
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
):
"""初始化 LangChain Agent
@@ -64,7 +64,6 @@ class LangChainAgent:
self.streaming = streaming
self.is_omni = is_omni
self.max_tool_consecutive_calls = max_tool_consecutive_calls
self.deep_thinking = deep_thinking and ("thinking" in (capability or []))
# 工具调用计数器:记录每个工具的连续调用次数
self.tool_call_counter: Dict[str, int] = {}
@@ -87,13 +86,6 @@ class LangChainAgent:
f"auto_calculated={max_iterations is None}"
)
# 根据 capability 校验是否真正支持深度思考
actual_deep_thinking = self.deep_thinking
if deep_thinking and not actual_deep_thinking:
logger.warning(
f"模型 {model_name} 不支持深度思考capability 中无 'thinking'),已自动关闭 deep_thinking"
)
# 创建 RedBearLLM支持多提供商
model_config = RedBearModelConfig(
model_name=model_name,
@@ -101,13 +93,10 @@ class LangChainAgent:
api_key=api_key,
base_url=api_base,
is_omni=is_omni,
deep_thinking=actual_deep_thinking,
thinking_budget_tokens=thinking_budget_tokens if actual_deep_thinking else None,
support_thinking="thinking" in (capability or []),
extra_params={
"temperature": temperature,
"max_tokens": max_tokens,
"streaming": streaming
"streaming": streaming # 使用参数控制流式
}
)
@@ -237,9 +226,10 @@ class LangChainAgent:
Returns:
List[BaseMessage]: 消息列表
"""
messages:list = [SystemMessage(content=self.system_prompt)]
messages = []
# 添加系统提示词
messages.append(SystemMessage(content=self.system_prompt))
# 添加历史消息
if history:
@@ -264,33 +254,6 @@ class LangChainAgent:
return messages
@staticmethod
def _extract_tokens_from_message(msg) -> int:
"""从 AIMessage 或类似对象中提取 total_tokens兼容多种 provider 格式
支持的格式:
- response_metadata.token_usage.total_tokens (OpenAI/ChatOpenAI)
- response_metadata.usage.total_tokens (部分 provider)
- usage_metadata.total_tokens (LangChain 新版)
"""
total = 0
# 1. response_metadata
response_meta = getattr(msg, "response_metadata", None)
if response_meta and isinstance(response_meta, dict):
# 尝试 token_usage 路径
token_usage = response_meta.get("token_usage") or response_meta.get("usage", {})
if isinstance(token_usage, dict):
total = token_usage.get("total_tokens", 0)
# 2. usage_metadataLangChain 新版 AIMessage 属性)
if not total:
usage_meta = getattr(msg, "usage_metadata", None)
if usage_meta:
if isinstance(usage_meta, dict):
total = usage_meta.get("total_tokens", 0)
else:
total = getattr(usage_meta, "total_tokens", 0)
return total or 0
def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
构建多模态消息内容
@@ -325,23 +288,17 @@ class LangChainAgent:
return content_parts
@staticmethod
def _extract_reasoning_content(msg) -> str:
"""从 AIMessage 中提取深度思考内容reasoning_content
所有 provider 统一通过 additional_kwargs.reasoning_content 传递:
- DeepSeek-R1 / QwQ: 原生字段
- Volcano (Doubao-thinking): 由 VolcanoChatOpenAI 从 delta.reasoning_content 注入
"""
additional = getattr(msg, "additional_kwargs", None) or {}
return additional.get("reasoning_content") or additional.get("reasoning", "")
async def chat(
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
files: Optional[List[Dict[str, Any]]] = None
end_user_id: Optional[str] = None,
config_id: Optional[str] = None, # 添加这个参数
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
memory_flag: Optional[bool] = True,
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
) -> Dict[str, Any]:
"""执行对话
@@ -349,12 +306,31 @@ class LangChainAgent:
message: 用户消息
history: 历史消息列表 [{"role": "user/assistant", "content": "..."}]
context: 上下文信息(如知识库检索结果)
files: 多模态文件
Returns:
Dict: 包含 content 和元数据的字典
"""
message_chat = message
start_time = time.time()
actual_config_id = config_id
# If config_id is None, try to get from end_user's connected config
if actual_config_id is None and end_user_id:
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
db = next(get_db())
try:
connected_config = get_end_user_connected_config(end_user_id, db)
actual_config_id = connected_config.get("memory_config_id")
except Exception as e:
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
finally:
db.close()
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
try:
# 准备消息列表(支持多模态)
messages = self._prepare_messages(message, history, context, files)
@@ -378,7 +354,7 @@ class LangChainAgent:
{"messages": messages},
config={"recursion_limit": self.max_iterations}
)
except (RecursionError, GraphRecursionError) as e:
except RecursionError as e:
logger.warning(
f"Agent 达到最大迭代次数限制 ({self.max_iterations}),可能存在工具调用循环",
extra={"error": str(e)}
@@ -401,7 +377,6 @@ class LangChainAgent:
logger.debug(f"输出消息数量: {len(output_messages)}")
total_tokens = 0
reasoning_content = ""
for msg in reversed(output_messages):
if isinstance(msg, AIMessage):
logger.debug(f"找到 AI 消息content 类型: {type(msg.content)}")
@@ -436,13 +411,16 @@ class LangChainAgent:
else:
content = str(msg.content)
logger.debug(f"转换为字符串: {content[:100]}...")
total_tokens = self._extract_tokens_from_message(msg)
reasoning_content = self._extract_reasoning_content(msg) if self.deep_thinking else ""
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
break
logger.info(f"最终提取的内容长度: {len(content)}")
elapsed_time = time.time() - start_time
if memory_flag:
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id,
actual_config_id)
response = {
"content": content,
"model": self.model_name,
@@ -453,8 +431,6 @@ class LangChainAgent:
"total_tokens": total_tokens
}
}
if reasoning_content:
response["reasoning_content"] = reasoning_content
logger.debug(
"Agent 调用完成",
@@ -475,20 +451,22 @@ class LangChainAgent:
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
files: Optional[List[Dict[str, Any]]] = None
) -> AsyncGenerator[str | int | dict[str, str], None]:
end_user_id: Optional[str] = None,
config_id: Optional[str] = None,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
memory_flag: Optional[bool] = True,
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
) -> AsyncGenerator[str, None]:
"""执行流式对话
Args:
message: 用户消息
history: 历史消息列表
context: 上下文信息
files: 多模态文件
Yields:
str: 消息内容块
int: token 统计
Dict: 深度思考内容 {"type": "reasoning", "content": "..."}
"""
logger.info("=" * 80)
logger.info(" chat_stream 方法开始执行")
@@ -496,6 +474,23 @@ class LangChainAgent:
logger.info(f" Has tools: {bool(self.tools)}")
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
logger.info("=" * 80)
message_chat = message
actual_config_id = config_id
# If config_id is None, try to get from end_user's connected config
if actual_config_id is None and end_user_id:
try:
db = next(get_db())
try:
connected_config = get_end_user_connected_config(end_user_id, db)
actual_config_id = connected_config.get("memory_config_id")
except Exception as e:
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
finally:
db.close()
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
try:
# 准备消息列表(支持多模态)
messages = self._prepare_messages(message, history, context, files)
@@ -505,19 +500,17 @@ class LangChainAgent:
)
chunk_count = 0
yielded_content = False
# 统一使用 agent 的 astream_events 实现流式输出
logger.debug("使用 Agent astream_events 实现流式输出")
full_content = ''
full_reasoning = ''
try:
last_event = {}
async for event in self.agent.astream_events(
{"messages": messages},
version="v2",
config={"recursion_limit": self.max_iterations}
):
last_event = event
chunk_count += 1
kind = event.get("event")
@@ -526,18 +519,12 @@ class LangChainAgent:
# LLM 流式输出
chunk = event.get("data", {}).get("chunk")
if chunk and hasattr(chunk, "content"):
# 提取深度思考内容(仅在启用深度思考时)
if self.deep_thinking:
reasoning_chunk = self._extract_reasoning_content(chunk)
if reasoning_chunk:
full_reasoning += reasoning_chunk
yield {"type": "reasoning", "content": reasoning_chunk}
# 处理多模态响应content 可能是字符串或列表
chunk_content = chunk.content
if isinstance(chunk_content, str) and chunk_content:
full_content += chunk_content
yield chunk_content
yielded_content = True
elif isinstance(chunk_content, list):
# 多模态响应:提取文本部分
for item in chunk_content:
@@ -548,32 +535,29 @@ class LangChainAgent:
if text:
full_content += text
yield text
yielded_content = True
# OpenAI 格式: {"type": "text", "text": "..."}
elif item.get("type") == "text":
text = item.get("text", "")
if text:
full_content += text
yield text
yielded_content = True
elif isinstance(item, str):
full_content += item
yield item
yielded_content = True
elif kind == "on_llm_stream":
# 另一种 LLM 流式事件
chunk = event.get("data", {}).get("chunk")
if chunk:
if hasattr(chunk, "content"):
# 提取深度思考内容(仅在启用深度思考时)
if self.deep_thinking:
reasoning_chunk = self._extract_reasoning_content(chunk)
if reasoning_chunk:
full_reasoning += reasoning_chunk
yield {"type": "reasoning", "content": reasoning_chunk}
chunk_content = chunk.content
if isinstance(chunk_content, str) and chunk_content:
full_content += chunk_content
yield chunk_content
yielded_content = True
elif isinstance(chunk_content, list):
# 多模态响应:提取文本部分
for item in chunk_content:
@@ -584,18 +568,22 @@ class LangChainAgent:
if text:
full_content += text
yield text
yielded_content = True
# OpenAI 格式: {"type": "text", "text": "..."}
elif item.get("type") == "text":
text = item.get("text", "")
if text:
full_content += text
yield text
yielded_content = True
elif isinstance(item, str):
full_content += item
yield item
yielded_content = True
elif isinstance(chunk, str):
full_content += chunk
yield chunk
yielded_content = True
# 记录工具调用(可选)
elif kind == "on_tool_start":
@@ -605,20 +593,19 @@ class LangChainAgent:
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
# 统计token消耗
output_messages = last_event.get("data", {}).get("output", {}).get("messages", [])
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
for msg in reversed(output_messages):
if isinstance(msg, AIMessage):
stream_total_tokens = self._extract_tokens_from_message(msg)
logger.info(f"流式 token 统计: total_tokens={stream_total_tokens}")
yield stream_total_tokens
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get(
"total_tokens",
0
) if response_meta else 0
yield total_tokens
break
except GraphRecursionError:
logger.warning(
f"Agent 达到最大迭代次数限制 ({self.max_iterations}),模型可能不支持正确的工具调用停止判断"
)
if not full_content:
yield "抱歉,我在处理您的请求时遇到了问题(已达最大处理步骤限制)。请尝试简化问题或更换模型后重试。"
if memory_flag:
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,
actual_config_id)
except Exception as e:
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
raise

View File

@@ -19,7 +19,6 @@ class BizCode(IntEnum):
TENANT_NOT_FOUND = 3002
WORKSPACE_NO_ACCESS = 3003
WORKSPACE_INVITE_NOT_FOUND = 3004
WORKSPACE_ACCESS_DENIED = 3005
# API Key 管理3xxx
API_KEY_NOT_FOUND = 3007
API_KEY_DUPLICATE_NAME = 3008
@@ -41,7 +40,6 @@ class BizCode(IntEnum):
FILE_NOT_FOUND = 4006
APP_NOT_FOUND = 4007
RELEASE_NOT_FOUND = 4008
USER_NO_ACCESS = 4009
# 冲突/状态5xxx
DUPLICATE_NAME = 5001
@@ -115,11 +113,8 @@ HTTP_MAPPING = {
BizCode.FORBIDDEN: 403,
BizCode.TENANT_NOT_FOUND: 400,
BizCode.WORKSPACE_NO_ACCESS: 403,
BizCode.WORKSPACE_INVITE_NOT_FOUND: 400,
BizCode.WORKSPACE_ACCESS_DENIED: 403,
BizCode.NOT_FOUND: 400,
BizCode.USER_NOT_FOUND: 200,
BizCode.USER_NO_ACCESS: 401,
BizCode.WORKSPACE_NOT_FOUND: 400,
BizCode.MODEL_NOT_FOUND: 400,
BizCode.KNOWLEDGE_NOT_FOUND: 400,

View File

@@ -1,408 +0,0 @@
"""
Perceptual Memory Retrieval Node & Service
Provides PerceptualSearchService for searching perceptual memories (vision, audio,
text, conversation) from Neo4j using keyword fulltext + embedding semantic search
with BM25+embedding fusion reranking.
Also provides the perceptual_retrieve_node for use as a LangGraph node.
"""
import asyncio
import math
from typing import List, Dict, Any, Optional
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import ReadState
from app.core.memory.utils.data.text_utils import escape_lucene_query
from app.repositories.neo4j.graph_search import (
search_perceptual,
search_perceptual_by_embedding,
)
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
logger = get_agent_logger(__name__)
class PerceptualSearchService:
"""
感知记忆检索服务。
封装关键词全文检索 + 向量语义检索 + BM25/embedding 融合排序的完整流程。
调用方只需提供 query / keywords、end_user_id、memory_config即可获得
格式化并排序后的感知记忆列表和拼接文本。
Usage:
service = PerceptualSearchService(end_user_id=..., memory_config=...)
results = await service.search(query="...", keywords=[...], limit=10)
# results = {"memories": [...], "content": "...", "keyword_raw": N, "embedding_raw": M}
"""
DEFAULT_ALPHA = 0.6
DEFAULT_CONTENT_SCORE_THRESHOLD = 0.5
def __init__(
self,
end_user_id: str,
memory_config: Any,
alpha: float = DEFAULT_ALPHA,
content_score_threshold: float = DEFAULT_CONTENT_SCORE_THRESHOLD,
):
self.end_user_id = end_user_id
self.memory_config = memory_config
self.alpha = alpha
self.content_score_threshold = content_score_threshold
async def search(
self,
query: str,
keywords: Optional[List[str]] = None,
limit: int = 10,
) -> Dict[str, Any]:
"""
执行感知记忆检索(关键词 + 向量并行),融合排序后返回结果。
对 embedding 命中但 keyword 未命中的结果,补查全文索引获取 BM25 分数,
确保所有结果都同时具备 BM25 和 embedding 两个维度的评分。
Args:
query: 原始用户查询(用于向量检索和 BM25 补查)
keywords: 关键词列表(用于全文检索),为 None 时使用 [query]
limit: 最大返回数量
Returns:
{
"memories": [格式化后的记忆 dict, ...],
"content": "拼接的纯文本摘要",
"keyword_raw": int,
"embedding_raw": int,
}
"""
if keywords is None:
keywords = [query] if query else []
connector = Neo4jConnector()
try:
kw_task = self._keyword_search(connector, keywords, limit)
emb_task = self._embedding_search(connector, query, limit)
kw_results, emb_results = await asyncio.gather(
kw_task, emb_task, return_exceptions=True
)
if isinstance(kw_results, Exception):
logger.warning(f"[PerceptualSearch] keyword search error: {kw_results}")
kw_results = []
if isinstance(emb_results, Exception):
logger.warning(f"[PerceptualSearch] embedding search error: {emb_results}")
emb_results = []
# 补查 BM25找出 embedding 命中但 keyword 未命中的 id
# 用原始 query 对这些节点补查全文索引拿 BM25 score
kw_ids = {r.get("id") for r in kw_results if r.get("id")}
emb_only_ids = {r.get("id") for r in emb_results if r.get("id") and r.get("id") not in kw_ids}
if emb_only_ids and query:
backfill = await self._bm25_backfill(connector, query, emb_only_ids, limit)
# 把补查到的 BM25 score 注入到 embedding 结果中
backfill_map = {r["id"]: r.get("score", 0) for r in backfill}
for r in emb_results:
rid = r.get("id", "")
if rid in backfill_map:
r["bm25_backfill_score"] = backfill_map[rid]
logger.info(
f"[PerceptualSearch] BM25 backfill: {len(emb_only_ids)} embedding-only ids, "
f"{len(backfill_map)} got BM25 scores"
)
reranked = self._rerank(kw_results, emb_results, limit)
memories = []
content_parts = []
for record in reranked:
fmt = self._format_result(record)
fmt["score"] = round(record.get("content_score", 0), 4)
memories.append(fmt)
content_parts.append(self._build_content_text(fmt))
logger.info(
f"[PerceptualSearch] {len(memories)} results after rerank "
f"(keyword_raw={len(kw_results)}, embedding_raw={len(emb_results)})"
)
return {
"memories": memories,
"content": "\n\n".join(content_parts),
"keyword_raw": len(kw_results),
"embedding_raw": len(emb_results),
}
finally:
await connector.close()
async def _bm25_backfill(
self,
connector: Neo4jConnector,
query: str,
target_ids: set,
limit: int,
) -> List[dict]:
"""
对指定 id 集合补查全文索引 BM25 score。
用原始 query 查全文索引,只保留 id 在 target_ids 中的结果。
"""
escaped = escape_lucene_query(query)
if not escaped.strip():
return []
try:
r = await search_perceptual(
connector=connector, query=escaped,
end_user_id=self.end_user_id,
limit=limit * 5, # 多查一些以提高命中率
)
all_hits = r.get("perceptuals", [])
return [h for h in all_hits if h.get("id") in target_ids]
except Exception as e:
logger.warning(f"[PerceptualSearch] BM25 backfill failed: {e}")
return []
async def _keyword_search(
self,
connector: Neo4jConnector,
keywords: List[str],
limit: int,
) -> List[dict]:
"""并发对每个关键词做全文检索,去重后按 score 降序返回 top N 原始结果。"""
seen_ids: set = set()
all_results: List[dict] = []
async def _one(kw: str):
escaped = escape_lucene_query(kw)
if not escaped.strip():
return []
r = await search_perceptual(
connector=connector, query=escaped,
end_user_id=self.end_user_id, limit=limit,
)
return r.get("perceptuals", [])
tasks = [_one(kw) for kw in keywords[:10]]
batch = await asyncio.gather(*tasks, return_exceptions=True)
for result in batch:
if isinstance(result, Exception):
logger.warning(f"[PerceptualSearch] keyword sub-query error: {result}")
continue
for rec in result:
rid = rec.get("id", "")
if rid and rid not in seen_ids:
seen_ids.add(rid)
all_results.append(rec)
all_results.sort(key=lambda x: float(x.get("score", 0)), reverse=True)
return all_results[:limit]
async def _embedding_search(
self,
connector: Neo4jConnector,
query_text: str,
limit: int,
) -> List[dict]:
"""向量语义检索,返回原始结果(不做阈值过滤)。"""
try:
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.services.memory_config_service import MemoryConfigService
with get_db_context() as db:
cfg = MemoryConfigService(db).get_embedder_config(
str(self.memory_config.embedding_model_id)
)
client = OpenAIEmbedderClient(RedBearModelConfig(**cfg))
r = await search_perceptual_by_embedding(
connector=connector, embedder_client=client,
query_text=query_text, end_user_id=self.end_user_id,
limit=limit,
)
return r.get("perceptuals", [])
except Exception as e:
logger.warning(f"[PerceptualSearch] embedding search failed: {e}")
return []
def _rerank(
self,
keyword_results: List[dict],
embedding_results: List[dict],
limit: int,
) -> List[dict]:
"""BM25 + embedding 融合排序。
对 embedding 结果中带有 bm25_backfill_score 的条目,
将其与 keyword 结果合并后统一归一化,确保 BM25 分数在同一尺度上。
"""
# 把补查的 BM25 score 合并到 keyword_results 中统一归一化
emb_backfill_items = []
for item in embedding_results:
backfill_score = item.get("bm25_backfill_score")
if backfill_score is not None and item.get("id"):
emb_backfill_items.append({"id": item["id"], "score": backfill_score})
# 合并后统一归一化 BM25 scores
all_bm25_items = keyword_results + emb_backfill_items
all_bm25_items = self._normalize_scores(all_bm25_items)
# 建立 id -> normalized BM25 score 的映射
bm25_norm_map: Dict[str, float] = {}
for item in all_bm25_items:
item_id = item.get("id", "")
if item_id:
bm25_norm_map[item_id] = float(item.get("normalized_score", 0))
# 归一化 embedding scores
embedding_results = self._normalize_scores(embedding_results)
# 合并
combined: Dict[str, dict] = {}
for item in keyword_results:
item_id = item.get("id", "")
if not item_id:
continue
combined[item_id] = item.copy()
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
combined[item_id]["embedding_score"] = 0.0
for item in embedding_results:
item_id = item.get("id", "")
if not item_id:
continue
if item_id in combined:
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
else:
combined[item_id] = item.copy()
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
for item in combined.values():
bm25 = float(item.get("bm25_score", 0) or 0)
emb = float(item.get("embedding_score", 0) or 0)
item["content_score"] = self.alpha * bm25 + (1 - self.alpha) * emb
results = list(combined.values())
before = len(results)
results = [r for r in results if r["content_score"] >= self.content_score_threshold]
results.sort(key=lambda x: x["content_score"], reverse=True)
results = results[:limit]
logger.info(
f"[PerceptualSearch] rerank: merged={before}, after_threshold={len(results)} "
f"(alpha={self.alpha}, threshold={self.content_score_threshold})"
)
return results
@staticmethod
def _normalize_scores(items: List[dict], field: str = "score") -> List[dict]:
"""Z-score + sigmoid 归一化。"""
if not items:
return items
scores = [float(it.get(field, 0) or 0) for it in items]
if len(scores) <= 1:
for it in items:
it[f"normalized_{field}"] = 1.0
return items
mean = sum(scores) / len(scores)
var = sum((s - mean) ** 2 for s in scores) / len(scores)
std = math.sqrt(var)
if std == 0:
for it in items:
it[f"normalized_{field}"] = 1.0
else:
for it, s in zip(items, scores):
z = (s - mean) / std
it[f"normalized_{field}"] = 1 / (1 + math.exp(-z))
return items
@staticmethod
def _format_result(record: dict) -> dict:
return {
"id": record.get("id", ""),
"perceptual_type": record.get("perceptual_type", ""),
"file_name": record.get("file_name", ""),
"file_path": record.get("file_path", ""),
"summary": record.get("summary", ""),
"topic": record.get("topic", ""),
"domain": record.get("domain", ""),
"keywords": record.get("keywords", []),
"created_at": str(record.get("created_at", "")),
"file_type": record.get("file_type", ""),
"score": record.get("score", 0),
}
@staticmethod
def _build_content_text(formatted: dict) -> str:
parts = []
if formatted["summary"]:
parts.append(formatted["summary"])
if formatted["topic"]:
parts.append(f"[主题: {formatted['topic']}]")
if formatted["keywords"]:
kw_list = formatted["keywords"]
if isinstance(kw_list, list):
parts.append(f"[关键词: {', '.join(kw_list)}]")
if formatted["file_name"]:
parts.append(f"[文件: {formatted['file_name']}]")
return " ".join(parts)
def _extract_keywords_from_problems(problem_extension: dict) -> List[str]:
"""Extract search keywords from problem extension results."""
keywords = []
context = problem_extension.get("context", {})
if isinstance(context, dict):
for original_q, extended_qs in context.items():
keywords.append(original_q)
if isinstance(extended_qs, list):
keywords.extend(extended_qs)
return keywords
async def perceptual_retrieve_node(state: ReadState) -> ReadState:
"""
LangGraph node: perceptual memory retrieval.
Uses PerceptualSearchService to run keyword + embedding search with
BM25 fusion reranking, then writes results to state['perceptual_data'].
"""
end_user_id = state.get("end_user_id", "")
problem_extension = state.get("problem_extension", {})
original_query = state.get("data", "")
memory_config = state.get("memory_config", None)
logger.info(f"Perceptual_Retrieve: start, end_user_id={end_user_id}")
keywords = _extract_keywords_from_problems(problem_extension)
if not keywords:
keywords = [original_query] if original_query else []
logger.info(f"Perceptual_Retrieve: {len(keywords)} keywords extracted")
service = PerceptualSearchService(
end_user_id=end_user_id,
memory_config=memory_config,
)
search_result = await service.search(
query=original_query,
keywords=keywords,
limit=10,
)
result = {
"memories": search_result["memories"],
"content": search_result["content"],
"_intermediate": {
"type": "perceptual_retrieve",
"title": "感知记忆检索",
"data": search_result["memories"],
"query": original_query,
"result_count": len(search_result["memories"]),
},
}
return {"perceptual_data": result}

View File

@@ -263,6 +263,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
logger.info(f"Problem extension result: {aggregated_dict}")
# Emit intermediate output for frontend
print(time.time() - start)
result = {
"context": aggregated_dict,
"original": data,

View File

@@ -1,11 +1,7 @@
import asyncio
import os
import time
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
PerceptualSearchService,
)
from app.core.memory.agent.models.summary_models import (
RetrieveSummaryResponse,
SummaryResponse,
@@ -343,45 +339,11 @@ async def Input_Summary(state: ReadState) -> ReadState:
try:
if storage_type != "rag":
async def _perceptual_search():
service = PerceptualSearchService(
end_user_id=end_user_id,
memory_config=memory_config,
)
return await service.search(query=data, limit=5)
hybrid_task = SearchService().execute_hybrid_search(
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(
**search_params,
memory_config=memory_config,
expand_communities=False,
expand_communities=False, # 路径 "2" 只需要 community 的 summary 文本,不展开到 Statement
)
perceptual_task = _perceptual_search()
gather_results = await asyncio.gather(
hybrid_task, perceptual_task, return_exceptions=True
)
hybrid_result = gather_results[0]
perceptual_results = gather_results[1]
# 处理 hybrid search 异常
if isinstance(hybrid_result, Exception):
raise hybrid_result
retrieve_info, question, raw_results = hybrid_result
# 处理感知记忆结果
if isinstance(perceptual_results, Exception):
logger.warning(f"[Input_Summary] perceptual search failed: {perceptual_results}")
perceptual_results = []
# 拼接感知记忆内容到 retrieve_info
if perceptual_results and isinstance(perceptual_results, dict):
perceptual_content = perceptual_results.get("content", "")
if perceptual_content:
retrieve_info = f"{retrieve_info}\n\n<history-files>\n{perceptual_content}"
count = len(perceptual_results.get("memories", []))
logger.info(f"[Input_Summary] appended {count} perceptual memories (reranked)")
# 调试:打印 community 检索结果数量
if raw_results and isinstance(raw_results, dict):
reranked = raw_results.get('reranked_results', {})
@@ -409,7 +371,10 @@ async def Input_Summary(state: ReadState) -> ReadState:
"error": str(e)
}
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('检索', duration)
return {"summary": summary}
@@ -447,20 +412,8 @@ async def Retrieve_Summary(state: ReadState) -> ReadState:
retrieve_info_str = list(set(retrieve_info_str))
retrieve_info_str = '\n'.join(retrieve_info_str)
# Merge perceptual memory content
perceptual_data = state.get("perceptual_data", {})
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
if perceptual_content:
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
aimessages = await summary_llm(
state,
history,
retrieve_info_str,
'direct_summary_prompt.jinja2',
'retrieve_summary', RetrieveSummaryResponse,
"1"
)
aimessages = await summary_llm(state, history, retrieve_info_str,
'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1")
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
await summary_redis_save(state, aimessages)
if aimessages == '':
@@ -505,12 +458,6 @@ async def Summary(state: ReadState) -> ReadState:
retrieve_info_str += i + '\n'
history = await summary_history(state)
# Merge perceptual memory content
perceptual_data = state.get("perceptual_data", {})
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
if perceptual_content:
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
data = {
"query": query,
"history": history,
@@ -561,13 +508,6 @@ async def Summary_fails(state: ReadState) -> ReadState:
if key == 'answer_small':
for i in value:
retrieve_info_str += i + '\n'
# Merge perceptual memory content
perceptual_data = state.get("perceptual_data", {})
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
if perceptual_content:
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
data = {
"query": query,
"history": history,

View File

@@ -15,10 +15,7 @@ from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import (
Problem_Extension,
)
from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import (
retrieve_nodes,
)
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
perceptual_retrieve_node,
retrieve,
)
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
Input_Summary,
@@ -56,9 +53,8 @@ async def make_read_graph():
workflow.add_node("Split_The_Problem", Split_The_Problem)
workflow.add_node("Problem_Extension", Problem_Extension)
workflow.add_node("Input_Summary", Input_Summary)
workflow.add_node("Retrieve", retrieve_nodes)
# workflow.add_node("Retrieve", retrieve)
workflow.add_node("Perceptual_Retrieve", perceptual_retrieve_node)
# workflow.add_node("Retrieve", retrieve_nodes)
workflow.add_node("Retrieve", retrieve)
workflow.add_node("Verify", Verify)
workflow.add_node("Retrieve_Summary", Retrieve_Summary)
workflow.add_node("Summary", Summary)
@@ -69,15 +65,14 @@ async def make_read_graph():
workflow.add_conditional_edges("content_input", Split_continue)
workflow.add_edge("Input_Summary", END)
workflow.add_edge("Split_The_Problem", "Problem_Extension")
# After Problem_Extension, retrieve perceptual memory first, then main Retrieve
workflow.add_edge("Problem_Extension", "Perceptual_Retrieve")
workflow.add_edge("Perceptual_Retrieve", "Retrieve")
workflow.add_edge("Problem_Extension", "Retrieve")
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
workflow.add_edge("Retrieve_Summary", END)
workflow.add_conditional_edges("Verify", Verify_continue)
workflow.add_edge("Summary_fails", END)
workflow.add_edge("Summary", END)
'''-----'''
# workflow.add_edge("Retrieve", END)
# Compile workflow
@@ -85,5 +80,7 @@ async def make_read_graph():
yield graph
except Exception as e:
logger.error(f"创建工作流失败: {e}")
print(f"创建工作流失败: {e}")
raise
finally:
print("工作流创建完成")

View File

@@ -12,6 +12,7 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
from app.utils.config_utils import resolve_config_id
@@ -20,6 +21,25 @@ logger = get_agent_logger(__name__)
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
"""
Write messages to RAG storage system
Combines user and AI messages into a single string format and stores them
in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval.
Args:
end_user_id: User identifier for the conversation
user_message: User's input message content
ai_message: AI's response message content
user_rag_memory_id: RAG memory identifier for storage location
"""
# RAG mode: combine messages into string format (maintain original logic)
combined_message = f"user: {user_message}\nassistant: {ai_message}"
await write_rag(end_user_id, combined_message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
async def write(
storage_type,
end_user_id,
@@ -98,7 +118,7 @@ async def write(
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
async def term_memory_save(end_user_id, strategy_type, scope):
async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope):
"""
Save long-term memory data to database
@@ -107,8 +127,10 @@ async def term_memory_save(end_user_id, strategy_type, scope):
to long-term memory storage.
Args:
long_term_messages: Long-term message data to be saved
actual_config_id: Configuration identifier for memory settings
end_user_id: User identifier for memory association
strategy_type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
scope: Scope/window size for memory processing
"""
with get_db_context() as db_session:
@@ -116,10 +138,7 @@ async def term_memory_save(end_user_id, strategy_type, scope):
from app.core.memory.agent.utils.redis_tool import write_store
result = write_store.get_session_by_userid(end_user_id)
if not result:
logger.warning(f"No write data found for user {end_user_id}")
return
if strategy_type in [AgentMemory_Long_Term.STRATEGY_CHUNK, AgentMemory_Long_Term.STRATEGY_AGGREGATE]:
if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
data = await format_parsing(result, "dict")
chunk_data = data[:scope]
if len(chunk_data) == scope:
@@ -132,6 +151,9 @@ async def term_memory_save(end_user_id, strategy_type, scope):
logger.info(f'写入短长期:')
"""Window-based dialogue processing"""
async def window_dialogue(end_user_id, langchain_messages, memory_config, scope):
"""
Process dialogue based on window size and write to Neo4j
@@ -145,33 +167,40 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope)
langchain_messages: Original message data list
scope: Window size determining when to trigger long-term storage
"""
is_end_user_has_history = count_store.get_sessions_count(end_user_id)
if is_end_user_has_history:
end_user_visit_count, redis_messages = is_end_user_has_history
else:
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
return
end_user_visit_count += 1
if end_user_visit_count < scope:
redis_messages.extend(langchain_messages)
count_store.update_sessions_count(end_user_id, end_user_visit_count, redis_messages)
else:
scope = scope
is_end_user_id = count_store.get_sessions_count(end_user_id)
if is_end_user_id is not False:
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
redis_messages = count_store.get_sessions_count(end_user_id)[1]
if is_end_user_id and int(is_end_user_id) != int(scope):
is_end_user_id += 1
langchain_messages += redis_messages
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
elif int(is_end_user_id) == int(scope):
logger.info('写入长期记忆NEO4J')
redis_messages.extend(langchain_messages)
formatted_messages = redis_messages
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
if hasattr(memory_config, 'config_id'):
config_id = memory_config.config_id
else:
config_id = memory_config
write_message_task.delay(
end_user_id, # end_user_id: User ID
redis_messages, # message: JSON string format message list
config_id, # config_id: Configuration ID string
AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j"
"" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
await write(
AgentMemory_Long_Term.STORAGE_NEO4J,
end_user_id,
"",
"",
None,
end_user_id,
config_id,
formatted_messages
)
count_store.update_sessions_count(end_user_id, 0, [])
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
else:
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
"""Time-based memory processing"""
async def memory_long_term_storage(end_user_id, memory_config, time):
@@ -262,7 +291,9 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
return result_dict
except Exception as e:
logger.error(f"[aggregate_judgment] 发生错误: {e}", exc_info=True)
print(f"[aggregate_judgment] 发生错误: {e}")
import traceback
traceback.print_exc()
return {
"is_same_event": False,

View File

@@ -1,25 +1,49 @@
import asyncio
import json
import sys
import warnings
from contextlib import asynccontextmanager
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from app.db import get_db, get_db_context
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
aggregate_judgment
from app.core.memory.agent.utils.redis_tool import write_store
from app.db import get_db_context
from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import write_rag
warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__)
if sys.platform.startswith("win"):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
async def long_term_storage(
long_term_type: str,
langchain_messages: list,
memory_config_id: str,
end_user_id: str,
scope: int = 6
):
@asynccontextmanager
async def make_write_graph():
"""
Create a write graph workflow for memory operations.
Args:
user_id: User identifier
tools: MCP tools loaded from session
apply_id: Application identifier
end_user_id: Group identifier
memory_config: MemoryConfig object containing all configuration
"""
workflow = StateGraph(WriteState)
workflow.add_node("save_neo4j", write_node)
workflow.add_edge(START, "save_neo4j")
workflow.add_edge("save_neo4j", END)
graph = workflow.compile()
yield graph
async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '',
end_user_id: str = '', scope: int = 6):
"""
Handle long-term memory storage with different strategies
@@ -29,39 +53,33 @@ async def long_term_storage(
Args:
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
langchain_messages: List of messages to store
memory_config_id: Memory configuration identifier
memory_config: Memory configuration identifier
end_user_id: User group identifier
scope: Scope parameter for chunk-based storage (default: 6)
"""
if langchain_messages is None:
langchain_messages = []
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
aggregate_judgment
from app.core.memory.agent.utils.redis_tool import write_store
write_store.save_session_write(end_user_id, langchain_messages)
# 获取数据库会话
with get_db_context() as db_session:
config_service = MemoryConfigService(db_session)
memory_config = config_service.load_memory_config(
config_id=memory_config_id, # 改为整数
config_id=memory_config, # 改为整数
service_name="MemoryAgentService"
)
if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
# Dialogue window with 6 rounds of conversation
'''Strategy 1: Dialogue window with 6 rounds of conversation'''
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
# Time-based strategy
"""Time-based strategy"""
await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
# Aggregate judgment
"""Strategy 3: Aggregate judgment"""
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
async def write_long_term(
storage_type: str,
end_user_id: str,
messages: list[dict],
user_rag_memory_id: str,
actual_config_id: str
):
async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id):
"""
Write long-term memory with different storage types
@@ -71,24 +89,44 @@ async def write_long_term(
Args:
storage_type: Type of storage (RAG or traditional)
end_user_id: User group identifier
messages: message list
message_chat: User message content
aimessages: AI response messages
user_rag_memory_id: RAG memory identifier
actual_config_id: Actual configuration ID
"""
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
message_content = []
for message in messages:
message_content.append(f'{message.get("role")}:{message.get("content")}')
messages_string = "\n".join(message_content)
await write_rag(end_user_id, messages_string, user_rag_memory_id)
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
else:
# AI reply writing (user messages and AI replies paired, written as complete dialogue at once)
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
await long_term_storage(long_term_type=CHUNK,
langchain_messages=messages,
memory_config_id=actual_config_id,
end_user_id=end_user_id,
scope=SCOPE)
await term_memory_save(end_user_id, CHUNK, scope=SCOPE)
long_term_messages = await agent_chat_messages(message_chat, aimessages)
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
# async def main():
# """主函数 - 运行工作流"""
# langchain_messages = [
# {
# "role": "user",
# "content": "今天周五去爬山"
# },
# {
# "role": "assistant",
# "content": "好耶"
# }
#
# ]
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
#
#
#
# if __name__ == "__main__":
# import asyncio
# asyncio.run(main())

View File

@@ -10,6 +10,7 @@ from app.core.logging_config import get_agent_logger
from app.core.memory.src.search import run_hybrid_search
from app.core.memory.utils.data.text_utils import escape_lucene_query
logger = get_agent_logger(__name__)
# 需要从展开结果中过滤的字段(含 Neo4j DateTime不可 JSON 序列化)
@@ -75,8 +76,7 @@ async def expand_communities_to_statements(
if s.get("statement") and s["statement"] not in existing_lines
]
cleaned = _clean_expand_fields(expanded_stmts)
logger.info(
f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements新增 {len(new_texts)}community_ids={community_ids}")
logger.info(f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements新增 {len(new_texts)}community_ids={community_ids}")
return cleaned, new_texts
@@ -180,7 +180,7 @@ class SearchService:
rerank_alpha: float = 0.4,
output_path: str = "search_results.json",
return_raw_results: bool = False,
memory_config=None,
memory_config = None,
expand_communities: bool = True,
) -> Tuple[str, str, Optional[dict]]:
"""
@@ -269,6 +269,7 @@ class SearchService:
ntype = "community" if ('member_count' in ans or 'core_entities' in ans) else ""
content_list.append(self.extract_content_from_result(ans, node_type=ntype))
# Filter out empty strings and join with newlines
clean_content = '\n'.join([c for c in content_list if c])

View File

@@ -1,3 +1,4 @@
import os
from collections import defaultdict
from pathlib import Path
from typing import Annotated, TypedDict
@@ -51,7 +52,6 @@ class ReadState(TypedDict):
embedding_id: str
memory_config: object # 新增字段用于传递内存配置对象
retrieve: dict
perceptual_data: dict
RetrieveSummary: dict
InputSummary: dict
verify: dict

View File

@@ -3,7 +3,6 @@ import uuid
from app.core.config import settings
from typing import List, Dict, Any, Optional, Union
from app.core.logging_config import get_logger
from app.core.memory.agent.utils.redis_base import (
serialize_messages,
deserialize_messages,
@@ -15,7 +14,7 @@ from app.core.memory.agent.utils.redis_base import (
get_current_timestamp
)
logger = get_logger(__name__)
class RedisWriteStore:
@@ -67,10 +66,10 @@ class RedisWriteStore:
})
result = pipe.execute()
logger.debug(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
return session_id
except Exception as e:
logger.error(f"[save_session_write] 保存会话失败: {e}")
print(f"[save_session_write] 保存会话失败: {e}")
raise e
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
@@ -113,10 +112,10 @@ class RedisWriteStore:
if not results:
return False
logger.debug(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
return results
except Exception as e:
logger.error(f"[get_session_by_userid] 查询失败: {e}")
print(f"[get_session_by_userid] 查询失败: {e}")
return False
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
@@ -145,7 +144,7 @@ class RedisWriteStore:
# 只查询 write 类型的 key
keys = self.r.keys('session:write:*')
if not keys:
logger.debug(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
return False
# 批量获取数据
@@ -176,16 +175,18 @@ class RedisWriteStore:
results.append(session_info)
if not results:
logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
return False
# 按时间排序(最新的在前)
results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
return results
except Exception as e:
logger.error(f"[get_all_sessions_by_end_user_id] 查询失败: {e}", exc_info=True)
print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}")
import traceback
traceback.print_exc()
return False
def find_user_recent_sessions(self, userid: str,
@@ -206,7 +207,7 @@ class RedisWriteStore:
# 只查询 write 类型的 key
keys = self.r.keys('session:write:*')
if not keys:
logger.debug(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
return []
# 批量获取数据
@@ -233,10 +234,11 @@ class RedisWriteStore:
# 根据时间范围过滤
filtered_items = filter_by_time_range(matched_items, minutes)
# 排序并移除时间字段
result_items = sort_and_limit_results(filtered_items)
result_items = sort_and_limit_results(filtered_items, limit=None)
print(result_items)
elapsed_time = time.time() - start_time
logger.debug(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
return result_items
@@ -276,7 +278,7 @@ class RedisCountStore:
decode_responses=True,
encoding='utf-8'
)
self.uuid = session_id
self.uudi = session_id
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
"""
@@ -296,7 +298,7 @@ class RedisCountStore:
pipe = self.r.pipeline()
pipe.hset(key, mapping={
"id": self.uuid,
"id": self.uudi,
"end_user_id": end_user_id,
"count": int(count),
"messages": serialize_messages(messages),
@@ -309,10 +311,10 @@ class RedisCountStore:
result = pipe.execute()
logger.debug(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
return session_id
def get_sessions_count(self, end_user_id: str) -> tuple[int, list[dict]] | bool:
def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]:
"""
通过 end_user_id 查询访问次数统计
@@ -333,7 +335,7 @@ class RedisCountStore:
self.r.delete(index_key)
return False
except Exception as type_error:
logger.error(f"[get_sessions_count] 检查键类型失败: {type_error}")
print(f"[get_sessions_count] 检查键类型失败: {type_error}")
session_id = self.r.get(index_key)
@@ -353,20 +355,15 @@ class RedisCountStore:
messages_str = data.get('messages')
if count is not None:
messages: list[dict] = deserialize_messages(messages_str)
return int(count), messages
messages = deserialize_messages(messages_str)
return [int(count), messages]
return False
except Exception as e:
logger.error(f"[get_sessions_count] 查询失败: {e}")
print(f"[get_sessions_count] 查询失败: {e}")
return False
def update_sessions_count(
self,
end_user_id: str,
new_count: int,
messages: Any
) -> bool:
def update_sessions_count(self, end_user_id: str, new_count: int,
messages: Any) -> bool:
"""
通过 end_user_id 修改访问次数统计(优化版:使用索引)
@@ -387,17 +384,17 @@ class RedisCountStore:
key_type = self.r.type(index_key)
if key_type != 'string' and key_type != 'none':
# 索引键类型错误,删除并返回 False
logger.warning(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
self.r.delete(index_key)
logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
return False
except Exception as type_error:
logger.error(f"[update_sessions_count] 检查键类型失败: {type_error}")
print(f"[update_sessions_count] 检查键类型失败: {type_error}")
session_id = self.r.get(index_key)
if not session_id:
logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
return False
# 直接更新数据
@@ -405,15 +402,15 @@ class RedisCountStore:
messages_str = serialize_messages(messages)
pipe = self.r.pipeline()
pipe.hset(key, 'count', str(new_count))
pipe.hset(key, 'count', int(new_count))
pipe.hset(key, 'messages', messages_str)
result = pipe.execute()
logger.debug(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
return True
except Exception as e:
logger.debug(f"[update_sessions_count] 更新失败: {e}")
print(f"[update_sessions_count] 更新失败: {e}")
return False
def delete_all_count_sessions(self) -> int:
@@ -486,10 +483,10 @@ class RedisSessionStore:
})
result = pipe.execute()
logger.debug(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
return session_id
except Exception as e:
logger.error(f"[save_session] 保存会话失败: {e}")
print(f"[save_session] 保存会话失败: {e}")
raise e
# ==================== 读取操作 ====================
@@ -541,7 +538,7 @@ class RedisSessionStore:
keys = self.r.keys('session:*')
if not keys:
logger.debug(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
return []
# 批量获取数据
@@ -568,7 +565,7 @@ class RedisSessionStore:
result_items = sort_and_limit_results(matched_items, limit=6)
elapsed_time = time.time() - start_time
logger.debug(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
return result_items
@@ -635,7 +632,7 @@ class RedisSessionStore:
keys = self.r.keys('session:*')
if not keys:
logger.debug("[delete_duplicate_sessions] 没有会话数据")
print("[delete_duplicate_sessions] 没有会话数据")
return 0
# 批量获取所有数据
@@ -681,7 +678,7 @@ class RedisSessionStore:
deleted_count += len(batch)
elapsed_time = time.time() - start_time
logger.debug(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}")
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}")
return deleted_count

View File

@@ -14,7 +14,6 @@ from dotenv import load_dotenv
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import _USER_PLACEHOLDER_NAMES
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import \
memory_summary_generation
@@ -153,24 +152,6 @@ async def write(
# Step 3: Save all data to Neo4j database
step_start = time.time()
# Neo4j 写入前:清洗用户/AI助手实体之间的别名交叉污染
# 从 Neo4j 查询已有的 AI 助手别名,与本轮实体中的 AI 助手别名合并,
# 确保用户实体的 aliases 不包含 AI 助手的名字
try:
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
clean_cross_role_aliases,
fetch_neo4j_assistant_aliases,
)
neo4j_assistant_aliases = set()
if all_entity_nodes:
_eu_id = all_entity_nodes[0].end_user_id
if _eu_id:
neo4j_assistant_aliases = await fetch_neo4j_assistant_aliases(neo4j_connector, _eu_id)
clean_cross_role_aliases(all_entity_nodes, external_assistant_aliases=neo4j_assistant_aliases)
logger.info(f"Neo4j 写入前别名清洗完成AI助手别名排除集大小: {len(neo4j_assistant_aliases)}")
except Exception as e:
logger.warning(f"Neo4j 写入前别名清洗失败(不影响主流程): {e}")
# 添加死锁重试机制
max_retries = 3
retry_delay = 1 # 秒
@@ -192,37 +173,15 @@ async def write(
if success:
logger.info("Successfully saved all data to Neo4j")
if all_entity_nodes:
end_user_id = all_entity_nodes[0].end_user_id
# Neo4j 写入完成后,用 PgSQL 权威 aliases 覆盖 Neo4j 用户实体
try:
from app.repositories.end_user_info_repository import EndUserInfoRepository
if end_user_id:
with get_db_context() as db_session:
info = EndUserInfoRepository(db_session).get_by_end_user_id(uuid.UUID(end_user_id))
pg_aliases = info.aliases if info and info.aliases else []
if info is not None:
# 将 Python 侧占位名集合作为参数传入,避免 Cypher 硬编码
placeholder_names = list(_USER_PLACEHOLDER_NAMES)
await neo4j_connector.execute_query(
"""
MATCH (e:ExtractedEntity)
WHERE e.end_user_id = $end_user_id AND toLower(e.name) IN $placeholder_names
SET e.aliases = $aliases
""",
end_user_id=end_user_id, aliases=pg_aliases,
placeholder_names=placeholder_names,
)
logger.info(f"[AliasSync] Neo4j 用户实体 aliases 已用 PgSQL 权威源覆盖: {pg_aliases}")
except Exception as sync_err:
logger.warning(f"[AliasSync] PgSQL→Neo4j aliases 同步失败(不影响主流程): {sync_err}")
# 使用 Celery 异步任务触发聚类(不阻塞主流程)
if all_entity_nodes:
try:
from app.tasks import run_incremental_clustering
end_user_id = all_entity_nodes[0].end_user_id
new_entity_ids = [e.id for e in all_entity_nodes]
# 异步提交 Celery 任务
task = run_incremental_clustering.apply_async(
kwargs={
"end_user_id": end_user_id,
@@ -230,6 +189,7 @@ async def write(
"llm_model_id": str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
"embedding_model_id": str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
},
# 设置任务优先级(低优先级,不影响主业务)
priority=3,
)
logger.info(
@@ -237,6 +197,7 @@ async def write(
f"task_id={task.id}, end_user_id={end_user_id}, entity_count={len(new_entity_ids)}"
)
except Exception as e:
# 聚类任务提交失败不影响主流程
logger.error(f"[Clustering] 提交聚类任务失败(不影响主流程): {e}", exc_info=True)
break

View File

@@ -56,7 +56,7 @@ class LLMClient(ABC):
self.max_retries = self.config.max_retries
self.timeout = self.config.timeout
logger.debug(
logger.info(
f"初始化 LLM 客户端: provider={self.provider}, "
f"model={self.model_name}, max_retries={self.max_retries}"
)

View File

@@ -58,14 +58,6 @@ from app.core.memory.models.triplet_models import (
TripletExtractionResponse,
)
# User metadata models
from app.core.memory.models.metadata_models import (
UserMetadata,
UserMetadataProfile,
MetadataExtractionResponse,
MetadataFieldChange,
)
# Ontology scenario models (LLM extracted from scenarios)
from app.core.memory.models.ontology_scenario_models import (
OntologyClass,
@@ -132,10 +124,6 @@ __all__ = [
"Entity",
"Triplet",
"TripletExtractionResponse",
"UserMetadata",
"UserMetadataProfile",
"MetadataExtractionResponse",
"MetadataFieldChange",
# Ontology models
"OntologyClass",
"OntologyExtractionResponse",

View File

@@ -364,14 +364,12 @@ class ChunkNode(Node):
Attributes:
dialog_id: ID of the parent dialog
content: The text content of the chunk
speaker: Speaker identifier ('user' or 'assistant')
chunk_embedding: Optional embedding vector for the chunk
sequence_number: Order of this chunk within the dialog
metadata: Additional chunk metadata as key-value pairs
"""
dialog_id: str = Field(..., description="ID of the parent dialog")
content: str = Field(..., description="The text content of the chunk")
speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses")
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
sequence_number: int = Field(..., description="Order of this chunk within the dialog")
metadata: dict = Field(default_factory=dict, description="Additional chunk metadata")

View File

@@ -1,63 +0,0 @@
"""Models for user metadata extraction.
Independent from triplet_models.py - these models are used by the
standalone metadata extraction pipeline (post-dedup async Celery task).
"""
from typing import List, Literal, Optional
from pydantic import BaseModel, ConfigDict, Field
class UserMetadataProfile(BaseModel):
"""用户画像信息"""
model_config = ConfigDict(extra="ignore")
role: List[str] = Field(default_factory=list, description="用户职业或角色")
domain: List[str] = Field(default_factory=list, description="用户所在领域")
expertise: List[str] = Field(
default_factory=list, description="用户擅长的技能或工具"
)
interests: List[str] = Field(
default_factory=list, description="用户关注的话题或领域标签"
)
class UserMetadata(BaseModel):
"""用户元数据顶层结构"""
model_config = ConfigDict(extra="ignore")
profile: UserMetadataProfile = Field(default_factory=UserMetadataProfile)
class MetadataFieldChange(BaseModel):
"""单个元数据字段的变更操作"""
model_config = ConfigDict(extra="ignore")
field_path: str = Field(
description="字段路径,用点号分隔,如 'profile.role''profile.expertise'"
)
action: Literal["set", "remove"] = Field(
description="操作类型:'set' 表示新增或修改,'remove' 表示移除"
)
value: Optional[str] = Field(
default=None,
description="字段的新值action='set' 时必填)。标量字段直接填值,列表字段填单个要新增的元素"
)
class MetadataExtractionResponse(BaseModel):
"""元数据提取 LLM 响应结构(增量模式)"""
model_config = ConfigDict(extra="ignore")
metadata_changes: List[MetadataFieldChange] = Field(
default_factory=list,
description="元数据的增量变更列表,每项描述一个字段的新增、修改或移除操作",
)
aliases_to_add: List[str] = Field(
default_factory=list,
description="本次新发现的用户别名(用户自我介绍或他人对用户的称呼)",
)
aliases_to_remove: List[str] = Field(
default_factory=list, description="用户明确否认的别名(如'我不叫XX了'"
)

View File

@@ -1,3 +1,4 @@
import argparse
import asyncio
import json
import math
@@ -5,6 +6,7 @@ import os
import time
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from uuid import UUID
if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig
@@ -21,7 +23,7 @@ from app.core.memory.utils.config.config_utils import (
)
from app.core.memory.utils.data.text_utils import extract_plain_query
from app.core.memory.utils.data.time_utils import normalize_date_safe
# from app.core.memory.utils.llm.llm_utils import get_reranker_client
from app.core.memory.utils.llm.llm_utils import get_reranker_client
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import (
@@ -41,7 +43,6 @@ load_dotenv()
logger = get_memory_logger(__name__)
def _parse_datetime(value: Any) -> Optional[datetime]:
"""Parse ISO `created_at` strings of the form 'YYYY-MM-DDTHH:MM:SS.ssssss'."""
if value is None:
@@ -131,6 +132,7 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
return results
def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Remove duplicate items from search results based on content.
@@ -194,7 +196,6 @@ def rerank_with_activation(
forgetting_config: ForgettingEngineConfig | None = None,
activation_boost_factor: float = 0.8,
now: datetime | None = None,
content_score_threshold: float = 0.5,
) -> Dict[str, List[Dict[str, Any]]]:
"""
两阶段排序:先按内容相关性筛选,再按激活值排序。
@@ -221,8 +222,6 @@ def rerank_with_activation(
forgetting_config: 遗忘引擎配置(当前未使用)
activation_boost_factor: 激活度对记忆强度的影响系数 (默认: 0.8)
now: 当前时间(用于遗忘计算)
content_score_threshold: 内容相关性最低阈值(基于归一化后的 content_score
低于此阈值的结果会被过滤。默认 0.5。
返回:
带评分元数据的重排序结果,按 final_score 排序
@@ -392,19 +391,7 @@ def rerank_with_activation(
# 无激活值:使用内容相关性分数
item["final_score"] = item.get("base_score", 0)
if content_score_threshold > 0:
before_count = len(sorted_items)
sorted_items = [
item for item in sorted_items
if float(item.get("content_score", 0) or 0) >= content_score_threshold
]
filtered_count = before_count - len(sorted_items)
if filtered_count > 0:
logger.info(
f"[rerank] {category}: filtered {filtered_count}/{before_count} "
f"items below content_score_threshold={content_score_threshold}"
)
# 最终去重确保没有重复项
sorted_items = _deduplicate_results(sorted_items)
reranked[category] = sorted_items
@@ -412,8 +399,7 @@ def rerank_with_activation(
return reranked
def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str],
log_file: str = None):
def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str], log_file: str = None):
"""Log search query information using the logger.
Args:
@@ -746,10 +732,11 @@ async def run_hybrid_search(
if search_type in ["keyword", "hybrid"]:
# Keyword-based search
logger.info("[PERF] Starting keyword search...")
keyword_start = time.time()
keyword_task = asyncio.create_task(
search_graph(
connector=connector,
query=query_text,
q=query_text,
end_user_id=end_user_id,
limit=limit,
include=include
@@ -759,6 +746,7 @@ async def run_hybrid_search(
if search_type in ["embedding", "hybrid"]:
# Embedding-based search
logger.info("[PERF] Starting embedding search...")
embedding_start = time.time()
# 从数据库读取嵌入器配置(按 ID并构建 RedBearModelConfig
config_load_start = time.time()
@@ -770,7 +758,8 @@ async def run_hybrid_search(
model_name=embedder_config_dict["model_name"],
provider=embedder_config_dict["provider"],
api_key=embedder_config_dict["api_key"],
base_url=embedder_config_dict["base_url"]
base_url=embedder_config_dict["base_url"],
type="llm"
)
config_load_time = time.time() - config_load_start
logger.info(f"[PERF] Config loading took {config_load_time:.4f}s")
@@ -800,7 +789,7 @@ async def run_hybrid_search(
if keyword_task:
keyword_results = await keyword_task
keyword_latency = time.time() - search_start_time
keyword_latency = time.time() - keyword_start
latency_metrics["keyword_search_latency"] = round(keyword_latency, 4)
logger.info(f"[PERF] Keyword search completed in {keyword_latency:.4f}s")
if search_type == "keyword":
@@ -810,7 +799,7 @@ async def run_hybrid_search(
if embedding_task:
embedding_results = await embedding_task
embedding_latency = time.time() - search_start_time
embedding_latency = time.time() - embedding_start
latency_metrics["embedding_search_latency"] = round(embedding_latency, 4)
logger.info(f"[PERF] Embedding search completed in {embedding_latency:.4f}s")
if search_type == "embedding":
@@ -822,8 +811,7 @@ async def run_hybrid_search(
if search_type == "hybrid":
results["combined_summary"] = {
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
"total_embedding_results": sum(
len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
"total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
"search_query": query_text,
"search_timestamp": datetime.now().isoformat()
}
@@ -879,8 +867,7 @@ async def run_hybrid_search(
results["reranked_results"] = reranked_results
results["combined_summary"] = {
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
"total_embedding_results": sum(
len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
"total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
"total_reranked_results": sum(len(v) if isinstance(v, list) else 0 for v in reranked_results.values()),
"search_query": query_text,
"search_timestamp": datetime.now().isoformat(),
@@ -900,10 +887,10 @@ async def run_hybrid_search(
else:
results["latency_metrics"] = latency_metrics
logger.info("[PERF] ===== SEARCH PERFORMANCE SUMMARY =====")
logger.info(f"[PERF] ===== SEARCH PERFORMANCE SUMMARY =====")
logger.info(f"[PERF] Total search completed in {total_latency:.4f}s")
logger.info(f"[PERF] Latency breakdown: {json.dumps(latency_metrics, indent=2)}")
logger.info("[PERF] =========================================")
logger.info(f"[PERF] =========================================")
# Sanitize results: drop large/unused fields
_remove_keys_recursive(results, ["name_embedding"]) # drop entity name embeddings from outputs
@@ -922,10 +909,8 @@ async def run_hybrid_search(
# Log search completion with result count
if search_type == "hybrid":
result_counts = {
"keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in
keyword_results.items()},
"embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in
embedding_results.items()}
"keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in keyword_results.items()},
"embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in embedding_results.items()}
}
else:
result_counts = {key: len(value) if isinstance(value, list) else 0 for key, value in results.items()}
@@ -1042,3 +1027,4 @@ async def search_chunk_by_chunk_id(
limit=limit
)
return {"chunks": chunks}

View File

@@ -4,7 +4,6 @@
import asyncio
import difflib # 提供字符串相似度计算工具
import importlib
import logging
import os
import re
from datetime import datetime
@@ -17,8 +16,6 @@ from app.core.memory.models.graph_models import (
)
from app.core.memory.models.variate_config import DedupConfig
logger = logging.getLogger(__name__)
# 模块级类型统一工具函数
def _unify_entity_type(canonical: ExtractedEntityNode, losing: ExtractedEntityNode, suggested_type: str = None) -> None:
@@ -82,37 +79,50 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
canonical.connect_strength = next(iter(pair))
# 别名合并(去重保序,使用标准化工具)
# 用户实体的 aliases 由 PgSQL end_user_info 作为唯一权威源,去重合并时不修改
try:
canonical_name = (getattr(canonical, "name", "") or "").strip()
if canonical_name.lower() not in _USER_PLACEHOLDER_NAMES:
incoming_name = (getattr(ent, "name", "") or "").strip()
# 收集所有需要合并的别名,过滤掉用户占位名避免污染非用户实体
all_aliases = list(getattr(canonical, "aliases", []) or [])
if incoming_name and incoming_name != canonical_name and incoming_name.lower() not in _USER_PLACEHOLDER_NAMES:
all_aliases.append(incoming_name)
all_aliases.extend(
a for a in (getattr(ent, "aliases", []) or [])
if a and a.strip().lower() not in _USER_PLACEHOLDER_NAMES
)
# 收集所有需要合并的别名
all_aliases = []
# 1. 添加canonical现有的别名
existing = getattr(canonical, "aliases", []) or []
all_aliases.extend(existing)
# 2. 添加incoming实体的名称如果不同于canonical的名称
if incoming_name and incoming_name != canonical_name:
all_aliases.append(incoming_name)
# 3. 添加incoming实体的所有别名
incoming = getattr(ent, "aliases", []) or []
all_aliases.extend(incoming)
# 4. 标准化并去重优先使用alias_utils工具函数
try:
from app.core.memory.utils.alias_utils import normalize_aliases
canonical.aliases = normalize_aliases(canonical_name, all_aliases)
except Exception:
# 如果导入失败,使用增强的去重逻辑
seen_normalized = set()
unique_aliases = []
for alias in all_aliases:
if not alias:
continue
alias_stripped = str(alias).strip()
if not alias_stripped or alias_stripped == canonical_name:
continue
# 标准化:转小写用于去重判断
alias_normalized = alias_stripped.lower()
if alias_normalized not in seen_normalized:
seen_normalized.add(alias_normalized)
unique_aliases.append(alias_stripped)
# 排序并赋值
canonical.aliases = sorted(unique_aliases)
except Exception:
pass
@@ -188,161 +198,6 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
except Exception:
pass
# 用户和AI助手的占位名称集合用于名称标准化
_USER_PLACEHOLDER_NAMES = {"用户", "", "user", "i"}
_ASSISTANT_PLACEHOLDER_NAMES = {"ai助手", "助手", "人工智能助手", "智能助手", "智能体", "ai assistant", "assistant"}
# 标准化后的规范名称和类型
_CANONICAL_USER_NAME = "用户"
_CANONICAL_USER_TYPE = "用户"
_CANONICAL_ASSISTANT_NAME = "AI助手"
_CANONICAL_ASSISTANT_TYPE = "Agent"
# 用户和AI助手的所有可能名称用于判断实体是否为特殊角色实体
_ALL_USER_NAMES = _USER_PLACEHOLDER_NAMES
_ALL_ASSISTANT_NAMES = _ASSISTANT_PLACEHOLDER_NAMES
def _is_user_entity(ent: ExtractedEntityNode) -> bool:
"""判断实体是否为用户实体name 或 entity_type 匹配)"""
name = (getattr(ent, "name", "") or "").strip().lower()
etype = (getattr(ent, "entity_type", "") or "").strip()
return name in _ALL_USER_NAMES or etype == _CANONICAL_USER_TYPE
def _is_assistant_entity(ent: ExtractedEntityNode) -> bool:
"""判断实体是否为AI助手实体name 或 entity_type 匹配)"""
name = (getattr(ent, "name", "") or "").strip().lower()
etype = (getattr(ent, "entity_type", "") or "").strip()
return name in _ALL_ASSISTANT_NAMES or etype == _CANONICAL_ASSISTANT_TYPE
def _would_merge_cross_role(a: ExtractedEntityNode, b: ExtractedEntityNode) -> bool:
"""判断两个实体的合并是否会跨越用户/AI助手角色边界。
用户实体和AI助手实体永远不应该被合并在一起。
如果一方是用户实体、另一方是AI助手实体返回 True阻止合并
"""
return (
(_is_user_entity(a) and _is_assistant_entity(b))
or (_is_assistant_entity(a) and _is_user_entity(b))
)
def _normalize_special_entity_names(
entity_nodes: List[ExtractedEntityNode],
) -> None:
"""标准化用户和AI助手实体的名称和类型。
多轮对话中LLM 对同一角色可能使用不同的名称变体(如"用户"/""/"User"
"AI助手"/"助手"/"Assistant"),导致精确匹配无法合并。
此函数在去重前将这些变体统一为规范名称,并强制绑定 entity_type确保
- name="用户" 的实体 entity_type 一定为 "用户"
- name="AI助手" 的实体 entity_type 一定为 "Agent"
Args:
entity_nodes: 实体节点列表(原地修改)
"""
for ent in entity_nodes:
name = (getattr(ent, "name", "") or "").strip()
name_lower = name.lower()
if name_lower in _USER_PLACEHOLDER_NAMES:
ent.name = _CANONICAL_USER_NAME
ent.entity_type = _CANONICAL_USER_TYPE
elif name_lower in _ASSISTANT_PLACEHOLDER_NAMES:
ent.name = _CANONICAL_ASSISTANT_NAME
ent.entity_type = _CANONICAL_ASSISTANT_TYPE
# 第二步:清洗用户/AI助手之间的别名交叉污染复用 clean_cross_role_aliases
clean_cross_role_aliases(entity_nodes)
async def fetch_neo4j_assistant_aliases(neo4j_connector, end_user_id: str) -> set:
"""从 Neo4j 查询 AI 助手实体的所有别名(小写归一化)。
这是助手别名查询的唯一入口,供 write_tools 和 extraction_orchestrator 共用,
避免多处维护相同的 Cypher 和名称列表。
Args:
neo4j_connector: Neo4j 连接器实例(需提供 execute_query 方法)
end_user_id: 终端用户 ID
Returns:
小写归一化后的助手别名集合
"""
# 查询名称列表:规范名称 + 常见变体(与 _normalize_special_entity_names 标准化后一致)
query_names = [_CANONICAL_ASSISTANT_NAME, *_ASSISTANT_PLACEHOLDER_NAMES]
# 去重保序
query_names = list(dict.fromkeys(query_names))
cypher = """
MATCH (e:ExtractedEntity)
WHERE e.end_user_id = $end_user_id AND e.name IN $names
RETURN e.aliases AS aliases
"""
try:
result = await neo4j_connector.execute_query(
cypher, end_user_id=end_user_id, names=query_names
)
assistant_aliases: set = set()
for record in (result or []):
for alias in (record.get("aliases") or []):
assistant_aliases.add(alias.strip().lower())
if assistant_aliases:
logger.debug(f"Neo4j 中 AI 助手别名: {assistant_aliases}")
return assistant_aliases
except Exception as e:
logger.warning(f"查询 Neo4j AI 助手别名失败: {e}")
return set()
def clean_cross_role_aliases(
entity_nodes: List[ExtractedEntityNode],
external_assistant_aliases: set = None,
) -> None:
"""清洗用户实体和AI助手实体之间的别名交叉污染。
在 Neo4j 写入前调用,确保:
- 用户实体的 aliases 不包含 AI 助手的别名
- AI 助手实体的 aliases 不包含用户的别名
Args:
entity_nodes: 实体节点列表(原地修改)
external_assistant_aliases: 外部传入的 AI 助手别名集合(如从 Neo4j 查询),
与本轮实体中的 AI 助手别名合并使用
"""
# 收集本轮 AI 助手实体的所有别名
assistant_aliases = set(external_assistant_aliases or set())
user_aliases = set()
for ent in entity_nodes:
if _is_assistant_entity(ent):
for alias in (getattr(ent, "aliases", []) or []):
assistant_aliases.add(alias.strip().lower())
elif _is_user_entity(ent):
for alias in (getattr(ent, "aliases", []) or []):
user_aliases.add(alias.strip().lower())
# 从用户实体的 aliases 中移除 AI 助手别名
if assistant_aliases:
for ent in entity_nodes:
if _is_user_entity(ent):
original = getattr(ent, "aliases", []) or []
cleaned = [a for a in original if a.strip().lower() not in assistant_aliases]
if len(cleaned) < len(original):
ent.aliases = cleaned
# 从 AI 助手实体的 aliases 中移除用户别名
if user_aliases:
for ent in entity_nodes:
if _is_assistant_entity(ent):
original = getattr(ent, "aliases", []) or []
cleaned = [a for a in original if a.strip().lower() not in user_aliases]
if len(cleaned) < len(original):
ent.aliases = cleaned
def accurate_match(
entity_nodes: List[ExtractedEntityNode]
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]:
@@ -406,10 +261,6 @@ def accurate_match(
canonical = alias_index.get((ent_uid, ent_name))
# 确保不是自身
if canonical is not None and canonical.id != ent.id:
# 保护禁止跨角色合并用户实体和AI助手实体不能互相合并
if _would_merge_cross_role(canonical, ent):
i += 1
continue
_merge_attribute(canonical, ent)
id_redirect[ent.id] = canonical.id
for k, v in list(id_redirect.items()):
@@ -720,37 +571,66 @@ def fuzzy_match(
def _merge_entities_with_aliases(canonical: ExtractedEntityNode, losing: ExtractedEntityNode):
"""模糊匹配中的实体合并(别名部分)
""" 模糊匹配中的实体合并。
用户实体的 aliases 由 PgSQL end_user_info 作为唯一权威源,跳过合并。
合并策略:
1. 保留canonical的主名称不变
2. 将losing的主名称添加为alias如果不同
3. 合并两个实体的所有aliases
4. 自动去重case-insensitive并排序
Args:
canonical: 规范实体(保留)
losing: 被合并实体(删除)
Note:
使用alias_utils.normalize_aliases进行标准化去重
"""
# 获取规范实体的名称
canonical_name = (getattr(canonical, "name", "") or "").strip()
if canonical_name.lower() in _USER_PLACEHOLDER_NAMES:
return
losing_name = (getattr(losing, "name", "") or "").strip()
all_aliases = list(getattr(canonical, "aliases", []) or [])
# 收集所有需要合并的别名
all_aliases = []
# 1. 添加canonical现有的别名
current_aliases = getattr(canonical, "aliases", []) or []
all_aliases.extend(current_aliases)
# 2. 添加losing实体的名称如果不同于canonical的名称
if losing_name and losing_name != canonical_name:
all_aliases.append(losing_name)
all_aliases.extend(getattr(losing, "aliases", []) or [])
# 3. 添加losing实体的所有别名
losing_aliases = getattr(losing, "aliases", []) or []
all_aliases.extend(losing_aliases)
# 4. 标准化并去重(使用标准化后的字符串进行去重)
try:
from app.core.memory.utils.alias_utils import normalize_aliases
canonical.aliases = normalize_aliases(canonical_name, all_aliases)
except Exception:
# 如果导入失败,使用增强的去重逻辑
# 使用标准化后的字符串作为key进行去重
seen_normalized = set()
unique_aliases = []
for alias in all_aliases:
if not alias:
continue
alias_stripped = str(alias).strip()
if not alias_stripped or alias_stripped == canonical_name:
continue
# 标准化:转小写用于去重判断
alias_normalized = alias_stripped.lower()
if alias_normalized not in seen_normalized:
seen_normalized.add(alias_normalized)
unique_aliases.append(alias_stripped)
# 排序并赋值
canonical.aliases = sorted(unique_aliases)
# ========== 主循环:遍历所有实体对进行模糊匹配 ==========
@@ -824,11 +704,6 @@ def fuzzy_match(
# 条件A快速通道alias_match_merge = True
# 条件B标准通道s_name ≥ tn AND s_type ≥ type_threshold AND overall ≥ tover
if alias_match_merge or (s_name >= tn and s_type >= type_threshold and overall >= tover):
# 保护禁止跨角色合并用户实体和AI助手实体不能互相合并
if _would_merge_cross_role(a, b):
j += 1
continue
# ========== 第六步:执行实体合并 ==========
# 6.1 合并别名
@@ -938,12 +813,6 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
b = entity_by_id.get(losing_id)
if not a or not b: # 若不存在 a 或 b可能已在精确或模糊阶段合并在之前阶段合并之后不会再处理但是处于审计的目的会记录
continue
# 保护禁止跨角色合并用户实体和AI助手实体不能互相合并
if _would_merge_cross_role(a, b):
llm_records.append(
f"[LLM阻断] 跨角色合并被阻止: {a.id} ({a.name}) 与 {b.id} ({b.name})"
)
continue
_merge_attribute(a, b)
# ID 重定向
try:
@@ -1065,9 +934,6 @@ async def deduplicate_entities_and_edges(
返回:去重后的实体、语句→实体边、实体↔实体边。
"""
local_llm_records: List[str] = [] # 作为“审计日志”的本地收集器 初始化保留为了之后对于LLM决策追溯
# 0) 标准化用户和AI助手实体名称确保多轮对话中的变体名称统一
_normalize_special_entity_names(entity_nodes)
# 1) 精确匹配
deduped_entities, id_redirect, exact_merge_map = accurate_match(entity_nodes)

View File

@@ -15,7 +15,6 @@ from app.core.memory.models.message_models import DialogData
from app.core.memory.models.variate_config import ExtractionPipelineConfig
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
deduplicate_entities_and_edges,
clean_cross_role_aliases,
)
from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import (
second_layer_dedup_and_merge_with_neo4j,
@@ -101,10 +100,6 @@ async def dedup_layers_and_merge_and_return(
except Exception as e:
print(f"Second-layer dedup failed: {e}")
# 第二层去重后,清洗用户/AI助手之间的别名交叉污染
# 第二层从 Neo4j 合并了旧实体,可能带入历史脏数据
clean_cross_role_aliases(fused_entity_nodes)
return (
dialogue_nodes,
chunk_nodes,

View File

@@ -44,10 +44,6 @@ from app.core.memory.models.variate_config import (
from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import (
dedup_layers_and_merge_and_return,
)
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
_USER_PLACEHOLDER_NAMES,
fetch_neo4j_assistant_aliases,
)
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import (
embedding_generation,
generate_entity_embeddings_from_triplets,
@@ -311,53 +307,10 @@ class ExtractionOrchestrator:
dialog_data_list,
)
# 步骤 7: 触发异步元数据和别名提取(仅正式模式)
# 步骤 7: 同步用户别名到数据库表(仅正式模式)
if not is_pilot_run:
try:
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.metadata_extractor import (
MetadataExtractor,
)
metadata_extractor = MetadataExtractor(
llm_client=self.llm_client, language=self.language
)
user_statements = (
metadata_extractor.collect_user_related_statements(
entity_nodes, statement_nodes, statement_entity_edges
)
)
if user_statements:
end_user_id = (
dialog_data_list[0].end_user_id
if dialog_data_list
else None
)
config_id = (
dialog_data_list[0].config_id
if dialog_data_list
and hasattr(dialog_data_list[0], "config_id")
else None
)
if end_user_id:
from app.tasks import extract_user_metadata_task
extract_user_metadata_task.delay(
end_user_id=str(end_user_id),
statements=user_statements,
config_id=str(config_id) if config_id else None,
language=self.language,
)
logger.info(
f"已触发异步元数据提取任务,共 {len(user_statements)} 条用户相关 statement"
)
else:
logger.info("未找到用户相关 statement跳过元数据提取")
except Exception as e:
logger.error(
f"触发元数据提取任务失败(不影响主流程): {e}", exc_info=True
)
# 别名同步已迁移到 Celery 元数据提取任务中,不再在此处执行
logger.info("步骤 7: 同步用户别名到 end_user 和 end_user_info 表")
await self._update_end_user_other_name(entity_nodes, dialog_data_list)
logger.info(f"知识提取流水线运行完成({mode_str}")
return (
@@ -1150,7 +1103,6 @@ class ExtractionOrchestrator:
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
content=chunk.content,
speaker=getattr(chunk, 'speaker', None),
chunk_embedding=chunk.chunk_embedding,
sequence_number=chunk_idx, # 添加必需的 sequence_number 字段
created_at=dialog_data.created_at,
@@ -1386,23 +1338,17 @@ class ExtractionOrchestrator:
async def _update_end_user_other_name(
self,
entity_nodes: List[ExtractedEntityNode],
dialog_data_list: List[DialogData],
dialog_data_list: List[DialogData]
) -> None:
"""
将本轮提取的用户别名同步到 end_user 和 end_user_info 表
从 Neo4j 读取用户实体的最终 aliases同步到 end_user 和 end_user_info 表
PgSQL end_user_info.aliases 是用户别名的唯一权威源。
此方法仅将本轮 LLM 从对话中新提取的别名增量追加到 PgSQL
不再从 Neo4j 二层去重合并历史别名,避免脏数据反向污染 PgSQL。
策略:
1. 从本轮对话原始发言中提取用户别名current_aliases
2. 从 PgSQL end_user_info 读取已有的 aliasesdb_aliases
3. 合并 db_aliases + current_aliases去重保序
4. 写回 PgSQL
注意:
1. other_name 使用本次对话提取的第一个别名(保持时间顺序)
2. aliases 从 Neo4j 读取(保持完整性)
Args:
entity_nodes: 去重后的实体节点列表(内存中)
entity_nodes: 实体节点列表
dialog_data_list: 对话数据列表
"""
try:
@@ -1415,28 +1361,23 @@ class ExtractionOrchestrator:
logger.warning("end_user_id 为空,跳过用户别名同步")
return
# 1. 提取本对话的用户别名(保持 LLM 提取的原始顺序,不排序)
current_aliases = self._extract_current_aliases(entity_nodes, dialog_data_list)
# 1. 提取本对话的用户别名(保持 LLM 提取的原始顺序,不排序)
current_aliases = self._extract_current_aliases(entity_nodes)
# 1.6 从 Neo4j 查询已有的 AI 助手别名,作为额外的排除源
# (防止 LLM 未提取出 AI 助手实体时AI 别名泄漏到用户别名中)
neo4j_assistant_aliases = await self._fetch_neo4j_assistant_aliases(end_user_id)
if neo4j_assistant_aliases:
before_count = len(current_aliases)
current_aliases = [
a for a in current_aliases
if a.strip().lower() not in neo4j_assistant_aliases
]
if len(current_aliases) < before_count:
logger.info(f"通过 Neo4j AI 助手别名排除了 {before_count - len(current_aliases)} 个误归属别名")
# 2. 从 Neo4j 获取完整 aliases权威数据源
neo4j_aliases = await self._fetch_neo4j_user_aliases(end_user_id)
if not current_aliases:
logger.debug(f"本轮未提取到用户别名,跳过同步: end_user_id={end_user_id}")
if not neo4j_aliases:
# Neo4j 中没有别名,使用本次对话提取的别名
neo4j_aliases = current_aliases
if not neo4j_aliases:
logger.debug(f"aliases 为空,跳过同步: end_user_id={end_user_id}")
return
logger.info(f"对话提取的 aliases: {current_aliases}")
logger.info(f"对话提取的 aliases: {current_aliases}")
logger.info(f"Neo4j 中的完整 aliases: {neo4j_aliases}")
# 2. 同步到数据库
# 3. 同步到数据库
end_user_uuid = uuid.UUID(end_user_id)
with get_db_context() as db:
# 更新 end_user 表
@@ -1445,32 +1386,7 @@ class ExtractionOrchestrator:
logger.warning(f"未找到 end_user_id={end_user_id} 的用户记录")
return
# 3. 从 PgSQL 读取已有 aliases 并与本轮新增合并
info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid)
db_aliases = (info.aliases if info and info.aliases else [])
# 过滤掉占位名称
db_aliases = [a for a in db_aliases if a.strip().lower() not in self.USER_PLACEHOLDER_NAMES]
# 合并PgSQL 已有 + 本轮新增,去重保序(不再合并 Neo4j 历史别名)
merged_aliases = list(db_aliases)
seen_lower = {a.strip().lower() for a in merged_aliases}
for alias in current_aliases:
if alias.strip().lower() not in seen_lower:
merged_aliases.append(alias)
seen_lower.add(alias.strip().lower())
# 最终过滤:从合并结果中排除 AI 助手别名(清理历史脏数据)
if neo4j_assistant_aliases:
merged_aliases = [
a for a in merged_aliases
if a.strip().lower() not in neo4j_assistant_aliases
]
logger.info(f"PgSQL 已有 aliases: {db_aliases}")
logger.info(f"合并后 aliases: {merged_aliases}")
# 更新 end_user 表 other_name
new_name = self._resolve_other_name(end_user.other_name, current_aliases, merged_aliases)
new_name = self._resolve_other_name(end_user.other_name, current_aliases, neo4j_aliases)
if new_name is not None:
end_user.other_name = new_name
logger.info(f"更新 end_user 表 other_name → {new_name}")
@@ -1478,105 +1394,78 @@ class ExtractionOrchestrator:
logger.debug(f"end_user 表 other_name 保持不变: {end_user.other_name}")
# 更新或创建 end_user_info 记录
info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid)
if info:
new_name_info = self._resolve_other_name(info.other_name, current_aliases, merged_aliases)
new_name_info = self._resolve_other_name(info.other_name, current_aliases, neo4j_aliases)
if new_name_info is not None:
info.other_name = new_name_info
logger.info(f"更新 end_user_info 表 other_name → {new_name_info}")
if info.aliases != merged_aliases:
info.aliases = merged_aliases
logger.info(f"同步合并后 aliases 到 end_user_info: {merged_aliases}")
if info.aliases != neo4j_aliases:
info.aliases = neo4j_aliases
logger.info(f"同步 Neo4j aliases 到 end_user_info: {neo4j_aliases}")
else:
first_alias = current_aliases[0].strip() if current_aliases else ""
# 确保 first_alias 不是占位名称
if first_alias and first_alias.lower() not in self.USER_PLACEHOLDER_NAMES:
if first_alias and first_alias not in self.USER_PLACEHOLDER_NAMES:
db.add(EndUserInfo(
end_user_id=end_user_uuid,
other_name=first_alias,
aliases=merged_aliases,
aliases=neo4j_aliases,
meta_data={}
))
logger.info(f"创建 end_user_info 记录other_name={first_alias}, aliases={merged_aliases}")
logger.info(f"创建 end_user_info 记录other_name={first_alias}, aliases={neo4j_aliases}")
db.commit()
except Exception as e:
logger.error(f"更新 end_user other_name 失败: {e}", exc_info=True)
# 用户实体占位名称,不允许作为 other_name 或出现在 aliases 中
# 复用 deduped_and_disamb 模块级常量,避免重复维护
USER_PLACEHOLDER_NAMES = _USER_PLACEHOLDER_NAMES
USER_PLACEHOLDER_NAMES = {'用户', '', 'User', 'I'}
def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode], dialog_data_list=None) -> List[str]:
"""用户发言的原始实体中提取本轮新增别名(绕过去重污染
def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode]) -> List[str]:
"""实体节点提取用户别名(保持 LLM 提取的原始顺序,不进行任何排序
策略:
仅从 dialog_data_list 中找到 speaker="user" 的 statement
从这些 statement 的 triplet_extraction_info 中提取用户实体的 aliases。
这样拿到的是 LLM 对用户原话的提取结果,不受去重合并的影响。
注意:不再使用去重后 entity_nodes 作为兜底,因为二层去重会将 Neo4j 历史别名
合并进来,导致历史别名被误认为"本轮提取"。历史别名的同步由
_extract_deduped_entity_aliases 负责。
这个方法直接返回 LLM 提取的别名列表,并过滤掉占位名称("用户""""User""I")。
第一个别名将被用作 other_name。
Args:
entity_nodes: 去重后的实体节点列表(未使用,保留参数兼容性)
dialog_data_list: 对话数据列表
entity_nodes: 实体节点列表
Returns:
别名列表(保持原始顺序,已过滤)
"""
if not dialog_data_list:
return []
all_user_aliases = []
seen_lower = set()
for dialog in dialog_data_list:
for chunk in dialog.chunks:
speaker = getattr(chunk, 'speaker', None)
for statement in chunk.statements:
stmt_speaker = getattr(statement, 'speaker', None) or speaker
if stmt_speaker != "user":
continue
triplet_info = getattr(statement, 'triplet_extraction_info', None)
if not triplet_info:
continue
for entity in (triplet_info.entities or []):
ent_name = getattr(entity, 'name', '').strip()
if ent_name.lower() in self.USER_PLACEHOLDER_NAMES:
for alias in (getattr(entity, 'aliases', []) or []):
a = alias.strip()
if a and a.lower() not in self.USER_PLACEHOLDER_NAMES and a.lower() not in seen_lower:
all_user_aliases.append(a)
seen_lower.add(a.lower())
if all_user_aliases:
logger.debug(f"从用户原始发言提取到别名: {all_user_aliases}")
return all_user_aliases
def _extract_deduped_entity_aliases(self, entity_nodes: List[ExtractedEntityNode]) -> List[str]:
"""从去重后的用户实体中提取完整别名列表。
二层去重会将 Neo4j 中已有的历史别名合并到 entity_nodes 的用户实体中,
因此这里提取到的别名包含了历史积累的所有别名,可用于同步到 PgSQL。
Args:
entity_nodes: 去重后的实体节点列表(含二层去重合并结果)
Returns:
别名列表(已过滤占位名称,去重保序)
别名列表(保持 LLM 提取的原始顺序,已过滤占位名称
"""
for entity in entity_nodes:
if getattr(entity, 'name', '').strip().lower() in self.USER_PLACEHOLDER_NAMES:
if getattr(entity, 'name', '').strip() in self.USER_PLACEHOLDER_NAMES:
aliases = getattr(entity, 'aliases', []) or []
filtered = [
a for a in aliases
if a.strip().lower() not in self.USER_PLACEHOLDER_NAMES
]
if filtered:
# 过滤掉占位名称,防止 "用户"/"我"/"User"/"I" 被存入 aliases 和 other_name
filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES]
logger.debug(f"提取到用户别名(原始顺序,已过滤占位名称): {filtered}")
return filtered
return []
async def _fetch_neo4j_assistant_aliases(self, end_user_id: str) -> set:
"""从 Neo4j 查询 AI 助手实体的所有别名(用于从用户别名中排除)"""
return await fetch_neo4j_assistant_aliases(self.connector, end_user_id)
async def _fetch_neo4j_user_aliases(self, end_user_id: str) -> List[str]:
"""从 Neo4j 查询用户实体的完整 aliases 列表(已过滤占位名称)"""
cypher = """
MATCH (e:ExtractedEntity)
WHERE e.end_user_id = $end_user_id AND e.name IN ['用户', '', 'User', 'I']
RETURN e.aliases AS aliases
LIMIT 1
"""
result = await Neo4jConnector().execute_query(cypher, end_user_id=end_user_id)
if not result:
logger.debug(f"Neo4j 中未找到用户实体: end_user_id={end_user_id}")
return []
aliases = result[0].get('aliases') or []
if not aliases:
logger.debug(f"Neo4j 用户实体 aliases 为空: end_user_id={end_user_id}")
return []
# 过滤掉占位名称,防止历史脏数据传播
filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES]
return filtered
def _resolve_other_name(
self,
@@ -1595,18 +1484,19 @@ class ExtractionOrchestrator:
注意:返回值不允许是占位名称("用户""""User""I"
"""
# 当前值为空或为占位名称时,需要更新
if not current or not current.strip() or current.strip().lower() in self.USER_PLACEHOLDER_NAMES:
if not current or not current.strip() or current.strip() in self.USER_PLACEHOLDER_NAMES:
candidate = current_aliases[0].strip() if current_aliases else None
# 确保候选值不是占位名称
if candidate and candidate.lower() in self.USER_PLACEHOLDER_NAMES:
if candidate and candidate in self.USER_PLACEHOLDER_NAMES:
return None
return candidate
if current not in neo4j_aliases:
candidate = neo4j_aliases[0].strip() if neo4j_aliases else None
# 确保候选值不是占位名称
if candidate and candidate.lower() in self.USER_PLACEHOLDER_NAMES:
if candidate and candidate in self.USER_PLACEHOLDER_NAMES:
return None
return candidate
return None
async def _run_dedup_and_write_summary(

View File

@@ -1,176 +0,0 @@
"""
Metadata extractor module.
Collects user-related statements from post-dedup graph data and
extracts user metadata via an independent LLM call.
"""
import logging
from typing import List, Optional
from app.core.memory.models.graph_models import (
ExtractedEntityNode,
StatementEntityEdge,
StatementNode,
)
logger = logging.getLogger(__name__)
# Reuse the same user-entity detection logic from dedup module
_USER_NAMES = {"用户", "", "user", "i"}
_CANONICAL_USER_TYPE = "用户"
def _is_user_entity(ent: ExtractedEntityNode) -> bool:
"""判断实体是否为用户实体"""
name = (getattr(ent, "name", "") or "").strip().lower()
etype = (getattr(ent, "entity_type", "") or "").strip()
return name in _USER_NAMES or etype == _CANONICAL_USER_TYPE
class MetadataExtractor:
"""Extracts user metadata from post-dedup graph data via independent LLM call."""
def __init__(self, llm_client, language: Optional[str] = None):
self.llm_client = llm_client
self.language = language
@staticmethod
def detect_language(statements: List[str]) -> str:
"""根据 statement 文本内容检测语言。
如果文本中包含中文字符则返回 "zh",否则返回 "en"
"""
import re
combined = " ".join(statements)
if re.search(r"[\u4e00-\u9fff]", combined):
return "zh"
return "en"
def collect_user_related_statements(
self,
entity_nodes: List[ExtractedEntityNode],
statement_nodes: List[StatementNode],
statement_entity_edges: List[StatementEntityEdge],
) -> List[str]:
"""
从去重后的数据中筛选与用户直接相关且由用户发言的 statement 文本。
筛选逻辑:
1. 用户实体 → StatementEntityEdge → statement直接关联
2. 只保留 speaker="user" 的 statement过滤 assistant 回复的噪声)
Returns:
用户发言的 statement 文本列表
"""
# Find user entity IDs
user_entity_ids = set()
for ent in entity_nodes:
if _is_user_entity(ent):
user_entity_ids.add(ent.id)
if not user_entity_ids:
logger.debug("未找到用户实体节点,跳过 statement 收集")
return []
# 用户实体 → StatementEntityEdge → statement
target_stmt_ids = set()
for edge in statement_entity_edges:
if edge.target in user_entity_ids:
target_stmt_ids.add(edge.source)
# Collect: only speaker="user" statements, preserving order
result = []
seen = set()
total_associated = 0
skipped_non_user = 0
for stmt_node in statement_nodes:
if stmt_node.id in target_stmt_ids and stmt_node.id not in seen:
total_associated += 1
speaker = getattr(stmt_node, "speaker", None) or "unknown"
if speaker == "user":
text = (stmt_node.statement or "").strip()
if text:
result.append(text)
else:
skipped_non_user += 1
seen.add(stmt_node.id)
logger.info(
f"收集到 {len(result)} 条用户发言 statement "
f"(直接关联: {total_associated}, speaker=user: {len(result)}, "
f"跳过非user: {skipped_non_user})"
)
if result:
for i, text in enumerate(result):
logger.info(f" [user statement {i + 1}] {text}")
if total_associated > 0 and len(result) == 0:
logger.warning(
f"{total_associated} 条直接关联 statement 但全部被 speaker 过滤,"
f"可能本次写入不包含 user 消息"
)
return result
async def extract_metadata(
self,
statements: List[str],
existing_metadata: Optional[dict] = None,
existing_aliases: Optional[List[str]] = None,
) -> Optional[tuple]:
"""
对筛选后的 statement 列表调用 LLM 提取元数据增量变更和用户别名。
Args:
statements: 用户发言的 statement 文本列表
existing_metadata: 数据库已有的元数据(可选)
existing_aliases: 数据库已有的用户别名列表(可选)
Returns:
(List[MetadataFieldChange], List[str], List[str]) tuple:
(metadata_changes, aliases_to_add, aliases_to_remove) on success, None on failure
"""
if not statements:
return None
try:
from app.core.memory.utils.prompt.prompt_utils import prompt_env
if self.language:
detected_language = self.language
logger.info(f"元数据提取使用显式指定语言: {detected_language}")
else:
detected_language = self.detect_language(statements)
logger.info(f"元数据提取语言自动检测结果: {detected_language}")
template = prompt_env.get_template("extract_user_metadata.jinja2")
prompt = template.render(
statements=statements,
language=detected_language,
existing_metadata=existing_metadata,
existing_aliases=existing_aliases,
json_schema="",
)
from app.core.memory.models.metadata_models import (
MetadataExtractionResponse,
)
response = await self.llm_client.response_structured(
messages=[{"role": "user", "content": prompt}],
response_model=MetadataExtractionResponse,
)
if response:
changes = response.metadata_changes if response.metadata_changes else []
to_add = response.aliases_to_add if response.aliases_to_add else []
to_remove = (
response.aliases_to_remove if response.aliases_to_remove else []
)
return changes, to_add, to_remove
logger.warning("LLM 返回的响应为空")
return None
except Exception as e:
logger.error(f"元数据提取 LLM 调用失败: {e}", exc_info=True)
return None

View File

@@ -1,5 +1,6 @@
import asyncio
import logging
import os
from datetime import datetime
from typing import Any, Dict, List, Optional
@@ -81,7 +82,6 @@ class StatementExtractor:
logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty")
return None
async def _extract_statements(self, chunk, end_user_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]:
"""Process a single chunk and return extracted statements
@@ -94,7 +94,6 @@ class StatementExtractor:
List of ExtractedStatement objects extracted from the chunk
"""
chunk_content = chunk.content
chunk_speaker = self._get_speaker_from_chunk(chunk)
if not chunk_content or len(chunk_content.strip()) < 5:
logger.warning(f"Chunk {chunk.id} content too short or empty, skipping")
@@ -151,6 +150,8 @@ class StatementExtractor:
except (KeyError, ValueError):
relevence_info = RelevenceInfo.RELEVANT
chunk_speaker = self._get_speaker_from_chunk(chunk)
chunk_statement = Statement(
statement=extracted_stmt.statement,
stmt_type=stmt_type,

View File

@@ -1,3 +1,4 @@
import os
import asyncio
from typing import List, Dict, Optional
@@ -60,7 +61,6 @@ class TripletExtractor:
predicate_instructions=PREDICATE_DEFINITIONS,
language=self._get_language(),
ontology_types=self.ontology_types,
speaker=getattr(statement, 'speaker', None),
)
# Create messages for LLM

View File

@@ -42,21 +42,22 @@ class AccessHistoryManager:
- access_count: 访问次数
特性:
- 原子性更新:使用 APOC 原子操作确保并发安全
- 批次内合并:同一批次中对同一节点的多次访问合并为一次更新
- 原子性更新:使用Neo4j事务确保所有字段同时更新或回滚
- 并发安全:使用乐观锁机制防止并发冲突
- 一致性保证:提供一致性检查和自动修复功能
- 智能修剪:自动修剪过长的访问历史
Attributes:
connector: Neo4j连接器实例
actr_calculator: ACT-R激活值计算器实例
max_retries: 并发冲突时的最大重试次数
"""
def __init__(
self,
connector: Neo4jConnector,
actr_calculator: ACTRCalculator,
max_retries: int = 5
max_retries: int = 3
):
"""
初始化访问历史管理器
@@ -64,35 +65,47 @@ class AccessHistoryManager:
Args:
connector: Neo4j连接器实例
actr_calculator: ACT-R激活值计算器实例
max_retries: 已废弃保留参数兼容性APOC 原子操作无需重试
max_retries: 并发冲突时的最大重试次数默认3次
"""
self.connector = connector
self.actr_calculator = actr_calculator
self.max_retries = max_retries
async def record_access(
self,
node_id: str,
node_label: str,
end_user_id: Optional[str] = None,
current_time: Optional[datetime] = None,
access_times: int = 1
current_time: Optional[datetime] = None
) -> Dict[str, Any]:
"""
记录节点访问并原子性更新所有相关字段
这是核心方法,实现了:
1. 首次访问初始化access_history计算初始激活值
2. 后续访问:追加访问历史,重新计算激活值
3. 历史修剪:当历史过长时自动修剪
4. 原子性:所有字段在单个事务中更新
5. 并发安全:使用乐观锁重试机制
Args:
node_id: 节点ID
node_label: 节点标签Statement, ExtractedEntity, MemorySummary
end_user_id: 组ID可选用于过滤
current_time: 当前时间(可选,默认使用系统时间)
access_times: 本次访问次数默认1批量合并时可能大于1
Returns:
Dict[str, Any]: 更新后的节点数据
Dict[str, Any]: 更新后的节点数据,包含:
- id: 节点ID
- activation_value: 更新后的激活值
- access_history: 更新后的访问历史
- last_access_time: 最后访问时间
- access_count: 访问次数
- importance_score: 重要性分数
Raises:
ValueError: 如果节点不存在或节点标签无效
RuntimeError: 如果更新失败
RuntimeError: 如果重试次数耗尽仍然失败
"""
if current_time is None:
current_time = datetime.now()
@@ -106,6 +119,8 @@ class AccessHistoryManager:
f"Invalid node_label: {node_label}. Must be one of {valid_labels}"
)
# 使用乐观锁重试机制处理并发冲突
for attempt in range(self.max_retries):
try:
# 步骤1读取当前节点状态
node_data = await self._fetch_node(node_id, node_label, end_user_id)
@@ -119,11 +134,10 @@ class AccessHistoryManager:
update_data = await self._calculate_update(
node_data=node_data,
current_time=current_time,
current_time_iso=current_time_iso,
access_times=access_times
current_time_iso=current_time_iso
)
# 步骤3使用 APOC 原子操作更新节点(无需重试
# 步骤3原子性更新节点(使用事务
updated_node = await self._atomic_update(
node_id=node_id,
node_label=node_label,
@@ -135,18 +149,24 @@ class AccessHistoryManager:
f"成功记录访问: {node_label}[{node_id}], "
f"activation={update_data['activation_value']:.4f}, "
f"access_count={update_data['access_count']}"
f"{f', 合并访问次数={access_times}' if access_times > 1 else ''}"
)
return updated_node
except Exception as e:
if attempt < self.max_retries - 1:
logger.warning(
f"访问记录失败(尝试 {attempt + 1}/{self.max_retries}: {str(e)}"
)
continue
else:
logger.error(
f"访问记录失败: {node_label}[{node_id}], 错误: {str(e)}"
f"访问记录失败,重试次数耗尽: {node_label}[{node_id}], "
f"错误: {str(e)}"
)
raise RuntimeError(
f"Failed to record access: {str(e)}"
) from e
f"Failed to record access after {self.max_retries} attempts: {str(e)}"
)
async def record_batch_access(
self,
@@ -158,10 +178,11 @@ class AccessHistoryManager:
"""
批量记录多个节点的访问
对同一个节点的多次访问会先在内存中合并,只发起一次更新
为提高性能,批量更新多个节点的访问历史
每个节点独立更新,失败的节点不影响其他节点。
Args:
node_ids: 节点ID列表可包含重复ID
node_ids: 节点ID列表
node_label: 节点标签(所有节点必须是同一类型)
end_user_id: 组ID可选
current_time: 当前时间(可选)
@@ -175,38 +196,25 @@ class AccessHistoryManager:
if current_time is None:
current_time = datetime.now()
# 合并同一节点的访问次数,避免对同一节点并发写入
access_count_map: Dict[str, int] = {}
for node_id in node_ids:
access_count_map[node_id] = access_count_map.get(node_id, 0) + 1
merged_count = len(node_ids) - len(access_count_map)
if merged_count > 0:
logger.info(
f"批量访问合并: 原始={len(node_ids)}, "
f"去重后={len(access_count_map)}, 合并={merged_count}"
)
# 对去重后的节点并行发起更新
# PERFORMANCE FIX: Process all nodes in parallel instead of sequentially
tasks = []
for node_id, access_times in access_count_map.items():
for node_id in node_ids:
task = self.record_access(
node_id=node_id,
node_label=node_label,
end_user_id=end_user_id,
current_time=current_time,
access_times=access_times
current_time=current_time
)
tasks.append((node_id, task))
tasks.append(task)
task_results = await asyncio.gather(
*[t for _, t in tasks], return_exceptions=True
)
# Execute all tasks in parallel
task_results = await asyncio.gather(*tasks, return_exceptions=True)
# Collect successful results and count failures
results = []
failed_count = 0
for (node_id, _), result in zip(tasks, task_results):
for node_id, result in zip(node_ids, task_results):
if isinstance(result, Exception):
failed_count += 1
logger.warning(
@@ -217,7 +225,7 @@ class AccessHistoryManager:
batch_duration = time.time() - batch_start
logger.info(
f"[PERF] 批量访问记录完成: 成功 {len(results)}/{len(access_count_map)}, "
f"[PERF] 批量访问记录完成: 成功 {len(results)}/{len(node_ids)}, "
f"失败 {failed_count}, 耗时 {batch_duration:.4f}s"
)
@@ -231,6 +239,22 @@ class AccessHistoryManager:
) -> Tuple[ConsistencyCheckResult, Optional[str]]:
"""
检查节点数据的一致性
验证以下一致性规则:
1. access_history[-1] == last_access_time
2. len(access_history) == access_count
3. 如果有访问历史,必须有激活值
4. 激活值必须在有效范围内 [offset, 1.0]
Args:
node_id: 节点ID
node_label: 节点标签
end_user_id: 组ID可选
Returns:
Tuple[ConsistencyCheckResult, Optional[str]]:
- 一致性检查结果枚举
- 错误描述(如果不一致)
"""
node_data = await self._fetch_node(node_id, node_label, end_user_id)
@@ -242,6 +266,7 @@ class AccessHistoryManager:
access_count = node_data.get('access_count', 0)
activation_value = node_data.get('activation_value')
# 检查1access_history[-1] == last_access_time
if access_history and last_access_time:
if access_history[-1] != last_access_time:
return (
@@ -250,6 +275,7 @@ class AccessHistoryManager:
f"last_access_time={last_access_time}"
)
# 检查2len(access_history) == access_count
if len(access_history) != access_count:
return (
ConsistencyCheckResult.INCONSISTENT_HISTORY_COUNT,
@@ -257,12 +283,14 @@ class AccessHistoryManager:
f"access_count={access_count}"
)
# 检查3有访问历史必须有激活值
if access_history and activation_value is None:
return (
ConsistencyCheckResult.MISSING_ACTIVATION,
"Node has access_history but activation_value is None"
)
# 检查4激活值范围
if activation_value is not None:
offset = self.actr_calculator.offset
if not (offset <= activation_value <= 1.0):
@@ -280,7 +308,23 @@ class AccessHistoryManager:
end_user_id: Optional[str] = None,
limit: int = 1000
) -> Dict[str, Any]:
"""批量检查多个节点的一致性"""
"""
批量检查多个节点的一致性
Args:
node_label: 节点标签
end_user_id: 组ID可选
limit: 检查的最大节点数
Returns:
Dict[str, Any]: 一致性检查报告,包含:
- total_checked: 检查的节点总数
- consistent_count: 一致的节点数
- inconsistent_count: 不一致的节点数
- inconsistencies: 不一致节点的详细信息列表
- consistency_rate: 一致性率0-1
"""
# 查询所有相关节点
query = f"""
MATCH (n:{node_label})
WHERE n.access_history IS NOT NULL
@@ -299,6 +343,7 @@ class AccessHistoryManager:
results = await self.connector.execute_query(query, **params)
node_ids = [r['id'] for r in results]
# 检查每个节点
inconsistencies = []
consistent_count = 0
@@ -344,8 +389,25 @@ class AccessHistoryManager:
node_label: str,
end_user_id: Optional[str] = None
) -> bool:
"""自动修复节点的数据不一致问题"""
"""
自动修复节点的数据不一致问题
修复策略:
1. 如果access_history[-1] != last_access_time使用access_history[-1]
2. 如果len(access_history) != access_count使用len(access_history)
3. 如果有历史但无激活值:重新计算激活值
4. 如果激活值超出范围:重新计算激活值
Args:
node_id: 节点ID
node_label: 节点标签
end_user_id: 组ID可选
Returns:
bool: 修复成功返回True否则返回False
"""
try:
# 检查一致性
result, message = await self.check_consistency(
node_id=node_id,
node_label=node_label,
@@ -356,6 +418,7 @@ class AccessHistoryManager:
logger.info(f"节点数据一致,无需修复: {node_label}[{node_id}]")
return True
# 获取节点数据
node_data = await self._fetch_node(node_id, node_label, end_user_id)
if not node_data:
logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]")
@@ -364,13 +427,17 @@ class AccessHistoryManager:
access_history = node_data.get('access_history') or []
importance_score = node_data.get('importance_score', 0.5)
# 准备修复数据
repair_data = {}
# 修复last_access_time
if access_history:
repair_data['last_access_time'] = access_history[-1]
# 修复access_count
repair_data['access_count'] = len(access_history)
# 修复activation_value
if access_history:
current_time = datetime.now()
last_access_dt = datetime.fromisoformat(access_history[-1])
@@ -386,6 +453,7 @@ class AccessHistoryManager:
)
repair_data['activation_value'] = activation_value
# 执行修复
query = f"""
MATCH (n:{node_label} {{id: $node_id}})
"""
@@ -425,7 +493,17 @@ class AccessHistoryManager:
node_label: str,
end_user_id: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""获取节点数据"""
"""
获取节点数据
Args:
node_id: 节点ID
node_label: 节点标签
end_user_id: 组ID可选
Returns:
Optional[Dict[str, Any]]: 节点数据如果不存在返回None
"""
query = f"""
MATCH (n:{node_label} {{id: $node_id}})
"""
@@ -454,8 +532,7 @@ class AccessHistoryManager:
self,
node_data: Dict[str, Any],
current_time: datetime,
current_time_iso: str,
access_times: int = 1
current_time_iso: str
) -> Dict[str, Any]:
"""
计算更新数据
@@ -464,38 +541,43 @@ class AccessHistoryManager:
node_data: 当前节点数据
current_time: 当前时间datetime对象
current_time_iso: 当前时间ISO格式字符串
access_times: 本次访问次数合并后可能大于1
Returns:
Dict[str, Any]: 更新数据
Dict[str, Any]: 更新数据,包含所有需要更新的字段
"""
access_history = node_data.get('access_history') or []
# Handle None importance_score - default to 0.5
importance_score = node_data.get('importance_score')
if importance_score is None:
importance_score = 0.5
# 本次新增的时间
new_timestamps = [current_time_iso] * access_times
# 追加新的访问时间
new_access_history = access_history + [current_time_iso]
# 仅用本次新增的访问记录计算激活值
new_history_dt = [current_time] * access_times
# 修剪访问历史(如果过长)
access_history_dt = [
datetime.fromisoformat(ts) for ts in new_access_history
]
trimmed_history_dt = self.actr_calculator.trim_access_history(
access_history=new_history_dt,
access_history=access_history_dt,
current_time=current_time
)
trimmed_history = [ts.isoformat() for ts in trimmed_history_dt]
# 计算新的激活值
activation_value = self.actr_calculator.calculate_memory_activation(
access_history=trimmed_history_dt,
current_time=current_time,
last_access_time=current_time,
last_access_time=current_time, # 最后访问时间就是当前时间
importance_score=importance_score
)
# 返回所有需要更新的字段
return {
'activation_value': activation_value,
'new_timestamps': new_timestamps,
'access_count_delta': access_times,
'access_count': len(trimmed_history_dt),
'access_history': trimmed_history,
'last_access_time': current_time_iso,
'access_count': len(trimmed_history)
}
async def _atomic_update(
@@ -506,10 +588,10 @@ class AccessHistoryManager:
end_user_id: Optional[str] = None
) -> Dict[str, Any]:
"""
原子性更新节点(使用 APOC 原子操作
原子性更新节点(使用乐观锁
使用 apoc.atomic.add 和 apoc.atomic.insert 保证并发安全,
无需 version 字段和乐观锁,数据库层面保证原子性
使用Neo4j事务和版本号确保所有字段同时更新或回滚。
实现乐观锁机制防止并发冲突
Args:
node_id: 节点ID
@@ -521,15 +603,49 @@ class AccessHistoryManager:
Dict[str, Any]: 更新后的节点数据
Raises:
RuntimeError: 如果更新失败
RuntimeError: 如果更新失败或发生版本冲突
"""
# 定义事务函数
async def update_transaction(tx, node_id, node_label, update_data, end_user_id):
# 步骤1读取当前节点并获取版本号
read_query = f"""
MATCH (n:{node_label} {{id: $node_id}})
"""
if end_user_id:
read_query += " WHERE n.end_user_id = $end_user_id"
read_query += """
RETURN n.id as id,
n.version as version,
n.activation_value as activation_value,
n.access_history as access_history,
n.last_access_time as last_access_time,
n.access_count as access_count,
n.importance_score as importance_score
"""
read_params = {'node_id': node_id}
if end_user_id:
read_params['end_user_id'] = end_user_id
read_result = await tx.run(read_query, **read_params)
current_node = await read_result.single()
if not current_node:
raise RuntimeError(f"Node not found: {node_label}[{node_id}]")
# 获取当前版本号如果不存在则为0
current_version = current_node.get('version', 0) or 0
new_version = current_version + 1
# 步骤2使用乐观锁更新节点
# 根据节点类型构建完整的查询语句
content_field_map = {
'Statement': 'n.statement as statement',
'MemorySummary': 'n.content as content',
'ExtractedEntity': 'null as content_placeholder',
'Community': 'n.summary as summary'
'ExtractedEntity': 'null as content_placeholder' # 占位符,后续会被过滤
}
# 显式检查节点类型,不支持的类型抛出错误
if node_label not in content_field_map:
raise ValueError(
f"Unsupported node_label: {node_label}. "
@@ -538,51 +654,75 @@ class AccessHistoryManager:
content_field = content_field_map[node_label]
where_clause = ""
# 构建 WHERE 子句
where_conditions = []
if end_user_id:
where_clause = " AND n.end_user_id = $end_user_id"
where_conditions.append("n.end_user_id = $end_user_id")
query = f"""
# 添加版本检查
if current_version > 0:
where_conditions.append("n.version = $current_version")
else:
where_conditions.append("(n.version IS NULL OR n.version = 0)")
where_clause = " AND ".join(where_conditions) if where_conditions else "true"
# 构建完整的更新查询
update_query = f"""
MATCH (n:{node_label} {{id: $node_id}})
WHERE true{where_clause}
CALL apoc.atomic.add(n, 'access_count', $access_count_delta, 5) YIELD oldValue AS old_count
WITH n
CALL (n) {{
UNWIND $new_timestamps AS ts
CALL apoc.atomic.insert(n, 'access_history', size(n.access_history), ts, 5) YIELD oldValue
RETURN count(*) AS inserted
}}
WHERE {where_clause}
SET n.activation_value = $activation_value,
n.last_access_time = $last_access_time
n.access_history = $access_history,
n.last_access_time = $last_access_time,
n.access_count = $access_count,
n.version = $new_version
RETURN n.id as id,
n.activation_value as activation_value,
n.access_history as access_history,
n.last_access_time as last_access_time,
n.access_count as access_count,
n.importance_score as importance_score,
n.version as version,
{content_field}
"""
params = {
update_params = {
'node_id': node_id,
'access_count_delta': update_data['access_count_delta'],
'new_timestamps': update_data['new_timestamps'],
'current_version': current_version,
'new_version': new_version,
'activation_value': update_data['activation_value'],
'access_history': update_data['access_history'],
'last_access_time': update_data['last_access_time'],
'access_count': update_data['access_count']
}
if end_user_id:
params['end_user_id'] = end_user_id
update_params['end_user_id'] = end_user_id
try:
results = await self.connector.execute_query(query, **params)
update_result = await tx.run(update_query, **update_params)
updated_node = await update_result.single()
if not results:
raise RuntimeError(f"Node not found: {node_label}[{node_id}]")
if not updated_node:
raise RuntimeError(
f"Version conflict detected for {node_label}[{node_id}]. "
f"Expected version {current_version}, but node was modified by another transaction."
)
result_dict = dict(results[0])
# 转换为字典并移除占位符字段
result_dict = dict(updated_node)
result_dict.pop('content_placeholder', None)
return result_dict
# 执行事务
try:
result = await self.connector.execute_write_transaction(
update_transaction,
node_id=node_id,
node_label=node_label,
update_data=update_data,
end_user_id=end_user_id
)
return result
except Exception as e:
logger.error(
f"原子性更新失败: {node_label}[{node_id}], 错误: {str(e)}"

View File

@@ -5,7 +5,7 @@
使用Neo4j的全文索引进行高效的文本匹配。
"""
from typing import List, Optional
from typing import List, Dict, Any, Optional
from app.core.logging_config import get_memory_logger
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult
@@ -74,7 +74,7 @@ class KeywordSearchStrategy(SearchStrategy):
# 调用底层的关键词搜索函数
results_dict = await search_graph(
connector=self.connector,
query=query_text,
q=query_text,
end_user_id=end_user_id,
limit=limit,
include=include_list

View File

@@ -22,9 +22,7 @@ def escape_lucene_query(query: str) -> str:
s = s.replace("\r", " ").replace("\n", " ").strip()
# Lucene reserved tokens/special characters
# NOTE: '/' is the regex delimiter in Lucene — must be escaped to prevent
# TokenMgrError when the query contains unmatched slashes.
specials = ['&&', '||', '\\', '+', '-', '!', '(', ')', '{', '}', '[', ']', '^', '"', '~', '*', '?', ':', '/']
specials = ['&&', '||', '\\', '+', '-', '!', '(', ')', '{', '}', '[', ']', '^', '"', '~', '*', '?', ':']
# Replace longer tokens first to avoid partial double-escaping
for token in sorted(specials, key=len, reverse=True):
s = s.replace(token, f"\\{token}")

View File

@@ -1,6 +1,6 @@
import os
from jinja2 import Environment, FileSystemLoader
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
from app.core.memory.utils.log.logging_utils import log_prompt_rendering, log_template_rendering
# Setup Jinja2 environment
@@ -205,7 +205,6 @@ async def render_triplet_extraction_prompt(
predicate_instructions: dict = None,
language: str = "zh",
ontology_types: "OntologyTypeList | None" = None,
speaker: str = None,
) -> str:
"""
Renders the triplet extraction prompt using the extract_triplet.jinja2 template.
@@ -217,7 +216,6 @@ async def render_triplet_extraction_prompt(
predicate_instructions: Optional predicate instructions
language: The language to use for entity descriptions ("zh" for Chinese, "en" for English)
ontology_types: Optional OntologyTypeList containing predefined ontology types for entity classification
speaker: Speaker role ("user" or "assistant") for the current statement
Returns:
Rendered prompt content as string
@@ -225,7 +223,7 @@ async def render_triplet_extraction_prompt(
template = prompt_env.get_template("extract_triplet.jinja2")
# 准备本体类型数据
ontology_type_section = None
ontology_type_section = ""
ontology_type_names = []
type_hierarchy_hints = []
if ontology_types and ontology_types.types:
@@ -242,7 +240,6 @@ async def render_triplet_extraction_prompt(
ontology_types=ontology_type_section,
ontology_type_names=ontology_type_names,
type_hierarchy_hints=type_hierarchy_hints,
speaker=speaker,
)
# 记录渲染结果到提示日志(与示例日志结构一致)
log_prompt_rendering('triplet extraction', rendered_prompt)

View File

@@ -43,9 +43,8 @@ Each statement must be labeled as per the criteria mentioned below.
对话上下文和共指消解:
- 将每个陈述句归属于说出它的参与者。
- **对于用户的发言:必须使用"用户"作为主语**,禁止将"用户"或"我"替换为用户的真实姓名或别名。例如,用户说"我叫张三"应提取为"用户叫张三",而不是"张三叫张三"
- 对于 AI 助手的发言:使用"助手"或"AI助手"作为主语
- 将所有代词解析为对话上下文中的具体人物或实体,但"我"必须解析为"用户"。
- 如果参与者列表为说话者提供了名称(例如,"李雪(用户)"),请在提取的陈述句中使用具体名称("李雪"),而不是通用角色("用户"
- 将所有代词解析为对话上下文中的具体人物或实体
- 识别并将抽象引用解析为其具体名称(如果提到)。
- 将缩写和首字母缩略词扩展为其完整形式。
{% else %}
@@ -69,9 +68,8 @@ Context Resolution Requirements:
Conversational Context & Co-reference Resolution:
- Attribute every statement to the participant who uttered it.
- **For user's statements: always use "用户" (User) as the subject**. Do NOT replace "用户" or "I" with the user's real name or alias. For example, if the user says "I'm John", extract as "用户 is John", not "John is John".
- For AI assistant's statements: use "助手" or "AI助手" as the subject.
- Resolve all pronouns to the specific person or entity from the conversation's context, but "I"/"我" must always resolve to "用户".
- If the participant list provides a name for a speaker (e.g., "李雪 (用户)"), use the specific name ("李雪") in the extracted statement, not the generic role ("用户").
- Resolve all pronouns to the specific person or entity from the conversation's context.
- Identify and resolve abstract references to their specific names if mentioned.
- Expand abbreviations and acronyms to their full form.
{% endif %}
@@ -141,13 +139,13 @@ AI: "水彩画很有趣!水彩颜料通常由颜料与阿拉伯树胶等粘合
示例输出: {
"statements": [
{
"statement": "用户最近一直在尝试水彩画。",
"statement": "Sarah Chen 最近一直在尝试水彩画。",
"statement_type": "FACT",
"temporal_type": "DYNAMIC",
"relevance": "RELEVANT"
},
{
"statement": "用户画了一些花朵。",
"statement": "Sarah Chen 画了一些花朵。",
"statement_type": "FACT",
"temporal_type": "DYNAMIC",
"relevance": "RELEVANT"
@@ -159,13 +157,13 @@ AI: "水彩画很有趣!水彩颜料通常由颜料与阿拉伯树胶等粘合
"relevance": "IRRELEVANT"
},
{
"statement": "用户认为她的水彩画中的色彩组合可以改进。",
"statement": "Sarah Chen 认为她的水彩画中的色彩组合可以改进。",
"statement_type": "OPINION",
"temporal_type": "STATIC",
"relevance": "RELEVANT"
},
{
"statement": "用户真的很喜欢玫瑰和百合。",
"statement": "Sarah Chen 真的很喜欢玫瑰和百合。",
"statement_type": "FACT",
"temporal_type": "STATIC",
"relevance": "RELEVANT"
@@ -188,13 +186,13 @@ AI: "水彩画很有趣!水彩颜料通常由颜料和阿拉伯树胶等粘合
示例输出: {
"statements": [
{
"statement": "用户最近在尝试水彩画。",
"statement": "张曼婷最近在尝试水彩画。",
"statement_type": "FACT",
"temporal_type": "DYNAMIC",
"relevance": "RELEVANT"
},
{
"statement": "用户画了一些花朵。",
"statement": "张曼婷画了一些花朵。",
"statement_type": "FACT",
"temporal_type": "DYNAMIC",
"relevance": "RELEVANT"
@@ -206,13 +204,13 @@ AI: "水彩画很有趣!水彩颜料通常由颜料和阿拉伯树胶等粘合
"relevance": "IRRELEVANT"
},
{
"statement": "用户觉得水彩画的色彩搭配还有提升的空间。",
"statement": "张曼婷觉得水彩画的色彩搭配还有提升的空间。",
"statement_type": "OPINION",
"temporal_type": "STATIC",
"relevance": "RELEVANT"
},
{
"statement": "用户很喜欢玫瑰和百合。",
"statement": "张曼婷很喜欢玫瑰和百合。",
"statement_type": "FACT",
"temporal_type": "STATIC",
"relevance": "RELEVANT"
@@ -235,13 +233,13 @@ User: "I think the color combinations could use some improvement, but I really l
Example Output: {
"statements": [
{
"statement": "用户 has been trying watercolor painting recently.",
"statement": "Sarah Chen has been trying watercolor painting recently.",
"statement_type": "FACT",
"temporal_type": "DYNAMIC",
"relevance": "RELEVANT"
},
{
"statement": "用户 painted some flowers.",
"statement": "Sarah Chen painted some flowers.",
"statement_type": "FACT",
"temporal_type": "DYNAMIC",
"relevance": "RELEVANT"
@@ -253,13 +251,13 @@ Example Output: {
"relevance": "IRRELEVANT"
},
{
"statement": "用户 thinks the color combinations in her watercolor paintings could use some improvement.",
"statement": "Sarah Chen thinks the color combinations in her watercolor paintings could use some improvement.",
"statement_type": "OPINION",
"temporal_type": "STATIC",
"relevance": "RELEVANT"
},
{
"statement": "用户 really likes roses and lilies.",
"statement": "Sarah Chen really likes roses and lilies.",
"statement_type": "FACT",
"temporal_type": "STATIC",
"relevance": "RELEVANT"
@@ -282,13 +280,13 @@ AI: "水彩画很有趣!水彩颜料通常由颜料和阿拉伯树胶等粘合
Example Output: {
"statements": [
{
"statement": "用户最近在尝试水彩画。",
"statement": "张曼婷最近在尝试水彩画。",
"statement_type": "FACT",
"temporal_type": "DYNAMIC",
"relevance": "RELEVANT"
},
{
"statement": "用户画了一些花朵。",
"statement": "张曼婷画了一些花朵。",
"statement_type": "FACT",
"temporal_type": "DYNAMIC",
"relevance": "RELEVANT"
@@ -300,13 +298,13 @@ Example Output: {
"relevance": "IRRELEVANT"
},
{
"statement": "用户觉得水彩画的色彩搭配还有提升的空间。",
"statement": "张曼婷觉得水彩画的色彩搭配还有提升的空间。",
"statement_type": "OPINION",
"temporal_type": "STATIC",
"relevance": "RELEVANT"
},
{
"statement": "用户很喜欢玫瑰和百合。",
"statement": "张曼婷很喜欢玫瑰和百合。",
"statement_type": "FACT",
"temporal_type": "STATIC",
"relevance": "RELEVANT"

View File

@@ -23,16 +23,6 @@ Extract entities and knowledge triplets from the given statement.
===Inputs===
**Chunk Content:** "{{ chunk_content }}"
**Statement:** "{{ statement }}"
{% if speaker %}
**Speaker:** {{ speaker }}
{% if speaker == "assistant" %}
{% if language == "zh" %}
⚠️ 当前陈述句来自 **AI助手的回复**。AI助手在回复中用来称呼用户的名字是**用户的别名**,不是 AI 助手的别名。但只能提取原文中逐字出现的名字,严禁推测或创造原文中不存在的别名变体。
{% else %}
⚠️ This statement is from the **AI assistant's reply**. Names the AI uses to address the user are **user's aliases**, NOT the AI assistant's aliases. But only extract names that appear VERBATIM in the text — never infer or fabricate alias variants.
{% endif %}
{% endif %}
{% endif %}
{% if ontology_types %}
===Ontology Type Guidance===
@@ -97,17 +87,7 @@ Extract entities and knowledge triplets from the given statement.
* "我叫张三,大家叫我小张" → aliases=["张三", "小张"](张三是第一个,将成为 other_name
* "大家叫我小李,我全名叫李明" → aliases=["小李", "李明"](小李先出现,将成为 other_name
- 空值:如果没有别名,使用 `[]`
- **🚨🚨🚨 严禁幻觉:只提取对话原文中逐字出现的别名,绝对不能推测、衍生或创造任何未在原文中出现的名字。例如,看到"陈思远"不能自行添加"思远大人""远哥""小远"等变体。如果原文没有这些字,就不能出现在 aliases 中。**
- **🚨 归属区分:必须严格区分名称的归属对象。默认情况下,用户提到的名字归属用户实体。只有出现明确的第二人称命名表达(如"叫你""给你取名")时,才将名字归属 AI/助手实体。**
- **🚨 说话人视角:当 speaker 为 assistant 时AI 助手用来称呼用户的名字是用户的别名,必须归入用户实体的 aliases绝对不能归入 AI 助手实体。但同样只能提取原文中逐字出现的称呼,不能推测。**
* "我叫陈思远我给AI取名为远仔" → 用户 aliases=["陈思远"]AI助手 aliases=["远仔"]
* "我叫vv" → 用户 aliases=["vv"]没有给AI取名的表达名字归用户
* [speaker=assistant] "好的VV" → 用户 aliases=["VV"]AI 在称呼用户,原文中出现了"VV"
* [speaker=assistant] "我叫陈仔" → AI助手 aliases=["陈仔"]AI 在自我介绍,这是 AI 的别名)
* ❌ 错误:将"远仔"放入用户的 aliases"远仔"是给AI取的名字不是用户的名字
* ❌ 错误:用户说"我叫vv",却把"vv"放入 AI 助手的 aliases
* ❌ 错误AI 称呼用户为"VV",却把"VV"放入 AI 助手的 aliases
* ❌ 错误:原文只有"陈思远",却在 aliases 中添加"思远大人""远哥""小远"等从未出现的变体(这是幻觉)
- 重要:只提取本次对话中明确提到的别名,不要推测或添加未提及的名字
{% else %}
- Include: nicknames, full names, abbreviations, alternative names
- Order: **The FIRST alias will be used as the user's primary display name (other_name). Put the most important/frequently used name FIRST**
@@ -116,17 +96,7 @@ Extract entities and knowledge triplets from the given statement.
* "I'm John, people call me Johnny" → aliases=["John", "Johnny"] (John is first, will become other_name)
* "People call me Mike, my full name is Michael" → aliases=["Mike", "Michael"] (Mike appears first, will become other_name)
- Empty: If no aliases, use `[]`
- **🚨🚨🚨 NO HALLUCINATION: Only extract aliases that appear VERBATIM in the original text. NEVER infer, derive, or fabricate names not present in the text. For example, seeing "John Smith" does NOT allow adding "Johnny", "Smithy", "Mr. Smith" unless those exact strings appear in the conversation.**
- **🚨 Ownership distinction: By default, all names mentioned by the user belong to the user entity. Only assign a name to the AI/assistant entity when an explicit second-person naming expression (e.g., "I'll call you", "your name is") is present.**
- **🚨 Speaker perspective: When speaker is "assistant", names the AI uses to address the user are the USER's aliases and MUST go into the user entity's aliases, NEVER into the AI assistant entity's aliases. But only extract names that appear verbatim in the text, never infer.**
* "I'm Alex, I'll call you Buddy" → User aliases=["Alex"], AI assistant aliases=["Buddy"]
* "I'm vv" → User aliases=["vv"] (no AI-naming expression, name belongs to user)
* [speaker=assistant] "Sure thing, VV" → User aliases=["VV"] (AI addressing the user, "VV" appears in text)
* [speaker=assistant] "I'm Jarvis" → AI assistant aliases=["Jarvis"] (AI self-introduction, this is AI's alias)
* ❌ Wrong: putting "Buddy" in user's aliases ("Buddy" is a name for the AI, not the user)
* ❌ Wrong: User says "I'm vv" but "vv" is put in AI assistant's aliases
* ❌ Wrong: AI calls user "VV" but "VV" is put in AI assistant's aliases
* ❌ Wrong: Text only has "John Smith" but aliases include "Johnny", "Smithy" (hallucinated variants)
- Important: Only extract aliases explicitly mentioned in current conversation, do not infer or add unmentioned names
{% endif %}
@@ -152,60 +122,7 @@ Extract entities and knowledge triplets from the given statement.
4. **AI/ASSISTANT ENTITY SPECIAL HANDLING:**
{% if language == "zh" %}
- **🚨 默认规则:如果对话中没有出现明确指向 AI/助手的命名表达,则所有名字都归属于用户实体。不要猜测或推断某个名字是给 AI 取的。**
- 只有当用户**明确**对 AI/助手进行命名时,才创建 AI/助手实体并将对应名字放入其 aliases
- AI/助手实体的 name 字段:使用 "AI助手"
- 用户给 AI 取的名字:放入 AI/助手实体的 aliases
- **🚨 禁止将用户给 AI 取的名字放入用户实体的 aliases 中**
- **必须出现以下明确的命名表达才能判定为给 AI 取名:**「给你取名」「叫你」「称呼你为」「给AI取名」「你的名字是」「以后叫你」「你就叫」「你不叫X了」「你现在叫」等**第二人称(你)或明确指向 AI 的命名句式**
- **🚨 "你不叫X了"/"你不叫X你叫Y" 句式X 和 Y 都是 AI 的名字(旧名和新名),绝对不是用户的名字。因为句子主语是"你"AI。**
- **以下情况名字归属用户,不是给 AI 取名:**「我叫」「我的名字是」「叫我」「我是」「大家叫我」「我的英文名是」「我的昵称是」等**第一人称(我)的自我介绍句式**
- **🚨 speaker=assistant 时的特殊规则:**
* AI 用来称呼用户的名字 → 归入**用户**实体的 aliases但必须是原文中逐字出现的称呼不能推测
* AI 自称的名字(如"我叫陈仔""我是你的助手")→ 归入**AI助手**实体的 aliases
* 判断依据AI 说"你叫X"或用 X 称呼用户 → X 是用户别名AI 说"我叫X"或"我是X" → X 是 AI 别名
- 示例:
* "我叫vv" → 用户实体: name="用户", aliases=["vv"](第一人称自我介绍,名字归用户)
* "我的英文名叫vv" → 用户实体: name="用户", aliases=["vv"](第一人称自我介绍,名字归用户)
* "我叫陈思远我给AI取名为远仔" → 用户实体: name="用户", aliases=["陈思远"]AI实体: name="AI助手", aliases=["远仔"]
* "叫你小助,我自己叫老王" → 用户实体: name="用户", aliases=["老王"]AI实体: name="AI助手", aliases=["小助"]
* "你不叫远仔了,你现在叫陈仔" → AI实体: name="AI助手", aliases=["陈仔"]"远仔"是AI旧名"陈仔"是AI新名都归AI。不要把"远仔"或"陈仔"放入用户的aliases
* [speaker=assistant] "好的VV今天想干点啥" → 用户实体: name="用户", aliases=["VV"]AI 在称呼用户,原文中出现了"VV"
* [speaker=assistant] "你叫陈思远,我叫陈仔" → 用户实体: name="用户", aliases=["陈思远"]AI实体: name="AI助手", aliases=["陈仔"]
* ❌ 错误:用户说"我叫vv",却把"vv"放入 AI 助手的 aliases没有任何给 AI 取名的表达)
* ❌ 错误AI 称呼用户为"VV",却把"VV"放入 AI 助手的 aliases
* ❌ 错误aliases=["陈思远", "远仔"]"远仔"是给AI取的名字不是用户的名字
* ❌ 错误:原文只有"陈思远",却在 aliases 中添加"思远大人""远哥""小远"等从未出现的变体(这是幻觉)
{% else %}
- **🚨 Default rule: If there is NO explicit AI/assistant naming expression in the conversation, ALL names belong to the user entity. Do NOT guess or infer that a name is for the AI.**
- Only create an AI/assistant entity when the user **explicitly** names the AI/assistant
- AI/assistant entity name field: use "AI Assistant"
- Names the user gives to the AI: put in the AI/assistant entity's aliases
- **🚨 NEVER put names given to the AI into the user entity's aliases**
- **An AI-naming expression MUST be present to assign a name to the AI:** "I'll call you", "your name is", "I name you", "let me call you", "you'll be called", "you're not called X anymore", "your new name is", etc. — **second-person ("you") or explicit AI-directed naming patterns**
- **🚨 "You're not called X anymore" / "You're not X, you're Y" pattern: BOTH X and Y are AI's names (old and new). They are NOT user's names. The subject is "you" (the AI).**
- **These patterns mean the name belongs to the USER, NOT the AI:** "I'm", "my name is", "call me", "I am", "people call me", "my English name is", "my nickname is", etc. — **first-person ("I"/"me") self-introduction patterns**
- **🚨 Special rules when speaker=assistant:**
* Names the AI uses to address the user → belong to the **user** entity's aliases (but only extract names that appear verbatim in the text, never infer)
* Names the AI uses for itself (e.g., "I'm Jarvis", "I am your assistant") → belong to the **AI assistant** entity's aliases
* Rule: AI says "you are X" or calls user X → X is user's alias; AI says "I'm X" or "I am X" → X is AI's alias
- Examples:
* "I'm vv" → User entity: name="User", aliases=["vv"] (first-person intro, name belongs to user)
* "My English name is vv" → User entity: name="User", aliases=["vv"] (first-person intro, name belongs to user)
* "I'm Alex, I'll call you Buddy" → User entity: name="User", aliases=["Alex"]; AI entity: name="AI Assistant", aliases=["Buddy"]
* "Call yourself Jarvis, my name is Tony" → User entity: name="User", aliases=["Tony"]; AI entity: name="AI Assistant", aliases=["Jarvis"]
* "You're not called Jarvis anymore, your new name is Friday" → AI entity: name="AI Assistant", aliases=["Friday"] (both "Jarvis" and "Friday" are AI names, NOT user names)
* [speaker=assistant] "Sure thing, VV" → User entity: name="User", aliases=["VV"] (AI addressing the user, "VV" appears in text)
* [speaker=assistant] "You're Alex, and I'm Jarvis" → User entity: name="User", aliases=["Alex"]; AI entity: name="AI Assistant", aliases=["Jarvis"]
* ❌ Wrong: User says "I'm vv" but "vv" is put in AI assistant's aliases (no AI-naming expression exists)
* ❌ Wrong: AI calls user "VV" but "VV" is put in AI assistant's aliases
* ❌ Wrong: aliases=["Alex", "Buddy"] ("Buddy" is a name for the AI, not the user)
* ❌ Wrong: Text only has "John Smith" but aliases include "Johnny", "Smithy" (hallucinated variants)
{% endif %}
5. **ALIASES ORDER:**
4. **ALIASES ORDER:**
{% if language == "zh" %}
- 顺序优先级:按出现顺序,先出现的在前
{% else %}
@@ -285,19 +202,8 @@ Output:
{"entity_idx": 0, "name": "Tripod", "type": "Equipment", "description": "Photography equipment accessory", "example": "", "aliases": ["Camera Tripod"], "is_explicit_memory": false}
]
}
**Example 4 (User vs AI alias distinction - English output):** "I'm Alex, and I'll call you Buddy"
Output:
{
"triplets": [
{"subject_name": "User", "subject_id": 0, "predicate": "NAMED", "object_name": "AI Assistant", "object_id": 1, "value": "Buddy"}
],
"entities": [
{"entity_idx": 0, "name": "User", "type": "Person", "description": "The user", "example": "", "aliases": ["Alex"], "is_explicit_memory": false},
{"entity_idx": 1, "name": "AI Assistant", "type": "Person", "description": "The user's AI assistant", "example": "", "aliases": ["Buddy"], "is_explicit_memory": false}
]
}
{% else %}
**Example 1 (English input → Chinese output):** "I plan to travel to Paris next week and visit the Louvre."
Output:
{
"triplets": [
@@ -352,39 +258,6 @@ Output:
]
}
**Example 6 (用户与AI别名区分 - Chinese):** "我称呼自己为陈思远我给AI取名为远仔"
Output:
{
"triplets": [
{"subject_name": "用户", "subject_id": 0, "predicate": "NAMED", "object_name": "AI助手", "object_id": 1, "value": "远仔"}
],
"entities": [
{"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["陈思远"], "is_explicit_memory": false},
{"entity_idx": 1, "name": "AI助手", "type": "Person", "description": "用户的AI助手", "example": "", "aliases": ["远仔"], "is_explicit_memory": false}
]
}
**Example 7 (纯用户自我介绍无AI命名 - Chinese):** "我叫vv"
Output:
{
"triplets": [],
"entities": [
{"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["vv"], "is_explicit_memory": false}
]
}
**Example 8 (给AI改名 - Chinese):** "你不叫远仔了,你现在叫陈仔"
Output:
{
"triplets": [
{"subject_name": "用户", "subject_id": 0, "predicate": "NAMED", "object_name": "AI助手", "object_id": 1, "value": "陈仔"}
],
"entities": [
{"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": [], "is_explicit_memory": false},
{"entity_idx": 1, "name": "AI助手", "type": "Person", "description": "用户的AI助手", "example": "", "aliases": ["陈仔"], "is_explicit_memory": false}
]
}
{% endif %}
===End of Examples===
@@ -406,12 +279,4 @@ Output:
- **⚠️ ALIASES ORDER: preserve temporal order of appearance**
- **🚨 MANDATORY FIELD: EVERY entity MUST include "aliases" field, even if empty array []**
**Output JSON structure:**
```json
{
"triplets": [...],
"entities": [...]
}
```
{{ json_schema }}

View File

@@ -1,140 +0,0 @@
===Task===
Extract user metadata changes from the following conversation statements spoken by the user.
{% if language == "zh" %}
**"三度原则"判断标准:**
- 复用度:该信息是否会被多个功能模块使用?
- 约束度:该信息是否会影响系统行为?
- 时效性:该信息是长期稳定的还是临时的?仅提取长期稳定信息。
**提取规则:**
- **只提取关于"用户本人"的画像信息**,忽略用户提到的第三方人物(如朋友、同事、家人)的信息
- 仅提取文本中明确提到的信息,不要推测
- **输出语言必须与输入文本的语言一致**(输入中文则输出中文值,输入英文则输出英文值)
**增量模式(重要):**
你只需要输出**本次对话引起的变更操作**,不要输出完整的元数据。每个变更是一个对象,包含:
- `field_path`:字段路径,用点号分隔(如 `profile.role`、`profile.expertise`
- `action`:操作类型
* `set`:新增或修改一个字段的值
* `remove`:移除一个字段的值
- `value`:字段的新值(`action="set"` 时必填,`action="remove"` 时填要移除的元素值)
* 所有字段均为列表类型,每个元素一条变更记录
**判断规则:**
- 用户提到新信息 → `action="set"`,填入新值
- 用户明确否定已有信息(如"我不再做老师了"、"我已经不学Python了")→ `action="remove"``value` 填要移除的元素值
- 如果本次对话没有任何可提取的变更,返回空的 `metadata_changes` 数组 `[]`
- **不要为未被提及的字段生成任何变更操作**
{% if existing_metadata %}
**已有元数据(仅供参考,用于判断是否需要变更):**
请对比已有数据和用户最新发言,只输出差异部分的变更操作。
- 如果用户说的信息和已有数据一致,不需要输出变更
- 如果用户否定了已有数据中的某个值,输出 `remove` 操作
- 如果用户提到了新信息,输出 `set` 操作
{% endif %}
**字段说明:**
- profile.role用户的职业或角色列表如 教师、医生、后端工程师,一个人可以有多个角色
- profile.domain用户所在领域列表如 教育、医疗、软件开发,一个人可以涉及多个领域
- profile.expertise用户擅长的技能或工具列表如 Python、心理咨询、高中物理
- profile.interests用户主动表达兴趣的话题或领域标签列表
**用户别名变更(增量模式):**
- **aliases_to_add**:本次新发现的用户别名,包括:
* 用户主动自我介绍:如"我叫张三"、"我的名字是XX"、"我的网名是XX"
* 他人对用户的称呼:如"同事叫我陈哥"、"大家叫我小张"、"领导叫我老陈"
* 只提取原文中逐字出现的名字,严禁推测或创造
* 禁止提取:用户给 AI 取的名字、第三方人物自身的名字、"用户"/"我" 等占位词
* 如果没有新别名,返回空数组 `[]`
- **aliases_to_remove**:用户明确否认的别名,包括:
* 用户说"我不叫XX了"、"别叫我XX"、"我改名了不叫XX" → 将 XX 放入此数组
* **严格限制**:只将用户原文中**逐字提到**的被否认名字放入,不要推断关联的其他别名
* 如果没有要移除的别名,返回空数组 `[]`
{% if existing_aliases %}
- 已有别名:{{ existing_aliases | tojson }}(仅供参考,不需要在输出中重复)
{% endif %}
{% else %}
**"Three-Degree Principle" criteria:**
- Reusability: Will this information be used by multiple functional modules?
- Constraint: Will this information affect system behavior?
- Timeliness: Is this information long-term stable or temporary? Only extract long-term stable information.
**Extraction rules:**
- **Only extract profile information about the user themselves**, ignore information about third parties (friends, colleagues, family) mentioned by the user
- Only extract information explicitly mentioned in the text, do not speculate
- **Output language must match the input text language**
**Incremental mode (important):**
You should only output **the change operations caused by this conversation**, not the complete metadata. Each change is an object containing:
- `field_path`: Field path separated by dots (e.g. `profile.role`, `profile.expertise`)
- `action`: Operation type
* `set`: Add or update a field value
* `remove`: Remove a field value
- `value`: The new value for the field (required when `action="set"`, for `action="remove"` fill in the element value to remove)
* All fields are list types, one change record per element
**Decision rules:**
- User mentions new information → `action="set"`, fill in the new value
- User explicitly negates existing info (e.g. "I'm no longer a teacher", "I stopped learning Python") → `action="remove"`, `value` is the element to remove
- If this conversation has no extractable changes, return an empty `metadata_changes` array `[]`
- **Do NOT generate any change operations for fields not mentioned in the conversation**
{% if existing_metadata %}
**Existing metadata (for reference only, to determine if changes are needed):**
Compare existing data with the user's latest statements, and only output change operations for the differences.
- If the user's statement matches existing data, no change is needed
- If the user negates a value in existing data, output a `remove` operation
- If the user mentions new information, output a `set` operation
{% endif %}
**Field descriptions:**
- profile.role: User's occupation or role (list), e.g. teacher, doctor, software engineer. A person can have multiple roles
- profile.domain: User's domain (list), e.g. education, healthcare, software development. A person can span multiple domains
- profile.expertise: User's skills or tools (list), e.g. Python, counseling, physics
- profile.interests: Topics or domain tags the user actively expressed interest in (list)
**User alias changes (incremental mode):**
- **aliases_to_add**: Newly discovered user aliases from this conversation, including:
* User self-introductions: e.g. "I'm John", "My name is XX", "My username is XX"
* How others address the user: e.g. "My colleagues call me Johnny", "People call me Mike"
* Only extract names that appear VERBATIM in the text — never infer or fabricate
* Do NOT extract: names the user gives to the AI, third-party people's own names, placeholder words like "User"/"I"
* If no new aliases, return empty array `[]`
- **aliases_to_remove**: Aliases the user explicitly denies, including:
* User says "Don't call me XX anymore", "I'm not called XX", "I changed my name from XX" → put XX in this array
* **Strict rule**: Only include the exact name the user **verbatim mentions** as denied. Do NOT infer or remove related aliases
* If no aliases to remove, return empty array `[]`
{% if existing_aliases %}
- Existing aliases: {{ existing_aliases | tojson }} (for reference only, do not repeat in output)
{% endif %}
{% endif %}
===User Statements===
{% for stmt in statements %}
- {{ stmt }}
{% endfor %}
{% if existing_metadata %}
===Existing User Metadata===
```json
{{ existing_metadata | tojson }}
```
{% endif %}
===Output Format===
Return a JSON object with the following structure:
```json
{
"metadata_changes": [
{"field_path": "profile.role", "action": "set", "value": "后端工程师"},
{"field_path": "profile.expertise", "action": "set", "value": "Python"},
{"field_path": "profile.expertise", "action": "remove", "value": "Java"}
],
"aliases_to_add": [],
"aliases_to_remove": []
}
```
{{ json_schema }}

View File

@@ -14,7 +14,6 @@ from pydantic import BaseModel, Field
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.models.models_model import ModelProvider, ModelType
from app.core.models.volcano_chat import VolcanoChatOpenAI
T = TypeVar("T")
@@ -26,9 +25,6 @@ class RedBearModelConfig(BaseModel):
api_key: str
base_url: Optional[str] = None
is_omni: bool = False # 是否为 Omni 模型
deep_thinking: bool = False # 是否启用深度思考模式
thinking_budget_tokens: Optional[int] = None # 深度思考 token 预算
support_thinking: bool = False # 模型是否支持 enable_thinking 参数capability 含 thinking
# 请求超时时间(秒)- 默认120秒以支持复杂的LLM调用可通过环境变量 LLM_TIMEOUT 配置
timeout: float = Field(default_factory=lambda: float(os.getenv("LLM_TIMEOUT", "120.0")))
# 最大重试次数 - 默认2次以避免过长等待可通过环境变量 LLM_MAX_RETRIES 配置
@@ -48,7 +44,7 @@ class RedBearModelFactory:
# 打印供应商信息用于调试
from app.core.logging_config import get_business_logger
logger = get_business_logger()
logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}, is_omni: {config.is_omni}, deep_thinking: {config.deep_thinking}")
logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}, is_omni: {config.is_omni}")
# dashscope 的 omni 模型使用 OpenAI 兼容模式
if provider == ModelProvider.DASHSCOPE and config.is_omni:
@@ -62,7 +58,7 @@ class RedBearModelFactory:
write=60.0,
pool=10.0,
)
params: Dict[str, Any] = {
return {
"model": config.model_name,
"base_url": config.base_url,
"api_key": config.api_key,
@@ -70,23 +66,6 @@ class RedBearModelFactory:
"max_retries": config.max_retries,
**config.extra_params
}
# 流式模式下启用 stream_usage 以获取 token 统计
is_streaming = bool(config.extra_params.get("streaming"))
if is_streaming:
params["stream_usage"] = True
# 只有支持 thinking 的模型才传 enable_thinking
if config.support_thinking:
model_kwargs: Dict[str, Any] = config.extra_params.get("model_kwargs", {})
if is_streaming:
model_kwargs["enable_thinking"] = config.deep_thinking
if config.deep_thinking:
model_kwargs["incremental_output"] = True
if config.thinking_budget_tokens:
model_kwargs["thinking_budget"] = config.thinking_budget_tokens
else:
model_kwargs["enable_thinking"] = False
params["model_kwargs"] = model_kwargs
return params
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]:
# 使用 httpx.Timeout 对象来设置详细的超时配置
@@ -99,7 +78,7 @@ class RedBearModelFactory:
write=60.0, # 写入超时60秒
pool=10.0, # 连接池超时10秒
)
params: Dict[str, Any] = {
return {
"model": config.model_name,
"base_url": config.base_url,
"api_key": config.api_key,
@@ -107,50 +86,16 @@ class RedBearModelFactory:
"max_retries": config.max_retries,
**config.extra_params
}
# 流式模式下启用 stream_usage 以获取 token 统计
if config.extra_params.get("streaming"):
params["stream_usage"] = True
# 深度思考模式
is_streaming = bool(config.extra_params.get("streaming"))
if config.support_thinking:
if is_streaming and not config.is_omni:
if provider == ModelProvider.VOLCANO:
# 火山引擎深度思考仅流式调用支持,非流式时不传 thinking 参数
thinking_config: Dict[str, Any] = {
"type": "enabled" if config.deep_thinking else "disabled"
}
if config.deep_thinking and config.thinking_budget_tokens:
thinking_config["budget_tokens"] = config.thinking_budget_tokens
params["extra_body"] = {"thinking": thinking_config}
else:
# 始终显式传递 enable_thinking不支持该参数的模型如 DeepSeek-R1会直接忽略
model_kwargs: Dict[str, Any] = config.extra_params.get("model_kwargs", {})
model_kwargs["enable_thinking"] = config.deep_thinking
if config.deep_thinking and config.thinking_budget_tokens:
model_kwargs["thinking_budget"] = config.thinking_budget_tokens
params["model_kwargs"] = model_kwargs
return params
elif provider == ModelProvider.DASHSCOPE:
params = {
# DashScope (通义千问) 使用自己的参数格式
# 注意: DashScopeEmbeddings 不支持 timeout 和 base_url 参数
# 只支持: model, dashscope_api_key, max_retries, client
return {
"model": config.model_name,
"dashscope_api_key": config.api_key,
"max_retries": config.max_retries,
**config.extra_params
}
# 只有支持 thinking 的模型才传 enable_thinking
if config.support_thinking:
is_streaming = bool(config.extra_params.get("streaming"))
model_kwargs: Dict[str, Any] = config.extra_params.get("model_kwargs", {})
if is_streaming:
model_kwargs["enable_thinking"] = config.deep_thinking
if config.deep_thinking:
model_kwargs["incremental_output"] = True
if config.thinking_budget_tokens:
model_kwargs["thinking_budget"] = config.thinking_budget_tokens
else:
model_kwargs["enable_thinking"] = False
params["model_kwargs"] = model_kwargs
return params
elif provider == ModelProvider.BEDROCK:
# Bedrock 使用 AWS 凭证
# api_key 格式: "access_key_id:secret_access_key" 或只是 access_key_id
@@ -189,13 +134,6 @@ class RedBearModelFactory:
elif "region_name" not in params:
params["region_name"] = "us-east-1" # 默认区域
# 深度思考模式Claude 3.7 Sonnet 等支持思考的模型
# 通过 additional_model_request_fields 传递 thinking 块关闭时不传Bedrock 无 disabled 选项)
if config.deep_thinking:
budget = config.thinking_budget_tokens or 10000
params["additional_model_request_fields"] = {
"thinking": {"type": "enabled", "budget_tokens": budget}
}
return params
else:
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
@@ -207,15 +145,10 @@ class RedBearModelFactory:
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
return {
"model": config.model_name,
# "base_url": config.base_url,
"jina_api_key": config.api_key,
**config.extra_params
}
elif provider == ModelProvider.DASHSCOPE:
return {
"model": config.model_name,
"dashscope_api_key": config.api_key,
**config.extra_params
}
else:
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
@@ -227,9 +160,7 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy
# dashscope 的 omni 模型使用 OpenAI 兼容模式
if provider == ModelProvider.DASHSCOPE and config.is_omni:
return ChatOpenAI
if provider == ModelProvider.VOLCANO:
return VolcanoChatOpenAI
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.VOLCANO]:
if type == ModelType.LLM:
return OpenAI
elif type == ModelType.CHAT:
@@ -271,9 +202,6 @@ def get_provider_rerank_class(provider: str):
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
from langchain_community.document_compressors import JinaRerank
return JinaRerank
elif provider == ModelProvider.DASHSCOPE:
from langchain_community.document_compressors.dashscope_rerank import DashScopeRerank
return DashScopeRerank
# elif provider == ModelProvider.OLLAMA:
# from langchain_ollama import OllamaEmbeddings
# return OllamaEmbeddings

View File

@@ -1,5 +1,5 @@
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union
from langchain_core.embeddings import Embeddings
from app.core.models.base import RedBearModelConfig, get_provider_embedding_class, RedBearModelFactory
@@ -22,38 +22,11 @@ class RedBearEmbeddings(Embeddings):
self._model = self._create_model(config)
self._client = None
@staticmethod
def _create_model(config: RedBearModelConfig) -> Embeddings:
def _create_model(self, config: RedBearModelConfig) -> Embeddings:
"""根据配置创建 LangChain 模型"""
embedding_class = get_provider_embedding_class(config.provider)
provider = config.provider.lower()
# Embedding models only need connection params, never LLM-specific ones
# (e.g. enable_thinking, model_kwargs) — build params directly.
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
import httpx
params = {
"model": config.model_name,
"base_url": config.base_url,
"api_key": config.api_key,
"timeout": httpx.Timeout(timeout=config.timeout, connect=60.0),
"max_retries": config.max_retries
}
elif provider == ModelProvider.DASHSCOPE:
params = {
"model": config.model_name,
"dashscope_api_key": config.api_key,
"max_retries": config.max_retries,
}
elif provider == ModelProvider.OLLAMA:
params = {
"model": config.model_name,
"base_url": config.base_url,
}
elif provider == ModelProvider.BEDROCK:
params = RedBearModelFactory.get_model_params(config)
else:
params = RedBearModelFactory.get_model_params(config)
return embedding_class(**params)
model_params = RedBearModelFactory.get_model_params(config)
return embedding_class(**model_params)
def _create_volcano_client(self, config: RedBearModelConfig):
"""创建火山引擎客户端"""

View File

@@ -76,9 +76,5 @@ class RedBearRerank(BaseDocumentCompressor):
from langchain_community.document_compressors import JinaRerank
model_instance: JinaRerank = self._model
return model_instance.rerank(documents=documents, query=query, top_n=top_n)
elif provider == ModelProvider.DASHSCOPE:
from langchain_community.document_compressors.dashscope_rerank import DashScopeRerank
model_instance: DashScopeRerank = self._model
return model_instance.rerank(documents=documents, query=query, top_n=top_n)
else:
raise ValueError(f"不支持的模型提供商: {provider}")

View File

@@ -11,7 +11,6 @@ models:
tags:
- 大语言模型
logo: bedrock
- name: amazon nova
type: llm
provider: bedrock
@@ -28,7 +27,6 @@ models:
- stream-tool-call
- vision
logo: bedrock
- name: anthropic claude
type: llm
provider: bedrock
@@ -37,7 +35,6 @@ models:
is_official: true
capability:
- vision
- thinking
is_omni: false
tags:
- 大语言模型
@@ -47,7 +44,6 @@ models:
- stream-tool-call
- document
logo: bedrock
- name: cohere
type: llm
provider: bedrock
@@ -62,7 +58,6 @@ models:
- tool-call
- stream-tool-call
logo: bedrock
- name: deepseek
type: llm
provider: bedrock
@@ -71,7 +66,6 @@ models:
is_official: true
capability:
- vision
- thinking
is_omni: false
tags:
- 大语言模型
@@ -80,7 +74,6 @@ models:
- tool-call
- stream-tool-call
logo: bedrock
- name: meta
type: llm
provider: bedrock
@@ -94,7 +87,6 @@ models:
- agent-thought
- tool-call
logo: bedrock
- name: mistral
type: llm
provider: bedrock
@@ -108,7 +100,6 @@ models:
- agent-thought
- tool-call
logo: bedrock
- name: openai
type: llm
provider: bedrock
@@ -123,7 +114,6 @@ models:
- tool-call
- stream-tool-call
logo: bedrock
- name: qwen
type: llm
provider: bedrock
@@ -138,7 +128,6 @@ models:
- tool-call
- stream-tool-call
logo: bedrock
- name: amazon.rerank-v1:0
type: rerank
provider: bedrock
@@ -150,7 +139,6 @@ models:
tags:
- 重排序模型
logo: bedrock
- name: cohere.rerank-v3-5:0
type: rerank
provider: bedrock
@@ -162,7 +150,6 @@ models:
tags:
- 重排序模型
logo: bedrock
- name: amazon.nova-2-multimodal-embeddings-v1:0
type: embedding
provider: bedrock
@@ -176,7 +163,6 @@ models:
- 文本嵌入模型
- vision
logo: bedrock
- name: amazon.titan-embed-text-v1
type: embedding
provider: bedrock
@@ -188,7 +174,6 @@ models:
tags:
- 文本嵌入模型
logo: bedrock
- name: amazon.titan-embed-text-v2:0
type: embedding
provider: bedrock
@@ -200,7 +185,6 @@ models:
tags:
- 文本嵌入模型
logo: bedrock
- name: cohere.embed-english-v3
type: embedding
provider: bedrock
@@ -212,7 +196,6 @@ models:
tags:
- 文本嵌入模型
logo: bedrock
- name: cohere.embed-multilingual-v3
type: embedding
provider: bedrock

View File

@@ -6,42 +6,36 @@ models:
description: DeepSeek-R1-Distill-Qwen-14B大语言模型支持智能体思考32000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- thinking
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
logo: dashscope
- name: deepseek-r1-distill-qwen-32b
type: llm
provider: dashscope
description: DeepSeek-R1-Distill-Qwen-32B大语言模型支持智能体思考32000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- thinking
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
logo: dashscope
- name: deepseek-r1
type: llm
provider: dashscope
description: DeepSeek-R1大语言模型支持智能体思考131072超大上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- thinking
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
logo: dashscope
- name: deepseek-v3.1
type: llm
provider: dashscope
@@ -54,7 +48,6 @@ models:
- 大语言模型
- agent-thought
logo: dashscope
- name: deepseek-v3.2-exp
type: llm
provider: dashscope
@@ -67,7 +60,6 @@ models:
- 大语言模型
- agent-thought
logo: dashscope
- name: deepseek-v3.2
type: llm
provider: dashscope
@@ -80,7 +72,6 @@ models:
- 大语言模型
- agent-thought
logo: dashscope
- name: deepseek-v3
type: llm
provider: dashscope
@@ -93,7 +84,6 @@ models:
- 大语言模型
- agent-thought
logo: dashscope
- name: farui-plus
type: llm
provider: dashscope
@@ -108,7 +98,6 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: glm-4.7
type: llm
provider: dashscope
@@ -123,7 +112,6 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qvq-max-latest
type: llm
provider: dashscope
@@ -132,7 +120,6 @@ models:
is_official: true
capability:
- vision
- thinking
is_omni: false
tags:
- 大语言模型
@@ -140,7 +127,6 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qvq-max
type: llm
provider: dashscope
@@ -149,7 +135,6 @@ models:
is_official: true
capability:
- vision
- thinking
is_omni: false
tags:
- 大语言模型
@@ -157,7 +142,6 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-coder-turbo-0919
type: llm
provider: dashscope
@@ -171,15 +155,13 @@ models:
- 代码模型
- agent-thought
logo: dashscope
- name: qwen-max-latest
type: llm
provider: dashscope
description: qwen-max-latest大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
capability:
- thinking
capability: []
is_omni: false
tags:
- 大语言模型
@@ -187,7 +169,6 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-max-longcontext
type: llm
provider: dashscope
@@ -202,15 +183,13 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-max
type: llm
provider: dashscope
description: qwen-max大语言模型支持多工具调用、智能体思考、流式工具调用32768上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
capability:
- thinking
capability: []
is_omni: false
tags:
- 大语言模型
@@ -218,7 +197,6 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-mt-plus
type: llm
provider: dashscope
@@ -232,7 +210,6 @@ models:
- 翻译模型
- agent-thought
logo: dashscope
- name: qwen-mt-turbo
type: llm
provider: dashscope
@@ -246,7 +223,6 @@ models:
- 翻译模型
- agent-thought
logo: dashscope
- name: qwen-plus-0112
type: llm
provider: dashscope
@@ -261,7 +237,6 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-plus-0125
type: llm
provider: dashscope
@@ -276,7 +251,6 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-plus-0723
type: llm
provider: dashscope
@@ -291,7 +265,6 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-plus-0806
type: llm
provider: dashscope
@@ -306,7 +279,6 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-plus-0919
type: llm
provider: dashscope
@@ -321,7 +293,6 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-plus-1125
type: llm
provider: dashscope
@@ -336,7 +307,6 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-plus-1127
type: llm
provider: dashscope
@@ -351,7 +321,6 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-plus-1220
type: llm
provider: dashscope
@@ -366,7 +335,6 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-vl-max
type: chat
provider: dashscope
@@ -384,7 +352,6 @@ models:
- agent-thought
- video
logo: dashscope
- name: qwen-vl-plus-0809
type: chat
provider: dashscope
@@ -402,7 +369,6 @@ models:
- agent-thought
- video
logo: dashscope
- name: qwen-vl-plus-2025-01-02
type: chat
provider: dashscope
@@ -420,7 +386,6 @@ models:
- agent-thought
- video
logo: dashscope
- name: qwen-vl-plus-2025-01-25
type: chat
provider: dashscope
@@ -438,7 +403,6 @@ models:
- agent-thought
- video
logo: dashscope
- name: qwen-vl-plus-latest
type: chat
provider: dashscope
@@ -456,7 +420,6 @@ models:
- agent-thought
- video
logo: dashscope
- name: qwen-vl-plus
type: chat
provider: dashscope
@@ -474,7 +437,6 @@ models:
- agent-thought
- video
logo: dashscope
- name: qwen2.5-0.5b-instruct
type: llm
provider: dashscope
@@ -489,15 +451,13 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-14b
type: llm
provider: dashscope
description: qwen3-14b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- thinking
capability: []
is_omni: false
tags:
- 大语言模型
@@ -505,15 +465,13 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-235b-a22b-instruct-2507
type: llm
provider: dashscope
description: qwen3-235b-a22b-instruct-2507大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- thinking
capability: []
is_omni: false
tags:
- 大语言模型
@@ -521,15 +479,13 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-235b-a22b-thinking-2507
type: llm
provider: dashscope
description: qwen3-235b-a22b-thinking-2507大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- thinking
capability: []
is_omni: false
tags:
- 大语言模型
@@ -537,15 +493,13 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-235b-a22b
type: llm
provider: dashscope
description: qwen3-235b-a22b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- thinking
capability: []
is_omni: false
tags:
- 大语言模型
@@ -553,15 +507,13 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-30b-a3b-instruct-2507
type: llm
provider: dashscope
description: qwen3-30b-a3b-instruct-2507大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- thinking
capability: []
is_omni: false
tags:
- 大语言模型
@@ -569,15 +521,13 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-30b-a3b
type: llm
provider: dashscope
description: qwen3-30b-a3b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- thinking
capability: []
is_omni: false
tags:
- 大语言模型
@@ -585,15 +535,13 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-32b
type: llm
provider: dashscope
description: qwen3-32b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- thinking
capability: []
is_omni: false
tags:
- 大语言模型
@@ -601,15 +549,13 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-4b
type: llm
provider: dashscope
description: qwen3-4b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- thinking
capability: []
is_omni: false
tags:
- 大语言模型
@@ -617,15 +563,13 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-8b
type: llm
provider: dashscope
description: qwen3-8b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- thinking
capability: []
is_omni: false
tags:
- 大语言模型
@@ -633,75 +577,65 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-coder-30b-a3b-instruct
type: llm
provider: dashscope
description: qwen3-coder-30b-a3b-instruct大语言模型支持智能体思考262144上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- thinking
capability: []
is_omni: false
tags:
- 大语言模型
- 代码模型
- agent-thought
logo: dashscope
- name: qwen3-coder-480b-a35b-instruct
type: llm
provider: dashscope
description: qwen3-coder-480b-a35b-instruct大语言模型支持智能体思考262144上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- thinking
capability: []
is_omni: false
tags:
- 大语言模型
- 代码模型
- agent-thought
logo: dashscope
- name: qwen3-coder-plus-2025-09-23
type: llm
provider: dashscope
description: qwen3-coder-plus-2025-09-23大语言模型支持智能体思考1000000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- thinking
capability: []
is_omni: false
tags:
- 大语言模型
- 代码模型
- agent-thought
logo: dashscope
- name: qwen3-coder-plus
type: llm
provider: dashscope
description: qwen3-coder-plus大语言模型支持智能体思考1000000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- thinking
capability: []
is_omni: false
tags:
- 大语言模型
- 代码模型
- agent-thought
logo: dashscope
- name: qwen3-max-2025-09-23
type: llm
provider: dashscope
description: qwen3-max-2025-09-23大语言模型支持多工具调用、智能体思考、流式工具调用262144上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
capability:
- thinking
capability: []
is_omni: false
tags:
- 大语言模型
@@ -710,15 +644,13 @@ models:
- stream-tool-call
- 联网搜索
logo: dashscope
- name: qwen3-max-2026-01-23
type: llm
provider: dashscope
description: qwen3-max-2026-01-23大语言模型支持多工具调用、智能体思考、流式工具调用262144上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
capability:
- thinking
capability: []
is_omni: false
tags:
- 大语言模型
@@ -727,15 +659,13 @@ models:
- stream-tool-call
- 联网搜索
logo: dashscope
- name: qwen3-max-preview
type: llm
provider: dashscope
description: qwen3-max-preview大语言模型支持多工具调用、智能体思考、流式工具调用262144上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- thinking
capability: []
is_omni: false
tags:
- 大语言模型
@@ -743,15 +673,13 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-max
type: llm
provider: dashscope
description: qwen3-max大语言模型支持多工具调用、智能体思考、流式工具调用262144上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
capability:
- thinking
capability: []
is_omni: false
tags:
- 大语言模型
@@ -760,15 +688,13 @@ models:
- stream-tool-call
- 联网搜索
logo: dashscope
- name: qwen3-next-80b-a3b-instruct
type: llm
provider: dashscope
description: qwen3-next-80b-a3b-instruct大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- thinking
capability: []
is_omni: false
tags:
- 大语言模型
@@ -776,15 +702,13 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-next-80b-a3b-thinking
type: llm
provider: dashscope
description: qwen3-next-80b-a3b-thinking大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- thinking
capability: []
is_omni: false
tags:
- 大语言模型
@@ -792,7 +716,6 @@ models:
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-omni-flash-2025-12-01
type: llm
provider: dashscope
@@ -812,7 +735,6 @@ models:
- video
- audio
logo: dashscope
- name: qwen3-vl-235b-a22b-instruct
type: chat
provider: dashscope
@@ -822,7 +744,6 @@ models:
capability:
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -833,7 +754,6 @@ models:
- vision
- video
logo: dashscope
- name: qwen3-vl-235b-a22b-thinking
type: chat
provider: dashscope
@@ -843,7 +763,6 @@ models:
capability:
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -854,7 +773,6 @@ models:
- vision
- video
logo: dashscope
- name: qwen3-vl-30b-a3b-instruct
type: chat
provider: dashscope
@@ -864,7 +782,6 @@ models:
capability:
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -875,7 +792,6 @@ models:
- vision
- video
logo: dashscope
- name: qwen3-vl-30b-a3b-thinking
type: chat
provider: dashscope
@@ -885,7 +801,6 @@ models:
capability:
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -896,7 +811,6 @@ models:
- vision
- video
logo: dashscope
- name: qwen3-vl-flash
type: chat
provider: dashscope
@@ -906,7 +820,6 @@ models:
capability:
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -917,7 +830,6 @@ models:
- vision
- video
logo: dashscope
- name: qwen3-vl-plus-2025-09-23
type: chat
provider: dashscope
@@ -927,7 +839,6 @@ models:
capability:
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -936,7 +847,6 @@ models:
- agent-thought
- video
logo: dashscope
- name: qwen3-vl-plus
type: chat
provider: dashscope
@@ -946,7 +856,6 @@ models:
capability:
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -955,52 +864,45 @@ models:
- agent-thought
- video
logo: dashscope
- name: qwq-32b
type: llm
provider: dashscope
description: qwq-32b大语言模型支持智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- thinking
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwq-plus-0305
type: llm
provider: dashscope
description: qwq-plus-0305大语言模型支持智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- thinking
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwq-plus
type: llm
provider: dashscope
description: qwq-plus大语言模型支持智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- thinking
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
- stream-tool-call
logo: dashscope
- name: gte-rerank-v2
type: rerank
provider: dashscope
@@ -1012,7 +914,6 @@ models:
tags:
- 重排序模型
logo: dashscope
- name: gte-rerank
type: rerank
provider: dashscope
@@ -1024,7 +925,6 @@ models:
tags:
- 重排序模型
logo: dashscope
- name: multimodal-embedding-v1
type: embedding
provider: dashscope
@@ -1039,7 +939,6 @@ models:
- 多模态模型
- vision
logo: dashscope
- name: text-embedding-v1
type: embedding
provider: dashscope
@@ -1052,7 +951,6 @@ models:
- 嵌入模型
- 文本嵌入
logo: dashscope
- name: text-embedding-v2
type: embedding
provider: dashscope
@@ -1065,7 +963,6 @@ models:
- 嵌入模型
- 文本嵌入
logo: dashscope
- name: text-embedding-v3
type: embedding
provider: dashscope
@@ -1078,7 +975,6 @@ models:
- 嵌入模型
- 文本嵌入
logo: dashscope
- name: text-embedding-v4
type: embedding
provider: dashscope

View File

@@ -20,7 +20,6 @@ models:
- audio
- video
logo: openai
- name: gpt-3.5-turbo-0125
type: llm
provider: openai
@@ -35,7 +34,6 @@ models:
- agent-thought
- stream-tool-call
logo: openai
- name: gpt-3.5-turbo-1106
type: llm
provider: openai
@@ -50,7 +48,6 @@ models:
- agent-thought
- stream-tool-call
logo: openai
- name: gpt-3.5-turbo-16k
type: llm
provider: openai
@@ -65,7 +62,6 @@ models:
- agent-thought
- stream-tool-call
logo: openai
- name: gpt-3.5-turbo-instruct
type: llm
provider: openai
@@ -77,7 +73,6 @@ models:
tags:
- 大语言模型
logo: openai
- name: gpt-3.5-turbo
type: llm
provider: openai
@@ -92,7 +87,6 @@ models:
- agent-thought
- stream-tool-call
logo: openai
- name: gpt-4-0125-preview
type: llm
provider: openai
@@ -107,7 +101,6 @@ models:
- agent-thought
- stream-tool-call
logo: openai
- name: gpt-4-1106-preview
type: llm
provider: openai
@@ -122,7 +115,6 @@ models:
- agent-thought
- stream-tool-call
logo: openai
- name: gpt-4-turbo-2024-04-09
type: llm
provider: openai
@@ -139,7 +131,6 @@ models:
- stream-tool-call
- vision
logo: openai
- name: gpt-4-turbo-preview
type: llm
provider: openai
@@ -154,7 +145,6 @@ models:
- agent-thought
- stream-tool-call
logo: openai
- name: gpt-4-turbo
type: llm
provider: openai
@@ -171,7 +161,6 @@ models:
- stream-tool-call
- vision
logo: openai
- name: o1-preview
type: llm
provider: openai
@@ -184,7 +173,6 @@ models:
- 大语言模型
- agent-thought
logo: openai
- name: o1
type: llm
provider: openai
@@ -193,7 +181,6 @@ models:
is_official: true
capability:
- vision
- thinking
is_omni: false
tags:
- 大语言模型
@@ -203,7 +190,6 @@ models:
- vision
- structured-output
logo: openai
- name: o3-2025-04-16
type: llm
provider: openai
@@ -212,7 +198,6 @@ models:
is_official: true
capability:
- vision
- thinking
is_omni: false
tags:
- 大语言模型
@@ -222,15 +207,13 @@ models:
- stream-tool-call
- structured-output
logo: openai
- name: o3-mini-2025-01-31
type: llm
provider: openai
description: o3-mini-2025-01-31大语言模型支持智能体思考、工具调用、流式工具调用、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- thinking
capability: []
is_omni: false
tags:
- 大语言模型
@@ -239,15 +222,13 @@ models:
- stream-tool-call
- structured-output
logo: openai
- name: o3-mini
type: llm
provider: openai
description: o3-mini大语言模型支持智能体思考、工具调用、流式工具调用、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- thinking
capability: []
is_omni: false
tags:
- 大语言模型
@@ -256,7 +237,6 @@ models:
- stream-tool-call
- structured-output
logo: openai
- name: o3-pro-2025-06-10
type: llm
provider: openai
@@ -265,7 +245,6 @@ models:
is_official: true
capability:
- vision
- thinking
is_omni: false
tags:
- 大语言模型
@@ -274,7 +253,6 @@ models:
- vision
- structured-output
logo: openai
- name: o3-pro
type: llm
provider: openai
@@ -283,7 +261,6 @@ models:
is_official: true
capability:
- vision
- thinking
is_omni: false
tags:
- 大语言模型
@@ -292,7 +269,6 @@ models:
- vision
- structured-output
logo: openai
- name: o3
type: llm
provider: openai
@@ -301,7 +277,6 @@ models:
is_official: true
capability:
- vision
- thinking
is_omni: false
tags:
- 大语言模型
@@ -311,7 +286,6 @@ models:
- stream-tool-call
- structured-output
logo: openai
- name: o4-mini-2025-04-16
type: llm
provider: openai
@@ -320,7 +294,6 @@ models:
is_official: true
capability:
- vision
- thinking
is_omni: false
tags:
- 大语言模型
@@ -330,7 +303,6 @@ models:
- stream-tool-call
- structured-output
logo: openai
- name: o4-mini
type: llm
provider: openai
@@ -339,7 +311,6 @@ models:
is_official: true
capability:
- vision
- thinking
is_omni: false
tags:
- 大语言模型
@@ -349,7 +320,6 @@ models:
- stream-tool-call
- structured-output
logo: openai
- name: text-embedding-3-large
type: embedding
provider: openai
@@ -361,7 +331,6 @@ models:
tags:
- 文本向量模型
logo: openai
- name: text-embedding-3-small
type: embedding
provider: openai
@@ -373,7 +342,6 @@ models:
tags:
- 文本向量模型
logo: openai
- name: text-embedding-ada-002
type: embedding
provider: openai

View File

@@ -10,7 +10,6 @@ models:
capability:
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -25,7 +24,6 @@ models:
capability:
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -40,7 +38,6 @@ models:
capability:
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -55,7 +52,6 @@ models:
capability:
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -86,7 +82,6 @@ models:
capability:
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -101,7 +96,6 @@ models:
capability:
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -116,7 +110,6 @@ models:
capability:
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -131,7 +124,6 @@ models:
capability:
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型
@@ -147,7 +139,6 @@ models:
capability:
- vision
- video
- thinking
is_omni: false
tags:
- 大语言模型

View File

@@ -1,52 +0,0 @@
"""
火山引擎 ChatOpenAI 扩展
ChatOpenAI 在解析流式 SSE 时只取 delta.content会丢弃 delta.reasoning_content。
此类仅重写 _convert_chunk_to_generation_chunk将 reasoning_content 补入 additional_kwargs。
"""
from __future__ import annotations
from typing import Any, Optional, Union
from langchain_core.outputs import ChatGenerationChunk, ChatResult
from langchain_openai import ChatOpenAI
class VolcanoChatOpenAI(ChatOpenAI):
"""火山引擎 Chat 模型支持深度思考内容reasoning_content的流式和非流式透传。"""
def _create_chat_result(self, response: Union[dict, Any], generation_info: Optional[dict] = None) -> ChatResult:
result = super()._create_chat_result(response, generation_info)
# 将非流式响应中的 reasoning_content 补入 additional_kwargs
choices = response.choices if hasattr(response, "choices") else response.get("choices", [])
if choices:
message = choices[0].message if hasattr(choices[0], "message") else choices[0].get("message", {})
reasoning = (
getattr(message, "reasoning_content", None)
or (message.get("reasoning_content") if isinstance(message, dict) else None)
)
if reasoning and result.generations:
result.generations[0].message.additional_kwargs["reasoning_content"] = reasoning
return result
def _convert_chunk_to_generation_chunk(
self,
chunk: dict,
default_chunk_class: type,
base_generation_info: Optional[dict],
) -> Optional[ChatGenerationChunk]:
gen_chunk = super()._convert_chunk_to_generation_chunk(
chunk, default_chunk_class, base_generation_info
)
if gen_chunk is None:
return None
# 从原始 chunk 中提取 reasoning_content
choices = chunk.get("choices") or chunk.get("chunk", {}).get("choices", [])
if choices:
delta = choices[0].get("delta") or {}
reasoning: Any = delta.get("reasoning_content")
if reasoning:
gen_chunk.message.additional_kwargs["reasoning_content"] = reasoning
return gen_chunk

View File

@@ -672,15 +672,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
excel_parser = ExcelParser()
if parser_config.get("html4excel") and parser_config.get("html4excel").lower() == "true":
sections = [(_, "") for _ in excel_parser.html(binary, 12) if _]
parser_config["chunk_token_num"] = 0
else:
sections = [(_, "") for _ in excel_parser(binary) if _]
callback(0.8, "Finish parsing.")
# Excel 每行直接作为一个 chunk不经过 naive_merge 避免被 delimiter 拆分
chunks = [s for s, _ in sections]
res.extend(tokenize_chunks(chunks, doc, is_english, None))
res.extend(embed_res)
res.extend(url_res)
return res
parser_config["chunk_token_num"] = 12800
elif re.search(r"\.(txt|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|sql)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")

View File

@@ -232,14 +232,14 @@ class RAGExcelParser:
t = str(ti[i].value) if i < len(ti) else ""
t += ("" if t else "") + str(c.value)
fields.append(t)
line = "\n".join(fields)
line = "; ".join(fields)
if sheetname.lower().find("sheet") < 0:
line += "\n——" + sheetname
line += " ——" + sheetname
res.append(line)
else:
# 只有表头的情况
if header_fields:
line = "\n".join(header_fields)
line = "; ".join(header_fields)
if sheetname.lower().find("sheet") < 0:
line += " ——" + sheetname
res.append(line)

View File

@@ -292,7 +292,6 @@ class MinerUParser(RAGPdfParser):
self.page_from = page_from
self.page_to = page_to
try:
with sys.modules[LOCK_KEY_pdfplumber]: # ← 加这一行,获取全局锁
with pdfplumber.open(fnm) if isinstance(fnm, (str, PathLike)) else pdfplumber.open(BytesIO(fnm)) as pdf:
self.pdf = pdf
self.page_images = [p.to_image(resolution=72 * zoomin, antialias=True).original for _, p in enumerate(self.pdf.pages[page_from:page_to])]

View File

@@ -50,9 +50,7 @@ class OpenAIEmbed(Base):
def encode(self, texts: list):
# OpenAI requires batch size <=16
batch_size = 16
# Use 8000 instead of 8191 to leave safety margin for tokenizer differences
# between cl100k_base (used by truncate) and the actual embedding model
texts = [truncate(t, 8000) for t in texts]
texts = [truncate(t, 8191) for t in texts]
ress = []
total_tokens = 0
for i in range(0, len(texts), batch_size):
@@ -65,7 +63,7 @@ class OpenAIEmbed(Base):
return np.array(ress), total_tokens
def encode_queries(self, text):
res = self.client.embeddings.create(input=[truncate(text, 8000)], model=self.model_name, encoding_format="float",extra_body={"drop_params": True})
res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float",extra_body={"drop_params": True})
return np.array(res.data[0].embedding), self.total_token_count(res)
@@ -81,7 +79,6 @@ class LocalAIEmbed(Base):
def encode(self, texts: list):
batch_size = 16
texts = [truncate(t, 8000) for t in texts]
ress = []
for i in range(0, len(texts), batch_size):
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
@@ -176,7 +173,6 @@ class XinferenceEmbed(Base):
def encode(self, texts: list):
batch_size = 16
texts = [truncate(t, 8000) for t in texts]
ress = []
total_tokens = 0
for i in range(0, len(texts), batch_size):
@@ -192,7 +188,7 @@ class XinferenceEmbed(Base):
def encode_queries(self, text):
res = None
try:
res = self.client.embeddings.create(input=[truncate(text, 8000)], model=self.model_name)
res = self.client.embeddings.create(input=[text], model=self.model_name)
return np.array(res.data[0].embedding), self.total_token_count(res)
except Exception as _e:
log_exception(_e, res)

View File

@@ -28,7 +28,6 @@ from app.core.rag.common.float_utils import get_float
from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD
from app.core.rag.llm.chat_model import Base
from app.core.rag.llm.embedding_model import OpenAIEmbed
from app.services.model_service import ModelApiKeyService
import logging
logger = logging.getLogger(__name__)
@@ -115,8 +114,9 @@ def knowledge_retrieval(
# Use the specified reranker for re-ranking
if reranker_id:
try:
all_results = rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
except Exception as rerank_error:
# If reranker fails, log warning and continue with original results
logger.warning(
"Reranker failed, falling back to original results",
extra={
@@ -132,10 +132,7 @@ def knowledge_retrieval(
from app.core.rag.common.settings import kg_retriever
doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
if doc:
all_results.insert(0, DocumentChunk(
page_content=doc.get("page_content", ""),
metadata=doc.get("metadata", {})
))
all_results.insert(0, doc)
except Exception as graph_error:
print(f"Failed to retrieve from knowledge graph: {str(graph_error)}")
@@ -201,18 +198,16 @@ def _retrieve_for_knowledge(
workspace_ids.append(str(db_knowledge.workspace_id))
if not chat_model:
llm_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.llm_id)
chat_model = Base(
key=llm_key.api_key,
model_name=llm_key.model_name,
base_url=llm_key.api_base,
key=db_knowledge.llm.api_keys[0].api_key,
model_name=db_knowledge.llm.api_keys[0].model_name,
base_url=db_knowledge.llm.api_keys[0].api_base,
)
if not embedding_model:
emb_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.embedding_id)
embedding_model = OpenAIEmbed(
key=emb_key.api_key,
model_name=emb_key.model_name,
base_url=emb_key.api_base,
key=db_knowledge.embedding.api_keys[0].api_key,
model_name=db_knowledge.embedding.api_keys[0].model_name,
base_url=db_knowledge.embedding.api_keys[0].api_base,
)
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
@@ -253,29 +248,6 @@ def _retrieve_for_knowledge(
seen_ids.add(doc.metadata["doc_id"])
unique_rs.append(doc)
rs = unique_rs
if unique_rs:
rs = vector_service.rerank(
query=kb_config["query"],
docs=unique_rs,
top_k=kb_config["top_k"]
)
if kb_config["retrieve_type"] == "graph":
try:
from app.core.rag.common.settings import kg_retriever
graph_doc = kg_retriever.retrieval(
question=kb_config["query"],
workspace_ids=[str(db_knowledge.workspace_id)],
kb_ids=[str(db_knowledge.id)],
emb_mdl=embedding_model,
llm=chat_model,
)
if graph_doc:
rs.insert(0, DocumentChunk(
page_content=graph_doc.get("page_content", ""),
metadata=graph_doc.get("metadata", {})
))
except Exception as graph_error:
logger.warning(f"Graph retrieval failed for kb {db_knowledge.id}: {graph_error}")
results.extend(rs)
return results, chat_model, embedding_model

View File

@@ -27,7 +27,7 @@ class DateTimeTool(BuiltinTool):
type=ParameterType.STRING,
description="操作类型",
required=True,
enum=["format", "convert_timezone", "timestamp_to_datetime", "now", "datetime_to_timestamp"]
enum=["format", "convert_timezone", "timestamp_to_datetime", "now"]
),
ToolParameter(
name="input_value",
@@ -230,7 +230,7 @@ class DateTimeTool(BuiltinTool):
@staticmethod
def _datetime_to_timestamp(kwargs) -> dict:
"""日期时间转时间戳"""
input_value = kwargs.get("input_value").strip()
input_value = kwargs.get("input_value")
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
timezone_str = kwargs.get("from_timezone", "Asia/Shanghai")
@@ -253,9 +253,9 @@ class DateTimeTool(BuiltinTool):
return {
"datetime": input_value,
"timezone": timezone_str,
"timestamp": int(dt.timestamp()) * 1000,
"timestamp": int(dt.timestamp()),
"iso_format": dt.isoformat(),
"result_data": int(dt.timestamp()) * 1000
"result_data": int(dt.timestamp())
}
def _calculate_datetime(self, kwargs) -> dict:

View File

@@ -1,300 +0,0 @@
"""OpenClaw 远程 Agent 内置工具"""
import time
import base64
from io import BytesIO
from typing import List, Dict, Any, Optional
import aiohttp
from app.core.tools.builtin.base import BuiltinTool
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class OpenClawTool(BuiltinTool):
"""OpenClaw 远程 Agent 工具 — 支持文本和图片多模态输入"""
def __init__(self, tool_id: str, config: Dict[str, Any]):
super().__init__(tool_id, config)
params = self.parameters_config
# 用户配置项(前端表单填写)
self._server_url = params.get("server_url", "")
self._api_key = params.get("api_key", "")
self._agent_id = params.get("agent_id", "main")
# 内部默认值
self._model = "openclaw"
self._session_strategy = "by_user"
self._timeout = 120
# 运行时上下文(通过 set_runtime_context 注入)
self._user_id = "anonymous"
self._conversation_id = None
self._uploaded_files = []
@property
def name(self) -> str:
return "openclaw_tool"
@property
def description(self) -> str:
return (
"OpenClaw 远程 Agent将任务委托给远程 OpenClaw Agent。"
"具备 3D 模型生成与打印控制、设备管理、文件处理、浏览器自动化、"
"Shell 命令执行、网络搜索等能力。支持文本和图片多模态交互。"
)
def get_required_config_parameters(self) -> List[str]:
return ["server_url", "api_key"]
@property
def parameters(self) -> List[ToolParameter]:
return [
ToolParameter(
name="operation",
type=ParameterType.STRING,
description="任务类型",
required=True,
enum= ["print_task", "device_query", "image_understand", "general"]
),
ToolParameter(
name="message",
type=ParameterType.STRING,
description="发送给 OpenClaw Agent 的文本请求内容",
required=True
),
ToolParameter(
name="image_url",
type=ParameterType.STRING,
description="可选,附带的图片 URL 或 base64 data URIOpenClaw 支持图片输入)",
required=False
)
]
# ---------- 运行时上下文注入 ----------
def set_runtime_context(
self,
user_id: str = "anonymous",
conversation_id: Optional[str] = None,
uploaded_files: Optional[list] = None
):
"""注入运行时上下文(由 chat service 调用)"""
self._user_id = user_id
self._conversation_id = conversation_id
self._uploaded_files = uploaded_files or []
# ---------- 连接测试 ----------
async def test_connection(self) -> Dict[str, Any]:
"""测试 OpenClaw Gateway 连接"""
if not self._server_url:
return {"success": False, "message": "未配置 server_url"}
if not self._api_key:
return {"success": False, "message": "未配置 api_key"}
url = f"{self._server_url.rstrip('/')}/v1/responses"
headers = {
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
"x-openclaw-agent-id": self._agent_id
}
body = {
"model": self._model,
"user": "connection-test",
"input": "hi",
"stream": False
}
try:
timeout_cfg = aiohttp.ClientTimeout(total=30)
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
async with session.post(url, json=body, headers=headers) as resp:
if resp.status < 400:
return {"success": True, "message": "OpenClaw 连接成功"}
error_text = await resp.text()
return {
"success": False,
"message": f"OpenClaw HTTP {resp.status}: {error_text[:200]}"
}
except Exception as e:
return {"success": False, "message": f"OpenClaw 连接失败: {str(e)}"}
# ---------- 执行 ----------
async def execute(self, **kwargs) -> ToolResult:
"""执行 OpenClaw 调用"""
start_time = time.time()
try:
message = kwargs.get("message", "")
if not message:
return ToolResult.error_result(
error="message 参数不能为空",
error_code="OPENCLAW_INVALID_INPUT",
execution_time=time.time() - start_time
)
# 提取图片优先从用户上传文件中获取LLM 传的 image_url 作为兜底
image_url = self._extract_image_from_uploads()
if not image_url:
image_url = kwargs.get("image_url")
if image_url and not image_url.startswith("data:"):
image_url = await self._download_and_encode_image(image_url)
# 构建请求
url = f"{self._server_url.rstrip('/')}/v1/responses"
headers = {
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
"x-openclaw-agent-id": self._agent_id
}
user_field = (
f"conv-{self._conversation_id}"
if self._session_strategy == "by_conversation" and self._conversation_id
else f"user-{self._user_id}"
)
input_field = self._build_input(message, image_url)
body = {
"model": self._model,
"user": user_field,
"input": input_field,
"stream": False
}
timeout_cfg = aiohttp.ClientTimeout(total=self._timeout)
# 打印请求日志(截断 base64 避免日志过大)
log_body = {**body}
if isinstance(log_body.get("input"), list):
log_body["input"] = "[multimodal input, truncated]"
elif isinstance(log_body.get("input"), str) and len(log_body["input"]) > 500:
log_body["input"] = log_body["input"][:500] + "..."
logger.info(
f"OpenClaw 请求: url={url}, agent_id={self._agent_id}, "
f"has_image={bool(image_url)}, body={log_body}"
)
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
async with session.post(url, json=body, headers=headers) as resp:
execution_time = time.time() - start_time
if resp.status >= 400:
error_text = await resp.text()
return ToolResult.error_result(
error=f"OpenClaw HTTP {resp.status}: {error_text[:500]}",
error_code="OPENCLAW_HTTP_ERROR",
execution_time=execution_time
)
data = await resp.json()
text = self._extract_response(data)
display_text = self._format_result(text)
return ToolResult.success_result(
data=display_text,
execution_time=execution_time
)
except aiohttp.ClientError as e:
return ToolResult.error_result(
error=f"OpenClaw 网络连接失败: {str(e)}",
error_code="OPENCLAW_NETWORK_ERROR",
execution_time=time.time() - start_time
)
except Exception as e:
return ToolResult.error_result(
error=f"OpenClaw 调用失败: {str(e)}",
error_code="OPENCLAW_EXECUTION_ERROR",
execution_time=time.time() - start_time
)
# ---------- 私有方法 ----------
def _extract_image_from_uploads(self) -> Optional[str]:
"""从用户上传文件中提取图片 URL"""
for f in self._uploaded_files:
f_type = f.get("type", "")
if f_type == "image":
source = f.get("source", {})
if source.get("type") == "base64":
media_type = source.get("media_type", "image/jpeg")
data = source.get("data", "")
return f"data:{media_type};base64,{data}"
elif f.get("image"):
return f.get("image")
elif f.get("url"):
return f.get("url")
elif f_type == "image_url":
return f.get("image_url", {}).get("url", "")
return None
async def _download_and_encode_image(self, image_url: str) -> str:
"""下载图片并转为 base64 data URI"""
try:
from PIL import Image
MAX_RAW_SIZE = 4 * 1024 * 1024
async with aiohttp.ClientSession() as session:
async with session.get(
image_url, allow_redirects=True,
timeout=aiohttp.ClientTimeout(total=30)
) as resp:
if resp.status != 200:
return image_url
content_type = resp.headers.get("Content-Type", "image/jpeg")
if not content_type.startswith("image/"):
return image_url
img_bytes = await resp.read()
if len(img_bytes) > MAX_RAW_SIZE:
img = Image.open(BytesIO(img_bytes))
if img.mode in ("RGBA", "P", "LA"):
img = img.convert("RGB")
if max(img.size) > 2048:
img.thumbnail((2048, 2048), Image.LANCZOS)
buf = BytesIO()
img.save(buf, format="JPEG", quality=75, optimize=True)
img_bytes = buf.getvalue()
content_type = "image/jpeg"
b64 = base64.b64encode(img_bytes).decode("utf-8")
return f"data:{content_type};base64,{b64}"
except Exception as e:
logger.warning(f"OpenClaw 下载图片失败,使用原始 URL: {e}")
return image_url
def _build_input(self, message: str, image_url: Optional[str] = None):
"""构造请求 input 字段:有图片则构造多模态结构,否则纯文本"""
if not image_url:
return message
content_parts = [{"type": "input_text", "text": message}]
if image_url.startswith("data:"):
try:
header, data = image_url.split(",", 1)
media_type = header.split(":")[1].split(";")[0]
content_parts.append({
"type": "input_image",
"source": {"type": "base64", "media_type": media_type, "data": data}
})
except (ValueError, IndexError):
return message
else:
content_parts.append({
"type": "input_image",
"source": {"type": "url", "url": image_url}
})
return [{"type": "message", "role": "user", "content": content_parts}]
def _extract_response(self, response_data: Dict[str, Any]) -> str:
"""从 OpenClaw 响应中提取文本内容
OpenClaw /v1/responses 只返回 output_text 类型的内容。
图片信息(如有)由 OpenClaw Skill 以 Markdown 链接形式嵌入文本中返回。
"""
output = response_data.get("output", [])
texts = []
for item in output:
if item.get("type") == "message":
for content in item.get("content", []):
if content.get("type") == "output_text" and content.get("text"):
texts.append(content["text"])
return "\n".join(texts) if texts else str(response_data)
@staticmethod
def _format_result(text: str) -> str:
"""格式化结果为 LLM 可读字符串"""
return text or "OpenClaw 返回了空内容)"

View File

@@ -12,11 +12,6 @@ class OperationTool(BaseTool):
self.operation = operation
super().__init__(base_tool.tool_id, base_tool.config)
def set_runtime_context(self, **kwargs):
"""转发运行时上下文到 base_tool"""
if hasattr(self.base_tool, 'set_runtime_context'):
self.base_tool.set_runtime_context(**kwargs)
@property
def name(self) -> str:
return f"{self.base_tool.name}_{self.operation}"
@@ -37,8 +32,6 @@ class OperationTool(BaseTool):
return self._get_datetime_params()
elif self.base_tool.name == 'json_tool':
return self._get_json_params()
elif self.base_tool.name == 'openclaw_tool':
return self._get_openclaw_params()
else:
# 默认返回除operation外的所有参数
return [p for p in self.base_tool.parameters if p.name != "operation"]
@@ -145,29 +138,6 @@ class OperationTool(BaseTool):
default="Asia/Shanghai"
)
]
elif self.operation == "datetime_to_timestamp":
return [
ToolParameter(
name="input_value",
type=ParameterType.STRING,
description="输入值时间字符串2026-04-07 10:30:25",
required=True
),
ToolParameter(
name="input_format",
type=ParameterType.STRING,
description="输入时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
),
ToolParameter(
name="from_timezone",
type=ParameterType.STRING,
description="源时区UTC, Asia/Shanghai",
required=False,
default="Asia/Shanghai"
)
]
else:
return []
@@ -239,64 +209,6 @@ class OperationTool(BaseTool):
else:
return base_params
def _get_openclaw_params(self) -> List[ToolParameter]:
"""获取 openclaw_tool 特定操作的参数"""
if self.operation == "print_task":
return [
ToolParameter(
name="message",
type=ParameterType.STRING,
description="发送给 OpenClaw 的打印任务描述,将用户的原始消息原封不动地传递给 OpenClaw禁止改写、补充或润色用户的原文",
required=True
),
ToolParameter(
name="image_url",
type=ParameterType.STRING,
description="可选附带的设计图片或参考图OpenClaw 可据此生成 3D 模型",
required=False
)
]
elif self.operation == "device_query":
return [
ToolParameter(
name="message",
type=ParameterType.STRING,
description="发送给 OpenClaw 的设备查询指令",
required=True
)
]
elif self.operation == "image_understand":
return [
ToolParameter(
name="message",
type=ParameterType.STRING,
description="发送给 OpenClaw 的图片理解任务,应描述需要对图片做什么(如描述内容、提取文字、分析信息)",
required=True
),
ToolParameter(
name="image_url",
type=ParameterType.STRING,
description="要分析的图片 URL 或 base64 data URI",
required=False
)
]
else:
# general 及其他
return [
ToolParameter(
name="message",
type=ParameterType.STRING,
description="发送给 OpenClaw Agent 的任务描述,应包含完整的任务需求",
required=True
),
ToolParameter(
name="image_url",
type=ParameterType.STRING,
description="可选,附带的图片 URL 或 base64 data URI",
required=False
)
]
async def execute(self, **kwargs) -> ToolResult:
"""执行特定操作"""
# 添加operation参数

View File

@@ -1,15 +0,0 @@
{
"name": "openclaw_tool",
"description": "调用OpenClaw Agent远程服务",
"tool_class": "OpenClawTool",
"category": "agent",
"requires_config": true,
"version": "1.0.0",
"enabled": true,
"parameters": {
"server_url": "",
"api_key": "",
"agent_id": "main"
},
"tags": ["agent", "openclaw", "multimodal", "3d-printing", "builtin"]
}

View File

@@ -30,18 +30,5 @@
"parameters": {
"api_key": {"type": "string", "description": "百度搜索API密钥", "sensitive": true, "required": true}
}
},
"openclaw": {
"name": "OpenClaw远程Agent",
"description": "OpenClaw Agent远程服务",
"tool_class": "OpenClawTool",
"category": "agent",
"requires_config": true,
"version": "1.0.0",
"enabled": true,
"parameters": {
"server_url": {"type": "string", "description": "OpenClaw Gateway 地址", "required": true},
"api_key": {"type": "string", "description": "OpenClaw API Key", "sensitive": true, "required": true}
}
}
}

View File

@@ -131,7 +131,7 @@ class LangchainAdapter:
def _tool_supports_operations(tool: BaseTool) -> bool:
"""检查工具是否支持多操作"""
# 内置工具中支持操作的工具
builtin_operation_tools = ['datetime_tool', 'json_tool', 'openclaw_tool']
builtin_operation_tools = ['datetime_tool', 'json_tool']
# 检查内置工具
if tool.tool_type.value == "builtin" and tool.name in builtin_operation_tools:

View File

@@ -99,7 +99,7 @@ class SimpleMCPClient:
# 建立 SSE 连接
response = await self._session.get(self.server_url)
if not (200 <= response.status < 300):
if response.status not in (200, 202):
error_text = await response.text()
raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}")
@@ -190,7 +190,9 @@ class SimpleMCPClient:
try:
async with self._session.post(self._endpoint_url, json=request) as response:
if not (200 <= response.status < 300):
# MCP SSE 协议POST 请求返回 200 或 202 均为正常
# 202 Accepted 表示请求已接受,结果通过 SSE 流异步返回
if response.status not in (200, 202):
error_text = await response.text()
raise MCPConnectionError(f"请求失败 {response.status}: {error_text}")
@@ -205,7 +207,7 @@ class SimpleMCPClient:
raise MCPConnectionError("endpoint URL 未初始化")
async with self._session.post(self._endpoint_url, json=notification) as response:
if not (200 <= response.status < 300):
if response.status not in (200, 202):
logger.warning(f"通知发送失败: {response.status}")
async def _initialize_modelscope_session(self):
@@ -223,7 +225,7 @@ class SimpleMCPClient:
try:
async with self._session.post(self.server_url, json=init_request) as response:
if not (200 <= response.status < 300):
if response.status != 200:
error_text = await response.text()
raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}")

View File

@@ -40,7 +40,6 @@ class WorkflowParserResult(BaseModel):
edges: list[EdgeDefinition] = Field(default_factory=list)
nodes: list[NodeDefinition] = Field(default_factory=list)
variables: list[VariableDefinition] = Field(default_factory=list)
features: dict[str, Any] = Field(default_factory=dict)
warnings: list[ExceptionDefinition] = Field(default_factory=list)
errors: list[ExceptionDefinition] = Field(default_factory=list)
@@ -52,7 +51,6 @@ class WorkflowImportResult(BaseModel):
edges: list[EdgeDefinition] = Field(default_factory=list)
nodes: list[NodeDefinition] = Field(default_factory=list)
variables: list[VariableDefinition] = Field(default_factory=list)
features: dict[str, Any] = Field(default_factory=dict)
warnings: list[ExceptionDefinition] = Field(default_factory=list)
errors: list[ExceptionDefinition] = Field(default_factory=list)

View File

@@ -15,7 +15,7 @@ from app.core.workflow.adapters.errors import (
ExceptionType
)
from app.core.workflow.nodes.assigner.config import AssignmentItem
from app.core.workflow.nodes.base_config import VariableDefinition as NodeVariableDefinition, BaseNodeConfig
from app.core.workflow.nodes.base_config import VariableDefinition, BaseNodeConfig
from app.core.workflow.nodes.code.config import InputVariable, OutputVariable
from app.core.workflow.nodes.configs import (
StartNodeConfig,
@@ -32,17 +32,13 @@ from app.core.workflow.nodes.configs import (
NoteNodeConfig,
ParameterExtractorNodeConfig,
QuestionClassifierNodeConfig,
VariableAggregatorNodeConfig,
ListOperatorNodeConfig,
DocExtractorNodeConfig,
VariableAggregatorNodeConfig
)
from app.schemas.workflow_schema import VariableDefinition as SchemaVariableDefinition
from app.core.workflow.nodes.cycle_graph.config import (
ConditionDetail as LoopConditionDetail,
ConditionsConfig,
CycleVariable
)
from app.core.workflow.nodes.list_operator.config import FilterCondition
from app.core.workflow.nodes.enums import (
ValueInputType,
ComparisonOperator,
@@ -94,12 +90,9 @@ class DifyConverter(BaseConverter):
NodeType.VAR_AGGREGATOR: self.convert_variable_aggregator_node_config,
NodeType.TOOL: self.convert_tool_node_config,
NodeType.NOTES: self.convert_notes_config,
NodeType.LIST_OPERATOR: self.convert_list_operator_node_config,
NodeType.DOCUMENT_EXTRACTOR: self.convert_document_extractor_node_config,
NodeType.CYCLE_START: lambda x: {},
NodeType.BREAK: lambda x: {},
}
self._file_vars_to_conv: list[SchemaVariableDefinition] = []
def get_node_convert(self, node_type):
func = self.CONFIG_CONVERT_MAP.get(node_type, lambda x: {})
@@ -133,7 +126,7 @@ class DifyConverter(BaseConverter):
selector = var_selector.split('.')
if len(selector) not in [2, 3] and var_selector != "context":
raise Exception(f"invalid variable selector: {var_selector}")
if len(selector) == 3 and selector[0] in ("conversation", "sys"):
if len(selector) == 3:
selector = selector[1:]
if selector[0] == "conversation":
selector[0] = "conv"
@@ -220,9 +213,7 @@ class DifyConverter(BaseConverter):
"end with": ComparisonOperator.END_WITH,
"not contains": ComparisonOperator.NOT_CONTAINS,
"exists": ComparisonOperator.NOT_EMPTY,
"not exists": ComparisonOperator.EMPTY,
"in": ComparisonOperator.IN,
"not in": ComparisonOperator.NOT_IN,
"not exists": ComparisonOperator.EMPTY
}
return operator_map.get(operator, operator)
@@ -288,25 +279,19 @@ class DifyConverter(BaseConverter):
)
continue
if var_type in [VariableType.FILE, VariableType.ARRAY_FILE]:
# 开始节点不支持文件变量,转为会话变量
self._file_vars_to_conv.append(SchemaVariableDefinition(
name=var["variable"],
type=var_type.value,
required=var.get("required", False),
default=None,
description=var.get("label", ""),
))
self.warnings.append(ExceptionDefinition(
if var_type in ["file", "array[file]"]:
self.errors.append(
ExceptionDefinition(
type=ExceptionType.VARIABLE,
node_id=node["id"],
node_name=node_data["title"],
name=var["variable"],
detail=f"File variable '{var['variable']}' is not supported in start node, moved to conversation variables"
))
detail=f"Unsupported Variable type for start node: {var_type}"
)
)
continue
var_def = NodeVariableDefinition(
var_def = VariableDefinition(
name=var["variable"],
type=var_type,
required=var["required"],
@@ -491,11 +476,11 @@ class DifyConverter(BaseConverter):
node_data = node["data"]
result = IterationNodeConfig.model_construct(
input=self._process_list_variable_literal(node_data["iterator_selector"]),
parallel=node_data.get("is_parallel", False),
parallel_count=node_data.get("parallel_nums", 4),
parallel=node_data["is_parallel"],
parallel_count=node_data["parallel_nums"],
output=self._process_list_variable_literal(node_data["output_selector"]),
output_type=self.variable_type_map(node_data.get("output_type")),
flatten=node_data.get("flatten_output", False),
flatten=node_data["flatten_output"],
).model_dump()
self.config_validate(node["id"], node["data"]["title"], IterationNodeConfig, result)
@@ -504,23 +489,7 @@ class DifyConverter(BaseConverter):
def convert_assigner_node_config(self, node: dict) -> dict:
node_data = node["data"]
assignments = []
# Support both formats:
# 1. New format: node_data["items"] list
# 2. Flat format: assigned_variable_selector + input_variable_selector + write_mode
if "items" in node_data:
raw_items = node_data["items"]
elif "assigned_variable_selector" in node_data and "input_variable_selector" in node_data:
raw_items = [{
"variable_selector": node_data["assigned_variable_selector"],
"value": node_data["input_variable_selector"],
"input_type": ValueInputType.VARIABLE,
"operation": node_data.get("write_mode", "over-write"),
}]
else:
raw_items = []
for assignment in raw_items:
for assignment in node_data["items"]:
if assignment.get("operation") is None or assignment.get("value") is None:
continue
assignments.append(
@@ -802,119 +771,3 @@ class DifyConverter(BaseConverter):
show_author=node_data.get("showAuthor", True)
).model_dump()
return result
def convert_list_operator_node_config(self, node: dict) -> dict:
"""Dify list-operator — convert variable path array to {{ }} selector format."""
node_data = node["data"]
variable_path = node_data.get("variable", [])
input_list = self._process_list_variable_literal(variable_path) or ""
filter_by = node_data.get("filter_by", {"enabled": False, "conditions": []})
# Convert each condition's comparison_operator from Dify format to native
if filter_by.get("conditions"):
converted_conditions = []
for cond in filter_by["conditions"]:
converted_conditions.append({
**cond,
"comparison_operator": self.convert_compare_operator(
cond.get("comparison_operator", "")
)
})
filter_by = {**filter_by, "conditions": converted_conditions}
result = {
"input_list": input_list,
"filter_by": filter_by,
"order_by": node_data.get("order_by", {"enabled": False, "key": "", "value": "asc"}),
"limit": node_data.get("limit", {"enabled": False, "size": -1}),
"extract_by": node_data.get("extract_by", {"enabled": False, "serial": "1"}),
}
self.config_validate(node["id"], node["data"]["title"], ListOperatorNodeConfig, result)
return result
def convert_document_extractor_node_config(self, node: dict) -> dict:
"""Convert Dify document-extractor node to MemoryBear DocExtractorNodeConfig.
Dify document-extractor data fields:
variable_selector: list[str] - file variable path
"""
node_data = node["data"]
file_selector = self._process_list_variable_literal(
node_data.get("variable_selector", [])
) or ""
result = DocExtractorNodeConfig.model_construct(
file_selector=file_selector,
).model_dump()
self.config_validate(node["id"], node["data"]["title"], DocExtractorNodeConfig, result)
return result
@staticmethod
def convert_features(features: dict) -> dict:
"""Convert Dify features to MemoryBear FeaturesConfigForm format."""
if not features:
return {}
result: dict = {}
# opening_statement
opening = features.get("opening_statement", "")
suggested = features.get("suggested_questions", [])
result["opening_statement"] = {
"enabled": bool(opening),
"statement": opening or None,
"suggested_questions": suggested,
}
# citation (对应 Dify retriever_resource)
retriever = features.get("retriever_resource", {})
result["citation"] = {
"enabled": retriever.get("enabled", False) if isinstance(retriever, dict) else False,
}
# file_upload: Dify allowed_file_types 数组 -> 前端扁平字段
file_upload = features.get("file_upload", {})
allowed_types = file_upload.get("allowed_file_types", []) if file_upload else []
allowed_methods = file_upload.get("allowed_file_upload_methods", ["local_file", "remote_url"])
if isinstance(allowed_methods, list):
if len(allowed_methods) >= 2:
transfer_method = "both"
elif allowed_methods:
transfer_method = allowed_methods[0]
else:
transfer_method = "both"
else:
transfer_method = allowed_methods or "both"
file_config = file_upload.get("fileUploadConfig", {})
result["file_upload"] = {
"enabled": file_upload.get("enabled", False) if file_upload else False,
"image_enabled": "image" in allowed_types,
"image_max_size_mb": file_config.get("image_file_size_limit", 10) if file_config else 10,
"image_allowed_extensions": ["png", "jpg", "jpeg"],
"audio_enabled": "audio" in allowed_types,
"audio_max_size_mb": file_config.get("audio_file_size_limit", 50) if file_config else 50,
"audio_allowed_extensions": ["mp3", "wav", "m4a"],
"document_enabled": "document" in allowed_types,
"document_max_size_mb": file_config.get("file_size_limit", 100) if file_config else 100,
"document_allowed_extensions": ["pdf", "docx", "doc", "xlsx", "xls", "txt", "csv", "json", "md"],
"video_enabled": "video" in allowed_types,
"video_max_size_mb": file_config.get("video_file_size_limit", 100) if file_config else 100,
"video_allowed_extensions": ["mp4", "mov"],
"max_file_count": file_upload.get("number_limits", 1) if file_upload else 1,
"allowed_transfer_methods": transfer_method,
}
# text_to_speech
tts = features.get("text_to_speech", {})
result["text_to_speech"] = {
"enabled": tts.get("enabled", False) if isinstance(tts, dict) else False,
"voice": tts.get("voice") if isinstance(tts, dict) else None,
"language": tts.get("language") if isinstance(tts, dict) else None,
"autoplay": False,
}
# suggested_questions_after_answer
sqa = features.get("suggested_questions_after_answer", {})
result["suggested_questions_after_answer"] = {
"enabled": sqa.get("enabled", False) if isinstance(sqa, dict) else False,
}
return result

View File

@@ -45,8 +45,6 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
"question-classifier": NodeType.QUESTION_CLASSIFIER,
"variable-aggregator": NodeType.VAR_AGGREGATOR,
"tool": NodeType.TOOL,
"list-operator": NodeType.LIST_OPERATOR,
"document-extractor": NodeType.DOCUMENT_EXTRACTOR,
"": NodeType.NOTES
}
@@ -119,12 +117,9 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
if variable:
self.conv_variables.append(con_var)
# 开始节点的文件变量合并到会话变量
self.conv_variables.extend(self._file_vars_to_conv)
features = self.convert_features(
self.config.get("workflow", {}).get("features", {})
)
# for variables in config.get("workflow").get("environment_variables"):
# variable = self._convert_variable(variables)
# conv_variables.append(variable)
trigger = self._convert_trigger({})
execution_config = self._convert_execution({})
@@ -138,7 +133,6 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
edges=self.edges,
nodes=self.nodes,
variables=self.conv_variables,
features=features,
warnings=self.warnings,
errors=self.errors
)

View File

@@ -22,8 +22,6 @@ from app.core.workflow.nodes.configs import (
MemoryReadNodeConfig,
MemoryWriteNodeConfig,
NoteNodeConfig,
ListOperatorNodeConfig,
DocExtractorNodeConfig,
)
from app.core.workflow.nodes.enums import NodeType
@@ -53,8 +51,6 @@ class MemoryBearConverter(BaseConverter):
NodeType.MEMORY_READ: MemoryReadNodeConfig,
NodeType.MEMORY_WRITE: MemoryWriteNodeConfig,
NodeType.NOTES: NoteNodeConfig,
NodeType.LIST_OPERATOR: ListOperatorNodeConfig,
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNodeConfig,
}
@staticmethod

View File

@@ -31,9 +31,9 @@ logger = logging.getLogger(__name__)
# Example:
# "Hello {{user.name}}!" ->
# ["Hello ", "{{user.name}}", "!"]
_OUTPUT_PATTERN = re.compile(r'\{\{.*?}}|[^{]+|{')
_OUTPUT_PATTERN = re.compile(r'\{\{.*?}}|[^{}]+')
# Strict variable format: {{ node_id.field_name }}
_VARIABLE_PATTERN = re.compile(r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+(?:\.[a-zA-Z0-9_]+)?\s*}}')
_VARIABLE_PATTERN = re.compile(r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*}}')
class GraphBuilder:

View File

@@ -59,9 +59,6 @@ class WorkflowResultBuilder:
conversation_vars = variable_pool.get_all_conversation_vars()
sys_vars = variable_pool.get_all_system_vars()
# 汇总所有 knowledge 节点的 citations
citations = self.aggregate_citations(node_outputs)
return {
"status": "completed" if success else "failed",
"output": final_output,
@@ -74,25 +71,9 @@ class WorkflowResultBuilder:
"conversation_id": execution_context.conversation_id,
"elapsed_time": elapsed_time,
"token_usage": token_usage,
"citations": citations,
"error": result.get("error"),
}
@staticmethod
def aggregate_citations(node_outputs: dict) -> list:
"""从所有 knowledge 节点的输出中汇总 citations去重"""
seen = set()
citations = []
for node_output in node_outputs.values():
if not isinstance(node_output, dict):
continue
for c in node_output.get("citations", []):
key = c.get("document_id")
if key and key not in seen:
seen.add(key)
citations.append(c)
return citations
@staticmethod
def aggregate_token_usage(node_outputs: dict) -> dict[str, int] | None:
"""

View File

@@ -9,10 +9,10 @@ from app.core.workflow.nodes.enums import NodeType
def merge_activate_state(x, y):
merged = dict(x)
for k, v in y.items():
merged[k] = merged.get(k, False) or v
return merged
return {
k: x.get(k, False) or y.get(k, False)
for k in set(x) | set(y)
}
def merge_looping_state(x, y):

View File

@@ -14,7 +14,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
logger = get_logger(__name__)
SCOPE_PATTERN = re.compile(
r"\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+(?:\.[a-zA-Z0-9_]+)?\s*}}"
r"\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\s*}}"
)

View File

@@ -17,54 +17,6 @@ from app.core.workflow.variable.variable_objects import T, create_variable_insta
logger = logging.getLogger(__name__)
VARIABLE_PATTERN = re.compile(r"\{\{\s*(.*?)\s*}}")
class LazyVariableDict:
def __init__(self, source, literal):
self._source: dict[str, VariableStruct[Any]] = source
self._literal: bool = literal
self._cache = {}
def keys(self):
return self._source.keys()
def _resolve(self, key):
if key in self._cache:
return self._cache[key]
var_struct = self._source.get(key)
if var_struct is None:
return None
raw = var_struct.instance.get_value()
# literal 模式下 dict/list 保留结构,让 Jinja2 能继续访问子字段(如 .type
value = raw if (not self._literal or isinstance(raw, (dict, list))) else var_struct.instance.to_literal()
self._cache[key] = value
return value
def get(self, key, default=None):
value = self._resolve(key)
return default if value is None else value
def __getitem__(self, key):
value = self._resolve(key)
if value is None:
raise KeyError(key)
return value
def __getattr__(self, key):
if key.startswith('_'):
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{key}'")
return self._resolve(key)
def __contains__(self, key):
return key in self._source
def __iter__(self):
return iter(self._source)
def __len__(self):
return len(self._source)
class VariableSelector:
"""变量选择器
@@ -165,9 +117,10 @@ class VariablePool:
@staticmethod
def transform_selector(selector):
variable_literal = VARIABLE_PATTERN.sub(r"\1", selector).strip()
pattern = r"\{\{\s*(.*?)\s*\}\}"
variable_literal = re.sub(pattern, r"\1", selector).strip()
selector = VariableSelector.from_string(variable_literal).path
if len(selector) not in (2, 3):
if len(selector) != 2:
raise ValueError(f"Selector not valid - {selector}")
return selector
@@ -199,16 +152,6 @@ class VariablePool:
return None
return var_instance
@staticmethod
def _extract_field(struct: "VariableStruct", field: str | None) -> Any:
"""If field is given, drill into a dict/object variable's value."""
if field is None:
return struct.instance.get_value()
value = struct.instance.get_value()
if not isinstance(value, dict):
raise KeyError(f"Variable is not an object, cannot access field '{field}'")
return value.get(field)
def get_instance(
self,
selector: str,
@@ -263,14 +206,12 @@ class VariablePool:
Raises:
KeyError: If strict is True and the variable does not exist.
"""
path = self.transform_selector(selector)
variable_struct = self._get_variable_struct(selector)
if variable_struct is None:
if strict:
raise KeyError(f"{selector} not exist")
return default
if len(path) == 3:
return self._extract_field(variable_struct, path[2])
return variable_struct.instance.get_value()
def get_literal(
@@ -297,15 +238,12 @@ class VariablePool:
Raises:
KeyError: If strict is True and the variable does not exist.
"""
path = self.transform_selector(selector)
variable_struct = self._get_variable_struct(selector)
if variable_struct is None:
if strict:
raise KeyError(f"{selector} not exist")
return default
if len(path) == 3:
value = self._extract_field(variable_struct, path[2])
return str(value) if value is not None else ""
return variable_struct.instance.to_literal()
async def set(
@@ -336,7 +274,7 @@ class VariablePool:
namespace: str,
key: str,
value: Any,
var_type: VariableType | None,
var_type: VariableType,
mut: bool
):
if self.has(f"{namespace}.{key}"):
@@ -363,24 +301,7 @@ class VariablePool:
Returns:
变量是否存在
"""
path = self.transform_selector(selector)
struct = self._get_variable_struct(selector)
if struct is None:
return False
if len(path) == 3:
value = struct.instance.get_value()
return isinstance(value, dict) and path[2] in value
return True
def lazy_namespace(self, namespace: str, literal: bool = False) -> LazyVariableDict:
return LazyVariableDict(self.variables.get(namespace, {}), literal)
def lazy_all_node_outputs(self, literal: bool = False) -> dict[str, LazyVariableDict]:
return {
ns: LazyVariableDict(vars_dict, literal)
for ns, vars_dict in self.variables.items()
if ns not in ("sys", "conv")
}
return self._get_variable_struct(selector) is not None
def get_all_system_vars(self, literal=False) -> dict[str, Any]:
"""获取所有系统变量
@@ -518,23 +439,6 @@ class VariablePoolInitializer:
var_value = var_default
else:
var_value = DEFAULT_VALUE(var_type)
# Convert FileInput-format dicts to full FileObject dicts
if var_type == VariableType.FILE:
if not var_value:
continue
var_value = await self._resolve_file_default(var_value)
if not var_value:
continue
elif var_type == VariableType.ARRAY_FILE:
if not var_value:
var_value = []
else:
resolved = []
for item in var_value:
f = await self._resolve_file_default(item)
if f:
resolved.append(f)
var_value = resolved
await variable_pool.new(
namespace="conv",
key=var_name,
@@ -543,17 +447,6 @@ class VariablePoolInitializer:
mut=True
)
@staticmethod
async def _resolve_file_default(file_def: dict) -> dict | None:
"""Accept only already-resolved FileObject dicts (is_file=True).
FileInput-format dicts are converted at save time by WorkflowService._resolve_variables_file_defaults.
"""
if not isinstance(file_def, dict):
return None
if file_def.get("is_file"):
return file_def
return None
@staticmethod
async def _init_system_vars(
variable_pool: VariablePool,
@@ -586,3 +479,5 @@ class VariablePoolInitializer:
var_type=var_type,
mut=False
)

View File

@@ -395,8 +395,7 @@ class BaseNode(ABC):
"output": output,
"elapsed_time": elapsed_time,
"token_usage": token_usage,
"error": None,
**self._extract_extra_fields(business_result),
"error": None
}
final_output = {
"node_outputs": {self.node_id: node_output},
@@ -499,13 +498,6 @@ class BaseNode(ABC):
# Default implementation returns the business result directly
return business_result
def _extract_extra_fields(self, business_result: Any) -> dict:
"""Extracts extra fields to merge into node_output (e.g. citations).
Subclasses may override to inject additional metadata.
"""
return {}
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
"""Extracts token usage information from the business result.
@@ -560,9 +552,9 @@ class BaseNode(ABC):
return render_template(
template=template,
conv_vars=variable_pool.lazy_namespace("conv", literal=True),
node_outputs=variable_pool.lazy_all_node_outputs(literal=True),
system_vars=variable_pool.lazy_namespace("sys", literal=True),
conv_vars=variable_pool.get_all_conversation_vars(literal=True),
node_outputs=variable_pool.get_all_node_outputs(literal=True),
system_vars=variable_pool.get_all_system_vars(literal=True),
strict=strict
)
@@ -587,9 +579,9 @@ class BaseNode(ABC):
return evaluate_condition(
expression=expression,
conv_var=variable_pool.lazy_namespace("conv"),
node_outputs=variable_pool.lazy_all_node_outputs(),
system_vars=variable_pool.lazy_namespace("sys")
conv_var=variable_pool.get_all_conversation_vars(),
node_outputs=variable_pool.get_all_node_outputs(),
system_vars=variable_pool.get_all_system_vars()
)
@staticmethod

View File

@@ -13,7 +13,7 @@ from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes import BaseNode
from app.core.workflow.nodes.code.config import CodeNodeConfig
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
from app.core.workflow.variable.base_variable import VariableType
logger = logging.getLogger(__name__)
@@ -70,8 +70,7 @@ class CodeNode(BaseNode):
for output in self.typed_config.output_variables:
value = exec_result.get(output.name)
if value is None:
result[output.name] = DEFAULT_VALUE(output.type)
continue
raise RuntimeError(f"Return value {output.name} does not exist")
match output.type:
case VariableType.STRING:
if not isinstance(value, str):

View File

@@ -24,8 +24,6 @@ from app.core.workflow.nodes.start.config import StartNodeConfig
from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
from app.core.workflow.nodes.notes.config import NoteNodeConfig
from app.core.workflow.nodes.list_operator.config import ListOperatorNodeConfig
from app.core.workflow.nodes.document_extractor.config import DocExtractorNodeConfig
__all__ = [
# 基础类
@@ -51,7 +49,5 @@ __all__ = [
"MemoryReadNodeConfig",
"MemoryWriteNodeConfig",
"CodeNodeConfig",
"NoteNodeConfig",
"ListOperatorNodeConfig",
"DocExtractorNodeConfig",
"NoteNodeConfig"
]

View File

@@ -28,135 +28,86 @@ class IterationRuntime:
def __init__(
self,
start_id: str,
stream: bool,
graph: CompiledStateGraph,
node_id: str,
config: dict[str, Any],
state: WorkflowState,
variable_pool: VariablePool,
cycle_nodes: list,
cycle_edges: list,
child_variable_pool: VariablePool,
):
"""
Initialize the iteration runtime.
Args:
stream: Whether to run in streaming mode. When True, each iteration
uses graph.astream and emits cycle_item events in real time.
When False, graph.ainvoke is used instead.
node_id: The unique identifier of the iteration node in the workflow.
Also used as the variable namespace for item/index inside
the subgraph (e.g. {{ node_id.item }}).
config: Raw configuration dict for the iteration node, parsed into
IterationNodeConfig. Controls input/output variable selectors,
parallel execution settings, and output flattening.
state: The parent workflow state at the point the iteration node is
entered. Each task receives a copy of this state as its
starting point.
variable_pool: The parent VariablePool containing all variables available
at the time the iteration node executes, including sys.*,
conv.*, and outputs from upstream nodes. Used as the source
for deep-copying into each task's independent child pool.
cycle_nodes: List of node config dicts belonging to this iteration's
subgraph (i.e. nodes whose cycle field equals node_id).
Passed to GraphBuilder when constructing each task's subgraph.
cycle_edges: List of edge config dicts connecting nodes within the subgraph.
Passed to GraphBuilder alongside cycle_nodes.
graph: Compiled workflow graph capable of async invocation.
node_id: Unique identifier of the loop node.
config: Dictionary containing iteration node configuration.
state: Current workflow state at the point of iteration.
"""
self.start_id = start_id
self.stream = stream
self.graph = graph
self.state = state
self.node_id = node_id
self.typed_config = IterationNodeConfig(**config)
self.looping = True
self.variable_pool = variable_pool
self.cycle_nodes = cycle_nodes
self.cycle_edges = cycle_edges
self.child_variable_pool = child_variable_pool
self.event_write = get_stream_writer()
self.checkpoint = RunnableConfig(
configurable={
"thread_id": uuid.uuid4()
}
)
self.output_value = None
self.result: list = []
def _build_child_graph(self) -> tuple[CompiledStateGraph, VariablePool, str]:
async def _init_iteration_state(self, item, idx):
"""
Build an independent compiled subgraph for a single iteration task.
Each call creates a brand-new VariablePool by deep-copying the parent pool,
then passes it to GraphBuilder. GraphBuilder binds this pool to every node's
execution closure at build time, so the pool and the subgraph always reference
the same object. This is the key design invariant: item/index written into the
pool after build will be visible to all nodes inside the subgraph.
Returns:
graph: The compiled LangGraph subgraph ready for invocation.
child_pool: The VariablePool bound to this subgraph's node closures.
Callers must write item/index into this pool before invoking
the graph, and read output from it after invocation.
start_node_id: The ID of the CYCLE_START node inside the subgraph,
used to set the initial activation signal in workflow state.
"""
from app.core.workflow.engine.graph_builder import GraphBuilder
child_pool = VariablePool()
child_pool.copy(self.variable_pool)
builder = GraphBuilder(
{"nodes": self.cycle_nodes, "edges": self.cycle_edges},
stream=self.stream,
variable_pool=child_pool,
cycle=self.node_id,
)
graph = builder.build()
return graph, builder.variable_pool, builder.start_node_id
async def _init_iteration_state(self, item, idx, child_pool: VariablePool, start_id: str):
"""
Initialize the workflow state for a single iteration.
Writes the current item and its index into child_pool under the iteration
node's namespace (e.g. iteration_xxx.item, iteration_xxx.index), making them
accessible to downstream nodes inside the subgraph via variable selectors.
Also prepares a copy of the parent workflow state with:
- node_outputs[node_id] set to {item, index} so the state snapshot is consistent
with the pool values.
- looping flag set to 1 (active) to signal the subgraph is inside a cycle.
- activate[start_id] set to True to trigger the CYCLE_START node.
Initialize a per-iteration copy of the workflow state.
Args:
item: The current element from the input array.
idx: The zero-based index of this element in the input array.
child_pool: The VariablePool bound to this iteration's subgraph.
Must be the same object returned by _build_child_graph.
start_id: The ID of the CYCLE_START node inside the subgraph.
item: Current element from the input array for this iteration.
idx: Index of the element in the input array.
Returns:
A WorkflowState instance ready to be passed to graph.ainvoke or graph.astream.
A copy of the workflow state with iteration-specific variables set.
"""
loopstate = WorkflowState(**self.state)
await child_pool.new(self.node_id, "item", item, VariableType.type_map(item), mut=True)
await child_pool.new(self.node_id, "index", idx, VariableType.type_map(idx), mut=True)
loopstate["node_outputs"][self.node_id] = {"item": item, "index": idx}
loopstate = WorkflowState(
**self.state
)
self.child_variable_pool.copy(self.variable_pool)
await self.child_variable_pool.new(self.node_id, "item", item, VariableType.type_map(item), mut=True)
await self.child_variable_pool.new(self.node_id, "index", item, VariableType.type_map(item), mut=True)
loopstate["node_outputs"][self.node_id] = {
"item": item,
"index": idx,
}
loopstate["looping"] = 1
loopstate["activate"][start_id] = True
loopstate["activate"][self.start_id] = True
return loopstate
def _merge_conv_vars(self, child_pool: VariablePool):
self.variable_pool.variables["conv"].update(child_pool.variables["conv"])
def merge_conv_vars(self):
self.variable_pool.variables["conv"].update(
self.child_variable_pool.variables["conv"]
)
async def run_task(self, item, idx):
"""
Execute a single iteration asynchronously.
Each task builds its own subgraph so the variable pool closure is independent.
Returns:
Tuple of (idx, output, result, child_pool, stopped)
Args:
item: The input element for this iteration.
idx: The index of this iteration.
"""
graph, child_pool, start_id = self._build_child_graph()
checkpoint = RunnableConfig(configurable={"thread_id": uuid.uuid4()})
init_state = await self._init_iteration_state(item, idx, child_pool, start_id)
if self.stream:
async for event in graph.astream(
init_state,
async for event in self.graph.astream(
await self._init_iteration_state(item, idx),
stream_mode=["debug"],
config=checkpoint
config=self.checkpoint
):
if isinstance(event, tuple) and len(event) == 2:
mode, data = event
@@ -166,6 +117,7 @@ class IterationRuntime:
event_type = data.get("type")
payload = data.get("payload", {})
node_name = payload.get("name")
if node_name and node_name.startswith("nop"):
continue
if event_type == "task_result":
@@ -188,13 +140,17 @@ class IterationRuntime:
"token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
}
})
result = graph.get_state(config=checkpoint).values
result = self.graph.get_state(config=self.checkpoint).values
else:
result = await graph.ainvoke(init_state)
output = child_pool.get_value(self.output_value)
stopped = result["looping"] == 2
return idx, output, result, child_pool, stopped
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
output = self.child_variable_pool.get_value(self.output_value)
if isinstance(output, list) and self.typed_config.flatten:
self.result.extend(output)
else:
self.result.append(output)
if result["looping"] == 2:
self.looping = False
return result
def _create_iteration_tasks(self, array_obj, idx):
"""
@@ -240,32 +196,16 @@ class IterationRuntime:
tasks = self._create_iteration_tasks(array_obj, idx)
logger.info(f"Iteration node {self.node_id}: running, concurrency {len(tasks)}")
idx += self.typed_config.parallel_count
batch = await asyncio.gather(*tasks)
# Sort by idx to preserve order, then collect results
batch_sorted = sorted(batch, key=lambda x: x[0])
for _, output, result, child_pool, stopped in batch_sorted:
if isinstance(output, list) and self.typed_config.flatten:
self.result.extend(output)
else:
self.result.append(output)
child_state.append(result)
self._merge_conv_vars(child_pool)
if stopped:
self.looping = False
child_state.extend(await asyncio.gather(*tasks))
self.merge_conv_vars()
else:
# Execute iterations sequentially
while idx < len(array_obj) and self.looping:
logger.info(f"Iteration node {self.node_id}: running")
item = array_obj[idx]
_, output, result, child_pool, stopped = await self.run_task(item, idx)
if isinstance(output, list) and self.typed_config.flatten:
self.result.extend(output)
else:
self.result.append(output)
self._merge_conv_vars(child_pool)
result = await self.run_task(item, idx)
self.merge_conv_vars()
child_state.append(result)
if stopped:
self.looping = False
idx += 1
logger.info(f"Iteration node {self.node_id}: execution completed")
return {

View File

@@ -11,6 +11,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator, NodeType
from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance
from app.core.workflow.utils.expression_evaluator import evaluate_expression
logger = logging.getLogger(__name__)
@@ -84,7 +85,12 @@ class LoopRuntime:
for variable in self.typed_config.cycle_vars:
if variable.input_type == ValueInputType.VARIABLE:
value = self.variable_pool.get_value(variable.value)
value = evaluate_expression(
expression=variable.value,
conv_var=self.variable_pool.get_all_conversation_vars(),
node_outputs=self.variable_pool.get_all_node_outputs(),
system_vars=self.variable_pool.get_all_system_vars(),
)
else:
value = TypeTransformer.transform(variable.value, variable.type)
await self.child_variable_pool.new(self.node_id, variable.name, value, variable.type, mut=True)
@@ -92,7 +98,12 @@ class LoopRuntime:
**self.state
)
loopstate["node_outputs"][self.node_id] = {
variable.name: self.variable_pool.get_value(variable.value)
variable.name: evaluate_expression(
expression=variable.value,
conv_var=self.variable_pool.get_all_conversation_vars(),
node_outputs=self.variable_pool.get_all_node_outputs(),
system_vars=self.variable_pool.get_all_system_vars(),
)
if variable.input_type == ValueInputType.VARIABLE
else TypeTransformer.transform(variable.value, variable.type)
for variable in self.typed_config.cycle_vars

View File

@@ -55,9 +55,9 @@ class CycleGraphNode(BaseNode):
if config.output_type in [
VariableType.ARRAY_FILE,
VariableType.ARRAY_STRING,
VariableType.ARRAY_NUMBER,
VariableType.NUMBER,
VariableType.ARRAY_OBJECT,
VariableType.ARRAY_BOOLEAN
VariableType.BOOLEAN
]:
if config.flatten:
outputs['output'] = config.output_type
@@ -123,7 +123,7 @@ class CycleGraphNode(BaseNode):
return cycle_nodes, cycle_edges
def build_graph(self, variable_pool: VariablePool):
def build_graph(self):
"""
Build and compile the internal subgraph for this cycle node.
@@ -135,7 +135,6 @@ class CycleGraphNode(BaseNode):
from app.core.workflow.engine.graph_builder import GraphBuilder
self.child_variable_pool = VariablePool()
self.child_variable_pool.copy(variable_pool)
builder = GraphBuilder(
{
"nodes": self.cycle_nodes,
@@ -166,8 +165,8 @@ class CycleGraphNode(BaseNode):
Raises:
RuntimeError: If the node type is unsupported.
"""
self.build_graph()
if self.node_type == NodeType.LOOP:
self.build_graph(variable_pool)
return await LoopRuntime(
start_id=self.start_node_id,
stream=False,
@@ -180,19 +179,20 @@ class CycleGraphNode(BaseNode):
).run()
if self.node_type == NodeType.ITERATION:
return await IterationRuntime(
start_id=self.start_node_id,
stream=False,
graph=self.graph,
node_id=self.node_id,
config=self.config,
state=state,
variable_pool=variable_pool,
cycle_nodes=self.cycle_nodes,
cycle_edges=self.cycle_edges,
child_variable_pool=self.child_variable_pool
).run()
raise RuntimeError("Unknown cycle node type")
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
self.build_graph()
if self.node_type == NodeType.LOOP:
self.build_graph(variable_pool)
yield {
"__final__": True,
"result": await LoopRuntime(
@@ -211,13 +211,14 @@ class CycleGraphNode(BaseNode):
yield {
"__final__": True,
"result": await IterationRuntime(
start_id=self.start_node_id,
stream=True,
graph=self.graph,
node_id=self.node_id,
config=self.config,
state=state,
variable_pool=variable_pool,
cycle_nodes=self.cycle_nodes,
cycle_edges=self.cycle_edges,
child_variable_pool=self.child_variable_pool
).run()
}
return

View File

@@ -14,22 +14,12 @@ logger = logging.getLogger(__name__)
def _file_object_to_file_input(f: FileObject) -> FileInput:
"""Convert workflow FileObject to multimodal FileInput."""
file_type = f.origin_file_type or ""
# Prefer mime_type for more accurate type detection
if not file_type and f.mime_type:
file_type = f.mime_type
resolved_type = FileType.trans(f.type) if isinstance(f.type, str) else f.type
if resolved_type != FileType.DOCUMENT:
raise ValueError(
f"Document extractor only supports document files, got type '{f.type}' "
f"(name={f.name or f.file_id or f.url})"
)
return FileInput(
type=resolved_type,
type=FileType.DOCUMENT,
transfer_method=TransferMethod(f.transfer_method),
url=f.url or None,
upload_file_id=f.file_id or None,
file_type=file_type,
file_type=f.origin_file_type or "",
)
@@ -91,7 +81,6 @@ class DocExtractorNode(BaseNode):
from app.services.multimodal_service import MultimodalService
svc = MultimodalService(db)
for f in files:
label = f.name or f.url or f.file_id
try:
file_input = _file_object_to_file_input(f)
# Ensure URL is populated for local files
@@ -100,11 +89,11 @@ class DocExtractorNode(BaseNode):
# Reuse cached bytes if already fetched
if f.get_content():
file_input.set_content(f.get_content())
text = await svc.extract_document_text(file_input)
text = await svc._extract_document_text(file_input)
chunks.append(text)
except Exception as e:
logger.error(
f"Node {self.node_id}: failed to extract file '{label}': {e}",
f"Node {self.node_id}: failed to extract file url={f.url} file_id={f.file_id}: {e}",
exc_info=True,
)
chunks.append("")

View File

@@ -24,7 +24,6 @@ class NodeType(StrEnum):
MEMORY_READ = "memory-read"
MEMORY_WRITE = "memory-write"
DOCUMENT_EXTRACTOR = "document-extractor"
LIST_OPERATOR = "list-operator"
UNKNOWN = "unknown"
NOTES = "notes"
@@ -46,8 +45,6 @@ class ComparisonOperator(StrEnum):
LE = "le"
GT = "gt"
GE = "ge"
IN = "in"
NOT_IN = "not_in"
class LogicOperator(StrEnum):

View File

@@ -72,9 +72,8 @@ class HttpContentTypeConfig(BaseModel):
@classmethod
def validate_data(cls, v, info):
content_type = info.data.get("content_type")
if content_type == HttpContentType.FROM_DATA and (
not isinstance(v, list) or not all(isinstance(item, HttpFormData) for item in v)):
raise ValueError("When content_type is 'form-data', data must be a list of HttpFormData")
if content_type == HttpContentType.FROM_DATA and not isinstance(v, HttpFormData):
raise ValueError("When content_type is 'form-data', data must be of type HttpFormData")
elif content_type in [HttpContentType.JSON] and not isinstance(v, str):
raise ValueError("When content_type is JSON, data must be of type str")
elif content_type in [HttpContentType.WWW_FORM] and not isinstance(v, dict):

View File

@@ -260,22 +260,17 @@ class HttpRequestNode(BaseNode):
))
case HttpContentType.FROM_DATA:
data = {}
files = []
content["files"] = {}
for item in self.typed_config.body.data:
key = self._render_template(item.key, variable_pool)
if item.type == "text":
data[key] = self._render_template(item.value, variable_pool)
data[self._render_template(item.key, variable_pool)] = self._render_template(item.value,
variable_pool)
elif item.type == "file":
file_instance = variable_pool.get_instance(item.value)
if isinstance(file_instance, ArrayVariable):
for v in file_instance.value:
if isinstance(v, FileVariable):
files.append((key, (uuid.uuid4().hex, await v.get_content())))
elif isinstance(file_instance, FileVariable):
files.append((key, (uuid.uuid4().hex, await file_instance.get_content())))
content["files"][self._render_template(item.key, variable_pool)] = (
uuid.uuid4().hex,
await variable_pool.get_instance(item.value).get_content()
)
content["data"] = data
if files:
content["files"] = files
case HttpContentType.BINARY:
content["files"] = []
file_instence = variable_pool.get_instance(self.typed_config.body.data)

View File

@@ -1,25 +1,19 @@
import asyncio
import logging
import uuid
from typing import Any
from langchain_core.documents import Document
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.models import RedBearRerank, RedBearModelConfig
from app.core.rag.llm.chat_model import Base
from app.core.rag.llm.embedding_model import OpenAIEmbed
from app.core.rag.models.chunk import DocumentChunk
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory, ElasticSearchVector
from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.db import get_db_read
from app.models import knowledge_model, ModelType
from app.repositories import knowledge_repository
from app.models import knowledge_model, knowledgeshare_model, ModelType
from app.repositories import knowledge_repository, knowledgeshare_repository
from app.schemas.chunk_schema import RetrieveType
from app.services.model_service import ModelConfigService
@@ -30,27 +24,13 @@ class KnowledgeRetrievalNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
self.vector_service: ElasticSearchVector | None = None
def _output_types(self) -> dict[str, VariableType]:
return {
"output": VariableType.ARRAY_STRING
}
def _extract_output(self, business_result: Any) -> Any:
"""下游节点只拿 chunks 列表"""
if isinstance(business_result, dict) and "chunks" in business_result:
return business_result["chunks"]
return business_result
@staticmethod
def _extract_citations(business_result: Any) -> list:
if isinstance(business_result, dict):
return business_result.get("citations", [])
return []
def _extract_extra_fields(self, business_result: Any) -> dict:
return {"citations": self._extract_citations(business_result)}
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
return {
"query": self._render_template(self.typed_config.query, variable_pool),
@@ -105,54 +85,46 @@ class KnowledgeRetrievalNode(BaseNode):
unique.append(doc)
return unique
def rerank(self, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]:
def _get_existing_kb_ids(self, db, kb_ids):
"""
Reorder the list of document blocks and return the top_k results most relevant to the query
Resolve all accessible and valid knowledge base IDs for retrieval.
This includes:
- Private knowledge bases owned by the user
- Shared knowledge bases
- Source knowledge bases mapped via knowledge sharing relationships
Args:
query: query string
docs: List of document chunk to be rearranged
top_k: The number of top-level documents returned
db: Database session.
kb_ids (list[UUID]): Knowledge base IDs from node configuration.
Returns:
Rearranged document chunk list (sorted in descending order of relevance)
Raises:
ValueError: If the input document list is empty or top_k is invalid
list[UUID]: Final list of valid knowledge base IDs.
"""
reranker = self.get_reranker_model()
# parameter validation
if not docs:
raise ValueError("retrieval chunks be empty")
if top_k <= 0:
raise ValueError("top_k must be a positive integer")
try:
# Convert to LangChain Document object
documents = [
Document(
page_content=doc.page_content, # Ensure that DocumentChunk possesses this attribute
metadata=doc.metadata or {} # Deal with possible None metadata
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Private)
existing_ids = knowledge_repository.get_chunked_knowledgeids(
db=db,
filters=filters
)
for doc in docs
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Share)
share_ids = knowledge_repository.get_chunked_knowledgeids(
db=db,
filters=filters
)
if share_ids:
filters = [
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(kb_ids)
]
# Perform reordering (compress_documents will automatically handle relevance scores and indexing)
reranked_docs = list(reranker.compress_documents(documents, query))
# Sort in descending order based on relevance score
reranked_docs.sort(
key=lambda x: x.metadata.get("relevance_score", 0),
reverse=True
items = knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
db=db,
filters=filters
)
# Convert back to a list of DocumentChunk, and save the relevance_score to metadata["score"]
result = []
for item in reranked_docs[:top_k]:
for doc in docs:
if doc.page_content == item.page_content:
doc.metadata["score"] = item.metadata["relevance_score"]
result.append(doc)
return result
except Exception as e:
raise RuntimeError(f"Failed to rerank documents: {str(e)}") from e
existing_ids.extend(items)
return existing_ids
def get_reranker_model(self) -> RedBearRerank:
"""
@@ -192,113 +164,49 @@ class KnowledgeRetrievalNode(BaseNode):
)
return reranker
async def knowledge_retrieval(self, db, query, db_knowledge, kb_config):
rs = []
def knowledge_retrieval(self, db, query, rs, db_knowledge, kb_config):
if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER:
children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id)
tasks = []
for child in children:
if not (child and child.chunk_num > 0 and child.status == 1):
continue
child_kb_config = kb_config.model_copy()
child_kb_config.kb_id = child.id
tasks.append(self.knowledge_retrieval(db, query, child, child_kb_config))
if tasks:
result = await asyncio.gather(*tasks)
for _ in result:
rs.extend(_)
return rs
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
kb_config.kb_id = child.id
self.knowledge_retrieval(db, query, rs, child, kb_config)
return
self.vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
match kb_config.retrieve_type:
case RetrieveType.PARTICIPLE:
rs.extend(
await asyncio.to_thread(
vector_service.search_by_full_text, **{
"query": query,
"top_k": kb_config.top_k,
"indices": indices,
"score_threshold": kb_config.similarity_threshold
}
)
)
rs.extend(self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.similarity_threshold))
case RetrieveType.SEMANTIC:
rs.extend(
await asyncio.to_thread(
vector_service.search_by_vector, **{
"query": query,
"top_k": kb_config.top_k,
"indices": indices,
"score_threshold": kb_config.vector_similarity_weight
}
)
)
case retrieve_type if retrieve_type in (RetrieveType.HYBRID, RetrieveType.Graph):
rs1_task = asyncio.to_thread(
vector_service.search_by_vector, **{
"query": query,
"top_k": kb_config.top_k,
"indices": indices,
"score_threshold": kb_config.vector_similarity_weight
}
)
rs2_task = asyncio.to_thread(
vector_service.search_by_full_text, **{
"query": query,
"top_k": kb_config.top_k,
"indices": indices,
"score_threshold": kb_config.similarity_threshold
}
)
rs1, rs2 = await asyncio.gather(rs1_task, rs2_task)
rs.extend(self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.vector_similarity_weight))
case RetrieveType.HYBRID:
rs1 = self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.vector_similarity_weight)
rs2 = self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.similarity_threshold)
# Deduplicate hybrid retrieval results
unique_rs = self._deduplicate_docs(rs1, rs2)
if not unique_rs:
return []
return
if self.typed_config.reranker_id:
rs.extend(
await asyncio.to_thread(
self.rerank,
**{"query": query, "docs": unique_rs, "top_k": kb_config.top_k}
)
)
self.vector_service.reranker = self.get_reranker_model()
rs.extend(self.vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
else:
rs.extend(sorted(
unique_rs,
key=lambda d: d.metadata.get("score", 0),
reverse=True
)[:kb_config.top_k])
if kb_config.retrieve_type == RetrieveType.Graph:
from app.core.rag.common.settings import kg_retriever
llm_key = self.model_balance(db_knowledge.llm)
emb_key = self.model_balance(db_knowledge.embedding)
chat_model = Base(
key=llm_key.api_key,
model_name=llm_key.model_name,
base_url=llm_key.api_base
)
embedding_model = OpenAIEmbed(
key=emb_key.api_key,
model_name=emb_key.model_name,
base_url=emb_key.api_base
)
doc = await asyncio.to_thread(
kg_retriever.retrieval,
question=query,
workspace_ids=[str(db_knowledge.workspace_id)],
kb_ids=[str(kb_config.kb_id)],
emb_mdl=embedding_model,
llm=chat_model
)
if doc:
rs.insert(0, DocumentChunk(
page_content=doc.get("page_content", ""),
metadata=doc.get("metadata", {})
))
case _:
raise RuntimeError("Unknown retrieval type")
return rs
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
"""
@@ -330,24 +238,17 @@ class KnowledgeRetrievalNode(BaseNode):
knowledge_bases = self.typed_config.knowledge_bases
rs = []
tasks = []
for kb_config in knowledge_bases:
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
if not db_knowledge:
raise RuntimeError("The knowledge base does not exist or access is denied.")
tasks.append(self.knowledge_retrieval(db, query, db_knowledge, kb_config))
if tasks:
result = await asyncio.gather(*tasks)
for _ in result:
rs.extend(_)
self.knowledge_retrieval(db, query, rs, db_knowledge, kb_config)
if not rs:
return []
if self.typed_config.reranker_id:
final_rs = await asyncio.to_thread(
self.rerank,
**{"query": query, "docs": rs, "top_k": self.typed_config.reranker_top_k}
)
self.vector_service.reranker = self.get_reranker_model()
final_rs = self.vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
else:
final_rs = sorted(
rs,
@@ -358,20 +259,4 @@ class KnowledgeRetrievalNode(BaseNode):
logger.info(
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
)
citations = []
seen_doc_ids = set()
for chunk in final_rs:
meta = chunk.metadata or {}
doc_id = meta.get("document_id") or meta.get("doc_id")
if doc_id and doc_id not in seen_doc_ids:
seen_doc_ids.add(doc_id)
citations.append({
"document_id": str(doc_id),
"file_name": meta.get("file_name", ""),
"knowledge_id": str(meta.get("knowledge_id", kb_config.kb_id)),
"score": meta.get("score", 0.0),
})
return {
"chunks": [chunk.page_content for chunk in final_rs],
"citations": citations,
}
return [chunk.page_content for chunk in final_rs]

View File

@@ -1,3 +0,0 @@
from .node import ListOperatorNode
__all__ = ["ListOperatorNode"]

View File

@@ -1,49 +0,0 @@
from typing import Any
from pydantic import BaseModel, Field, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.enums import ComparisonOperator
class FilterCondition(BaseModel):
key: str = ""
comparison_operator: ComparisonOperator = ComparisonOperator.CONTAINS
value: str | list[str] | bool = ""
class FilterBy(BaseModel):
enabled: bool = False
conditions: list[FilterCondition] = Field(default_factory=list)
class OrderByConfig(BaseModel):
enabled: bool = False
key: str = ""
value: str = "asc" # "asc" | "desc"
class Limit(BaseModel):
enabled: bool = False
size: int = -1
class ExtractConfig(BaseModel):
enabled: bool = False
serial: str = "1" # 1-based index string, e.g. "1" = first
@field_validator("serial", mode="before")
@classmethod
def coerce_serial(cls, v):
return str(v)
class ListOperatorNodeConfig(BaseNodeConfig):
"""
List Operator node config.
Operation order: filter -> extract -> order -> limit
"""
input_list: str = Field(..., description="Variable selector, e.g. {{ sys.files }} or {{ conv.uploaded_files }}")
filter_by: FilterBy = Field(default_factory=FilterBy)
order_by: OrderByConfig = Field(default_factory=OrderByConfig)
limit: Limit = Field(default_factory=Limit)
extract_by: ExtractConfig = Field(default_factory=ExtractConfig)

View File

@@ -1,150 +0,0 @@
import logging
from typing import Any
from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import ComparisonOperator
from app.core.workflow.nodes.list_operator.config import ListOperatorNodeConfig, FilterCondition
from app.core.workflow.variable.base_variable import VariableType
logger = logging.getLogger(__name__)
# File object fields that hold string values
_FILE_STRING_KEYS = {"type", "name", "url", "extension", "mime_type", "transfer_method", "origin_file_type", "file_id"}
_FILE_NUMBER_KEYS = {"size"}
class ListOperatorNode(BaseNode):
def __init__(self, node_config: dict, workflow_config: dict, down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: ListOperatorNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:
return {
"result": VariableType.ANY,
"first_record": VariableType.ANY,
"last_record": VariableType.ANY,
}
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
self.typed_config = ListOperatorNodeConfig(**self.config)
cfg = self.typed_config
# Resolve input variable from path selector
items: list = self.get_variable(cfg.input_list, variable_pool)
if not isinstance(items, list):
raise TypeError(f"Variable '{cfg.input_list}' must be an array, got {type(items)}")
result = list(items)
# 1. Filter
if cfg.filter_by.enabled and cfg.filter_by.conditions:
for condition in cfg.filter_by.conditions:
result = [item for item in result if self._match_condition(item, condition, variable_pool)]
# 2. Extract (take single item by 1-based serial index)
if cfg.extract_by.enabled:
serial_str = self._resolve_value(cfg.extract_by.serial, variable_pool)
idx = int(serial_str) - 1
if idx < 0 or idx >= len(result):
raise ValueError(f"extract_by.serial={cfg.extract_by.serial} out of range (list length={len(result)})")
result = [result[idx]]
# 3. Order
if cfg.order_by.enabled:
reverse = cfg.order_by.value == "desc"
key_fn = self._make_sort_key(cfg.order_by.key)
result = sorted(result, key=key_fn, reverse=reverse)
# 4. Limit (take first N)
if cfg.limit.enabled and cfg.limit.size > 0:
result = result[:cfg.limit.size]
return {
"result": result,
"first_record": result[0] if result else None,
"last_record": result[-1] if result else None,
}
@staticmethod
def _resolve_value(value: str, variable_pool: VariablePool) -> Any:
"""If value is a {{ namespace.key }} variable selector, resolve it from the pool.
Otherwise return the raw string."""
import re
m = re.fullmatch(r"\{\{\s*(\w+\.\w+)\s*}}", value.strip())
if m:
resolved = variable_pool.get_value(value, default=value, strict=False)
return resolved
return value
@staticmethod
def _make_sort_key(key: str):
def key_fn(item):
if isinstance(item, dict):
return item.get(key) or ""
return item
return key_fn
def _match_condition(self, item: Any, cond: FilterCondition, variable_pool: VariablePool) -> bool:
op = cond.comparison_operator
value = cond.value
# Resolve value if it's a variable reference {{ namespace.key }}
if isinstance(value, str):
value = self._resolve_value(value, variable_pool)
# Resolve left value
if isinstance(item, dict):
left = item.get(cond.key)
else:
left = item # primitive array: compare element directly
# Determine if this field should be compared as a string
is_string_field = isinstance(item, dict) and cond.key in _FILE_STRING_KEYS
# Numeric operators
if op == ComparisonOperator.EQ:
if is_string_field:
return str(left) == str(value)
return self._safe_num(left) == self._safe_num(value)
if op == ComparisonOperator.NE:
if is_string_field:
return str(left) != str(value)
return self._safe_num(left) != self._safe_num(value)
if op == ComparisonOperator.LT:
return self._safe_num(left) < self._safe_num(value)
if op == ComparisonOperator.LE:
return self._safe_num(left) <= self._safe_num(value)
if op == ComparisonOperator.GT:
return self._safe_num(left) > self._safe_num(value)
if op == ComparisonOperator.GE:
return self._safe_num(left) >= self._safe_num(value)
# String / sequence operators
left_str = str(left) if left is not None else ""
if op == ComparisonOperator.CONTAINS:
return str(value) in left_str
if op == ComparisonOperator.NOT_CONTAINS:
return str(value) not in left_str
if op == ComparisonOperator.START_WITH:
return left_str.startswith(str(value))
if op == ComparisonOperator.END_WITH:
return left_str.endswith(str(value))
if op == ComparisonOperator.IN:
return left_str in (value if isinstance(value, list) else [str(value)])
if op == ComparisonOperator.NOT_IN:
return left_str not in (value if isinstance(value, list) else [str(value)])
if op == ComparisonOperator.EMPTY:
return not left
if op == ComparisonOperator.NOT_EMPTY:
return bool(left)
raise ValueError(f"Unsupported operator: {op}")
@staticmethod
def _safe_num(v) -> float:
try:
return float(v)
except (TypeError, ValueError):
return 0.0

View File

@@ -213,10 +213,9 @@ class LLMNode(BaseNode):
messages = messages[:-1] + history_message + messages[-1:]
self.messages = messages
else:
# 使用简单的 prompt 格式(向后兼容)——包装为标准消息列表以兼容所有 provider
# 使用简单的 prompt 格式(向后兼容)
prompt_template = self.config.get("prompt", "")
rendered = self._render_template(prompt_template, variable_pool)
self.messages = [{"role": "user", "content": rendered}]
self.messages = self._render_template(prompt_template, variable_pool)
return llm
@@ -246,10 +245,7 @@ class LLMNode(BaseNode):
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
# 返回 AIMessage包含响应元数据
return AIMessage(content=content, response_metadata={
**response.response_metadata,
"token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage')
})
return AIMessage(content=content, response_metadata=response.response_metadata)
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""提取输入数据(用于记录)"""
@@ -308,16 +304,15 @@ class LLMNode(BaseNode):
# 调用 LLM流式支持字符串或消息列表
last_meta_data = {}
last_usage_metadata = {}
async for chunk in llm.astream(self.messages):
# 提取内容
if hasattr(chunk, 'content'):
content = self.process_model_output(chunk.content)
else:
content = str(chunk)
if hasattr(chunk, 'response_metadata') and chunk.response_metadata:
if hasattr(chunk, 'response_metadata'):
if chunk.response_metadata:
last_meta_data = chunk.response_metadata
if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
last_usage_metadata = chunk.usage_metadata
# 只有当内容不为空时才处理
if content:
@@ -340,10 +335,7 @@ class LLMNode(BaseNode):
# 构建完整的 AIMessage包含元数据
final_message = AIMessage(
content=full_response,
response_metadata={
**last_meta_data,
"token_usage": last_usage_metadata or last_meta_data.get('token_usage')
}
response_metadata=last_meta_data
)
# yield 完成标记

View File

@@ -27,7 +27,6 @@ from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
from app.core.workflow.nodes.breaker import BreakNode
from app.core.workflow.nodes.tool import ToolNode
from app.core.workflow.nodes.document_extractor import DocExtractorNode
from app.core.workflow.nodes.list_operator import ListOperatorNode
logger = logging.getLogger(__name__)
@@ -52,8 +51,7 @@ WorkflowNode = Union[
MemoryReadNode,
MemoryWriteNode,
CodeNode,
DocExtractorNode,
ListOperatorNode
DocExtractorNode
]
@@ -85,8 +83,7 @@ class NodeFactory:
NodeType.MEMORY_READ: MemoryReadNode,
NodeType.MEMORY_WRITE: MemoryWriteNode,
NodeType.CODE: CodeNode,
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNode,
NodeType.LIST_OPERATOR: ListOperatorNode
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNode
}
@classmethod

View File

@@ -12,7 +12,7 @@ from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
from app.core.workflow.variable.base_variable import VariableType
from app.db import get_db_read
from app.models import ModelType
from app.services.model_service import ModelConfigService
@@ -45,12 +45,6 @@ class ParameterExtractorNode(BaseNode):
"model_id": str(self.typed_config.model_id),
}
def _extract_output(self, business_result: Any) -> Any:
final_output = {}
for param in self.typed_config.params:
final_output[param.name] = business_result.get(param.name) or DEFAULT_VALUE(self.output_types[param.name])
return final_output
def _output_types(self) -> dict[str, VariableType]:
outputs = {}
for param in self.typed_config.params:
@@ -115,7 +109,6 @@ class ParameterExtractorNode(BaseNode):
api_key = api_config.api_key
api_base = api_config.api_base
is_omni = api_config.is_omni
capability = api_config.capability
model_type = config.type
llm = RedBearLLM(
@@ -208,10 +201,7 @@ class ParameterExtractorNode(BaseNode):
])
model_resp = await llm.ainvoke(messages)
self.response_metadata = {
**model_resp.response_metadata,
"token_usage": getattr(model_resp, 'usage_metadata', None) or model_resp.response_metadata.get('token_usage')
}
self.response_metadata = model_resp.response_metadata
model_message = self.process_model_output(model_resp.content)
result = json_repair.repair_json(model_message, return_objects=True)
logger.info(f"node: {self.node_id} get params:{result}")

View File

@@ -62,7 +62,6 @@ class QuestionClassifierNode(BaseNode):
api_key = api_config.api_key
base_url = api_config.api_base
is_omni = api_config.is_omni
capability = api_config.capability
model_type = config.type
return RedBearLLM(
@@ -136,10 +135,7 @@ class QuestionClassifierNode(BaseNode):
response = await llm.ainvoke(messages)
result = self.process_model_output(response.content)
self.response_metadata = {
**response.response_metadata,
"token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage')
}
self.response_metadata = response.response_metadata
if result in category_names:
category = result

Some files were not shown because too many files have changed in this diff Show More