Merge branch 'develop' into feature/node_run
164
.github/workflows/release-notify-wechat.yml
vendored
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
name: Release Notify Workflow
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
types: [closed]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
notify:
|
||||||
|
if: >
|
||||||
|
github.event.pull_request.merged == true &&
|
||||||
|
startsWith(github.event.pull_request.base.ref, 'release')
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
# 防止 GitHub HEAD 未同步
|
||||||
|
- run: sleep 3
|
||||||
|
|
||||||
|
# 1️⃣ 获取分支 HEAD
|
||||||
|
- name: Get HEAD
|
||||||
|
id: head
|
||||||
|
run: |
|
||||||
|
HEAD_SHA=$(curl -s \
|
||||||
|
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
|
||||||
|
https://api.github.com/repos/${{ github.repository }}/git/ref/heads/${{ github.event.pull_request.base.ref }} \
|
||||||
|
| jq -r '.object.sha')
|
||||||
|
echo "head_sha=$HEAD_SHA" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
# 2️⃣ 判断是否最终PR
|
||||||
|
- name: Check Latest
|
||||||
|
id: check
|
||||||
|
run: |
|
||||||
|
if [ "${{ github.event.pull_request.merge_commit_sha }}" = "${{ steps.head.outputs.head_sha }}" ]; then
|
||||||
|
echo "ok=true" >> $GITHUB_OUTPUT
|
||||||
|
else
|
||||||
|
echo "ok=false" >> $GITHUB_OUTPUT
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 3️⃣ 尝试从 PR body 提取 Sourcery 摘要
|
||||||
|
- name: Extract Sourcery Summary
|
||||||
|
if: steps.check.outputs.ok == 'true'
|
||||||
|
id: sourcery
|
||||||
|
env:
|
||||||
|
PR_BODY: ${{ github.event.pull_request.body }}
|
||||||
|
run: |
|
||||||
|
python3 << 'PYEOF'
|
||||||
|
import os, re
|
||||||
|
|
||||||
|
body = os.environ.get("PR_BODY", "") or ""
|
||||||
|
match = re.search(
|
||||||
|
r"## Summary by Sourcery\s*\n(.*?)(?=\n## |\Z)",
|
||||||
|
body,
|
||||||
|
re.DOTALL
|
||||||
|
)
|
||||||
|
|
||||||
|
if match:
|
||||||
|
summary = match.group(1).strip()
|
||||||
|
found = "true"
|
||||||
|
else:
|
||||||
|
summary = ""
|
||||||
|
found = "false"
|
||||||
|
|
||||||
|
with open("sourcery_summary.txt", "w", encoding="utf-8") as f:
|
||||||
|
f.write(summary)
|
||||||
|
|
||||||
|
with open(os.environ["GITHUB_OUTPUT"], "a") as gh:
|
||||||
|
gh.write(f"found={found}\n")
|
||||||
|
gh.write("summary<<EOF\n")
|
||||||
|
gh.write(summary + "\n")
|
||||||
|
gh.write("EOF\n")
|
||||||
|
PYEOF
|
||||||
|
|
||||||
|
# 4️⃣ Fallback: 获取 commits + 通义千问总结
|
||||||
|
- name: Get Commits
|
||||||
|
if: steps.check.outputs.ok == 'true' && steps.sourcery.outputs.found == 'false'
|
||||||
|
run: |
|
||||||
|
curl -s \
|
||||||
|
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
|
||||||
|
${{ github.event.pull_request.commits_url }} \
|
||||||
|
| jq -r '.[].commit.message' | head -n 20 > commits.txt
|
||||||
|
|
||||||
|
- name: AI Summary (Qwen Fallback)
|
||||||
|
if: steps.check.outputs.ok == 'true' && steps.sourcery.outputs.found == 'false'
|
||||||
|
id: qwen
|
||||||
|
env:
|
||||||
|
DASHSCOPE_API_KEY: ${{ secrets.DASHSCOPE_API_KEY }}
|
||||||
|
run: |
|
||||||
|
python3 << 'PYEOF'
|
||||||
|
import json, os, urllib.request
|
||||||
|
|
||||||
|
with open("commits.txt", "r") as f:
|
||||||
|
commits = f.read().strip()
|
||||||
|
|
||||||
|
prompt = "请用中文总结以下代码提交,输出3-5条要点,面向测试人员。直接输出编号列表,不要输出标题或前言:\n" + commits
|
||||||
|
payload = {"model": "qwen-plus", "input": {"prompt": prompt}}
|
||||||
|
data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
||||||
|
|
||||||
|
req = urllib.request.Request(
|
||||||
|
"https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation",
|
||||||
|
data=data,
|
||||||
|
headers={
|
||||||
|
"Authorization": "Bearer " + os.environ["DASHSCOPE_API_KEY"],
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
resp = urllib.request.urlopen(req)
|
||||||
|
result = json.loads(resp.read().decode())
|
||||||
|
summary = result.get("output", {}).get("text", "AI 摘要生成失败")
|
||||||
|
|
||||||
|
with open(os.environ["GITHUB_OUTPUT"], "a") as gh:
|
||||||
|
gh.write("summary<<EOF\n")
|
||||||
|
gh.write(summary + "\n")
|
||||||
|
gh.write("EOF\n")
|
||||||
|
PYEOF
|
||||||
|
|
||||||
|
# 5️⃣ 企业微信通知(Markdown)
|
||||||
|
- name: Notify WeChat
|
||||||
|
if: steps.check.outputs.ok == 'true'
|
||||||
|
env:
|
||||||
|
WECHAT_WEBHOOK: ${{ secrets.WECHAT_WEBHOOK }}
|
||||||
|
BRANCH: ${{ github.event.pull_request.base.ref }}
|
||||||
|
AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||||
|
PR_TITLE: ${{ github.event.pull_request.title }}
|
||||||
|
PR_URL: ${{ github.event.pull_request.html_url }}
|
||||||
|
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||||
|
MERGE_SHA: ${{ github.event.pull_request.merge_commit_sha }}
|
||||||
|
SOURCERY_FOUND: ${{ steps.sourcery.outputs.found }}
|
||||||
|
SOURCERY_SUMMARY: ${{ steps.sourcery.outputs.summary }}
|
||||||
|
QWEN_SUMMARY: ${{ steps.qwen.outputs.summary }}
|
||||||
|
run: |
|
||||||
|
python3 << 'PYEOF'
|
||||||
|
import json, os, urllib.request
|
||||||
|
|
||||||
|
if os.environ.get("SOURCERY_FOUND") == "true":
|
||||||
|
label = "Summary by Sourcery"
|
||||||
|
summary = os.environ.get("SOURCERY_SUMMARY", "")
|
||||||
|
else:
|
||||||
|
label = "AI变更摘要"
|
||||||
|
summary = os.environ.get("QWEN_SUMMARY", "AI 摘要生成失败")
|
||||||
|
|
||||||
|
pr_number = os.environ.get("PR_NUMBER", "")
|
||||||
|
short_sha = os.environ.get("MERGE_SHA", "")[:7]
|
||||||
|
|
||||||
|
content = (
|
||||||
|
"## 🚀 Release 发布通知\n"
|
||||||
|
"> <20> **分支**: " + os.environ["BRANCH"] + "\n"
|
||||||
|
"> 👤 **提交人**: " + os.environ["AUTHOR"] + "\n"
|
||||||
|
"> 📝 **标题**: " + os.environ["PR_TITLE"] + "\n"
|
||||||
|
"> 🔢 **PR编号**: #" + pr_number + "\n"
|
||||||
|
"> 🔖 **Commit**: " + short_sha + "\n\n"
|
||||||
|
"### 🧠 " + label + "\n" +
|
||||||
|
summary + "\n\n"
|
||||||
|
"---\n"
|
||||||
|
"🔗 [查看PR详情](" + os.environ["PR_URL"] + ")"
|
||||||
|
)
|
||||||
|
payload = {"msgtype": "markdown", "markdown": {"content": content}}
|
||||||
|
data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
||||||
|
req = urllib.request.Request(
|
||||||
|
os.environ["WECHAT_WEBHOOK"],
|
||||||
|
data=data,
|
||||||
|
headers={"Content-Type": "application/json"}
|
||||||
|
)
|
||||||
|
resp = urllib.request.urlopen(req)
|
||||||
|
print(resp.read().decode())
|
||||||
|
PYEOF
|
||||||
1
.gitignore
vendored
@@ -27,6 +27,7 @@ time.log
|
|||||||
celerybeat-schedule.db
|
celerybeat-schedule.db
|
||||||
search_results.json
|
search_results.json
|
||||||
redbear-mem-metrics/
|
redbear-mem-metrics/
|
||||||
|
redbear-mem-benchmark/
|
||||||
pitch-deck/
|
pitch-deck/
|
||||||
|
|
||||||
api/migrations/versions
|
api/migrations/versions
|
||||||
|
|||||||
77
api/app/config/default_free_plan.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"""
|
||||||
|
社区版默认免费套餐配置
|
||||||
|
当无法从 SaaS 版获取 premium 模块时,使用此配置作为兜底
|
||||||
|
|
||||||
|
可通过环境变量覆盖配额配置,格式:QUOTA_<QUOTA_NAME>
|
||||||
|
例如:QUOTA_END_USER_QUOTA=100
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def _get_quota_from_env():
|
||||||
|
"""从环境变量获取配额配置"""
|
||||||
|
quota_keys = [
|
||||||
|
"workspace_quota",
|
||||||
|
"skill_quota",
|
||||||
|
"app_quota",
|
||||||
|
"knowledge_capacity_quota",
|
||||||
|
"memory_engine_quota",
|
||||||
|
"end_user_quota",
|
||||||
|
"ontology_project_quota",
|
||||||
|
"model_quota",
|
||||||
|
"api_ops_rate_limit",
|
||||||
|
]
|
||||||
|
quotas = {}
|
||||||
|
for key in quota_keys:
|
||||||
|
env_key = f"QUOTA_{key.upper()}"
|
||||||
|
env_value = os.getenv(env_key)
|
||||||
|
if env_value is not None:
|
||||||
|
try:
|
||||||
|
quotas[key] = float(env_value) if '.' in env_value else int(env_value)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return quotas
|
||||||
|
|
||||||
|
|
||||||
|
def _build_default_free_plan():
|
||||||
|
"""构建默认免费套餐配置"""
|
||||||
|
base = {
|
||||||
|
"name": "记忆体验版",
|
||||||
|
"name_en": "Memory Experience",
|
||||||
|
"category": "saas_personal",
|
||||||
|
"tier_level": 0,
|
||||||
|
"version": "1.0",
|
||||||
|
"status": True,
|
||||||
|
"price": 0,
|
||||||
|
"billing_cycle": "permanent_free",
|
||||||
|
"core_value": "感受永久记忆",
|
||||||
|
"core_value_en": "Experience Permanent Memory",
|
||||||
|
"tech_support": "社群交流",
|
||||||
|
"tech_support_en": "Community Support",
|
||||||
|
"sla_compliance": "无",
|
||||||
|
"sla_compliance_en": "None",
|
||||||
|
"page_customization": "无",
|
||||||
|
"page_customization_en": "None",
|
||||||
|
"theme_color": "#64748B",
|
||||||
|
"quotas": {
|
||||||
|
"workspace_quota": 1,
|
||||||
|
"skill_quota": 5,
|
||||||
|
"app_quota": 2,
|
||||||
|
"knowledge_capacity_quota": 0.3,
|
||||||
|
"memory_engine_quota": 1,
|
||||||
|
"end_user_quota": 1,
|
||||||
|
"ontology_project_quota": 3,
|
||||||
|
"model_quota": 1,
|
||||||
|
"api_ops_rate_limit": 50,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
env_quotas = _get_quota_from_env()
|
||||||
|
if env_quotas:
|
||||||
|
base["quotas"].update(env_quotas)
|
||||||
|
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_FREE_PLAN = _build_default_free_plan()
|
||||||
@@ -47,7 +47,8 @@ from . import (
|
|||||||
user_memory_controllers,
|
user_memory_controllers,
|
||||||
workspace_controller,
|
workspace_controller,
|
||||||
ontology_controller,
|
ontology_controller,
|
||||||
skill_controller
|
skill_controller,
|
||||||
|
tenant_subscription_controller,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建管理端 API 路由器
|
# 创建管理端 API 路由器
|
||||||
@@ -98,5 +99,7 @@ manager_router.include_router(file_storage_controller.router)
|
|||||||
manager_router.include_router(ontology_controller.router)
|
manager_router.include_router(ontology_controller.router)
|
||||||
manager_router.include_router(skill_controller.router)
|
manager_router.include_router(skill_controller.router)
|
||||||
manager_router.include_router(i18n_controller.router)
|
manager_router.include_router(i18n_controller.router)
|
||||||
|
manager_router.include_router(tenant_subscription_controller.router)
|
||||||
|
manager_router.include_router(tenant_subscription_controller.public_router)
|
||||||
|
|
||||||
__all__ = ["manager_router"]
|
__all__ = ["manager_router"]
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from app.services.app_statistics_service import AppStatisticsService
|
|||||||
from app.services.workflow_import_service import WorkflowImportService
|
from app.services.workflow_import_service import WorkflowImportService
|
||||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||||
from app.services.app_dsl_service import AppDslService
|
from app.services.app_dsl_service import AppDslService
|
||||||
|
from app.core.quota_stub import check_app_quota
|
||||||
|
|
||||||
router = APIRouter(prefix="/apps", tags=["Apps"])
|
router = APIRouter(prefix="/apps", tags=["Apps"])
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
@@ -35,6 +36,7 @@ logger = get_business_logger()
|
|||||||
|
|
||||||
@router.post("", summary="创建应用(可选创建 Agent 配置)")
|
@router.post("", summary="创建应用(可选创建 Agent 配置)")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
|
@check_app_quota
|
||||||
def create_app(
|
def create_app(
|
||||||
payload: app_schema.AppCreate,
|
payload: app_schema.AppCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
@@ -269,6 +271,19 @@ def update_agent_config(
|
|||||||
return success(data=app_schema.AgentConfig.model_validate(cfg))
|
return success(data=app_schema.AgentConfig.model_validate(cfg))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{app_id}/model/parameters/default", summary="获取 Agent 模型参数默认配置")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def get_agent_model_parameters(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
service = AppService(db)
|
||||||
|
model_parameters = service.get_default_model_parameters(app_id=app_id)
|
||||||
|
return success(data=model_parameters, msg="获取 Agent 模型参数默认配置")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{app_id}/config", summary="获取 Agent 配置")
|
@router.get("/{app_id}/config", summary="获取 Agent 配置")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
def get_agent_config(
|
def get_agent_config(
|
||||||
@@ -1250,9 +1265,11 @@ async def export_app(
|
|||||||
async def import_app(
|
async def import_app(
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
|
app_id: Optional[str] = Form(None),
|
||||||
):
|
):
|
||||||
"""从 YAML 文件导入 agent / multi_agent / workflow 应用。
|
"""从 YAML 文件导入 agent / multi_agent / workflow 应用。
|
||||||
|
传入 app_id 时覆盖该应用的配置(类型必须一致),否则创建新应用。
|
||||||
跨空间/跨租户导入时,模型/工具/知识库会按名称匹配,匹配不到则置空并返回 warnings。
|
跨空间/跨租户导入时,模型/工具/知识库会按名称匹配,匹配不到则置空并返回 warnings。
|
||||||
"""
|
"""
|
||||||
if not file.filename.lower().endswith((".yaml", ".yml")):
|
if not file.filename.lower().endswith((".yaml", ".yml")):
|
||||||
@@ -1263,13 +1280,15 @@ async def import_app(
|
|||||||
if not dsl or "app" not in dsl:
|
if not dsl or "app" not in dsl:
|
||||||
return fail(msg="YAML 格式无效,缺少 app 字段", code=BizCode.BAD_REQUEST)
|
return fail(msg="YAML 格式无效,缺少 app 字段", code=BizCode.BAD_REQUEST)
|
||||||
|
|
||||||
new_app, warnings = AppDslService(db).import_dsl(
|
target_app_id = uuid.UUID(app_id) if app_id else None
|
||||||
|
result_app, warnings = AppDslService(db).import_dsl(
|
||||||
dsl=dsl,
|
dsl=dsl,
|
||||||
workspace_id=current_user.current_workspace_id,
|
workspace_id=current_user.current_workspace_id,
|
||||||
tenant_id=current_user.tenant_id,
|
tenant_id=current_user.tenant_id,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
|
app_id=target_app_id,
|
||||||
)
|
)
|
||||||
return success(
|
return success(
|
||||||
data={"app": app_schema.App.model_validate(new_app), "warnings": warnings},
|
data={"app": app_schema.App.model_validate(result_app), "warnings": warnings},
|
||||||
msg="应用导入成功" + (",但部分资源需手动配置" if warnings else "")
|
msg="应用导入成功" + (",但部分资源需手动配置" if warnings else "")
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -443,10 +443,10 @@ async def retrieve_chunks(
|
|||||||
match retrieve_data.retrieve_type:
|
match retrieve_data.retrieve_type:
|
||||||
case chunk_schema.RetrieveType.PARTICIPLE:
|
case chunk_schema.RetrieveType.PARTICIPLE:
|
||||||
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
||||||
return success(data=rs, msg="retrieval successful")
|
return success(data=jsonable_encoder(rs), msg="retrieval successful")
|
||||||
case chunk_schema.RetrieveType.SEMANTIC:
|
case chunk_schema.RetrieveType.SEMANTIC:
|
||||||
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
||||||
return success(data=rs, msg="retrieval successful")
|
return success(data=jsonable_encoder(rs), msg="retrieval successful")
|
||||||
case _:
|
case _:
|
||||||
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
||||||
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from app.models.user_model import User
|
|||||||
from app.schemas import file_schema, document_schema
|
from app.schemas import file_schema, document_schema
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services import file_service, document_service
|
from app.services import file_service, document_service
|
||||||
|
from app.core.quota_stub import check_knowledge_capacity_quota
|
||||||
|
|
||||||
|
|
||||||
# Obtain a dedicated API logger
|
# Obtain a dedicated API logger
|
||||||
@@ -131,6 +132,7 @@ async def create_folder(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/file", response_model=ApiResponse)
|
@router.post("/file", response_model=ApiResponse)
|
||||||
|
@check_knowledge_capacity_quota
|
||||||
async def upload_file(
|
async def upload_file(
|
||||||
kb_id: uuid.UUID,
|
kb_id: uuid.UUID,
|
||||||
parent_id: uuid.UUID,
|
parent_id: uuid.UUID,
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from app.schemas import knowledge_schema
|
|||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services import knowledge_service, document_service
|
from app.services import knowledge_service, document_service
|
||||||
from app.services.model_service import ModelConfigService
|
from app.services.model_service import ModelConfigService
|
||||||
|
from app.core.quota_stub import check_knowledge_capacity_quota
|
||||||
|
|
||||||
# Obtain a dedicated API logger
|
# Obtain a dedicated API logger
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
@@ -179,6 +180,7 @@ async def get_knowledges(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/knowledge", response_model=ApiResponse)
|
@router.post("/knowledge", response_model=ApiResponse)
|
||||||
|
@check_knowledge_capacity_quota
|
||||||
async def create_knowledge(
|
async def create_knowledge(
|
||||||
create_data: knowledge_schema.KnowledgeCreate,
|
create_data: knowledge_schema.KnowledgeCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from app.services.memory_storage_service import (
|
|||||||
search_entity,
|
search_entity,
|
||||||
search_statement,
|
search_statement,
|
||||||
)
|
)
|
||||||
|
from app.core.quota_stub import check_memory_engine_quota
|
||||||
from fastapi import APIRouter, Depends, Header
|
from fastapi import APIRouter, Depends, Header
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -76,6 +77,7 @@ async def get_storage_info(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
|
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
|
||||||
|
@check_memory_engine_quota
|
||||||
def create_config(
|
def create_config(
|
||||||
payload: ConfigParamsCreate,
|
payload: ConfigParamsCreate,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from app.core.response_utils import success
|
|||||||
from app.schemas.response_schema import ApiResponse, PageData
|
from app.schemas.response_schema import ApiResponse, PageData
|
||||||
from app.services.model_service import ModelConfigService, ModelApiKeyService, ModelBaseService
|
from app.services.model_service import ModelConfigService, ModelApiKeyService, ModelBaseService
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
|
from app.core.quota_stub import check_model_quota, check_model_activation_quota
|
||||||
|
|
||||||
# 获取API专用日志器
|
# 获取API专用日志器
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
@@ -236,6 +237,7 @@ def delete_model_base(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/model_plaza/{model_base_id}/add", response_model=ApiResponse)
|
@router.post("/model_plaza/{model_base_id}/add", response_model=ApiResponse)
|
||||||
|
@check_model_quota
|
||||||
def add_model_from_plaza(
|
def add_model_from_plaza(
|
||||||
model_base_id: uuid.UUID,
|
model_base_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
@@ -273,6 +275,7 @@ def get_model_by_id(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=ApiResponse)
|
@router.post("", response_model=ApiResponse)
|
||||||
|
@check_model_quota
|
||||||
async def create_model(
|
async def create_model(
|
||||||
model_data: model_schema.ModelConfigCreate,
|
model_data: model_schema.ModelConfigCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
@@ -303,6 +306,7 @@ async def create_model(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/composite", response_model=ApiResponse)
|
@router.post("/composite", response_model=ApiResponse)
|
||||||
|
@check_model_quota
|
||||||
async def create_composite_model(
|
async def create_composite_model(
|
||||||
model_data: model_schema.CompositeModelCreate,
|
model_data: model_schema.CompositeModelCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
@@ -329,6 +333,7 @@ async def create_composite_model(
|
|||||||
|
|
||||||
|
|
||||||
@router.put("/composite/{model_id}", response_model=ApiResponse)
|
@router.put("/composite/{model_id}", response_model=ApiResponse)
|
||||||
|
@check_model_activation_quota
|
||||||
async def update_composite_model(
|
async def update_composite_model(
|
||||||
model_id: uuid.UUID,
|
model_id: uuid.UUID,
|
||||||
model_data: model_schema.CompositeModelCreate,
|
model_data: model_schema.CompositeModelCreate,
|
||||||
@@ -370,6 +375,7 @@ def delete_composite_model(
|
|||||||
|
|
||||||
|
|
||||||
@router.put("/{model_id}", response_model=ApiResponse)
|
@router.put("/{model_id}", response_model=ApiResponse)
|
||||||
|
@check_model_activation_quota
|
||||||
def update_model(
|
def update_model(
|
||||||
model_id: uuid.UUID,
|
model_id: uuid.UUID,
|
||||||
model_data: model_schema.ModelConfigUpdate,
|
model_data: model_schema.ModelConfigUpdate,
|
||||||
|
|||||||
@@ -28,6 +28,8 @@ from fastapi import APIRouter, Depends, HTTPException, File, UploadFile, Form, H
|
|||||||
from fastapi.responses import StreamingResponse, JSONResponse
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.quota_stub import check_ontology_project_quota
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.language_utils import get_language_from_header
|
from app.core.language_utils import get_language_from_header
|
||||||
@@ -163,7 +165,7 @@ def _get_ontology_service(
|
|||||||
api_key=api_key_config.api_key,
|
api_key=api_key_config.api_key,
|
||||||
base_url=api_key_config.api_base,
|
base_url=api_key_config.api_base,
|
||||||
is_omni=api_key_config.is_omni,
|
is_omni=api_key_config.is_omni,
|
||||||
support_thinking="thinking" in (api_key_config.capability or []),
|
capability=api_key_config.capability,
|
||||||
max_retries=3,
|
max_retries=3,
|
||||||
timeout=60.0
|
timeout=60.0
|
||||||
)
|
)
|
||||||
@@ -287,6 +289,7 @@ async def extract_ontology(
|
|||||||
# ==================== 本体场景管理接口 ====================
|
# ==================== 本体场景管理接口 ====================
|
||||||
|
|
||||||
@router.post("/scene", response_model=ApiResponse)
|
@router.post("/scene", response_model=ApiResponse)
|
||||||
|
@check_ontology_project_quota
|
||||||
async def create_scene(
|
async def create_scene(
|
||||||
request: SceneCreateRequest,
|
request: SceneCreateRequest,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
|||||||
@@ -124,10 +124,11 @@ async def get_prompt_opt(
|
|||||||
skill=data.skill
|
skill=data.skill
|
||||||
):
|
):
|
||||||
# chunk 是 prompt 的增量内容
|
# chunk 是 prompt 的增量内容
|
||||||
yield f"event:message\ndata: {json.dumps(chunk)}\n\n"
|
yield f"event:message\ndata: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield f"event:error\ndata: {json.dumps(
|
yield f"event:error\ndata: {json.dumps(
|
||||||
{"error": str(e)}
|
{"error": str(e)},
|
||||||
|
ensure_ascii=False
|
||||||
)}\n\n"
|
)}\n\n"
|
||||||
yield "event:end\ndata: {}\n\n"
|
yield "event:end\ndata: {}\n\n"
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
|
|||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.core.quota_manager import check_end_user_quota
|
||||||
from app.core.response_utils import success, fail
|
from app.core.response_utils import success, fail
|
||||||
from app.db import get_db, get_db_read
|
from app.db import get_db, get_db_read
|
||||||
from app.dependencies import get_share_user_id, ShareTokenData
|
from app.dependencies import get_share_user_id, ShareTokenData
|
||||||
@@ -308,6 +309,7 @@ def get_conversation(
|
|||||||
"/chat",
|
"/chat",
|
||||||
summary="发送消息(支持流式和非流式)"
|
summary="发送消息(支持流式和非流式)"
|
||||||
)
|
)
|
||||||
|
@check_end_user_quota
|
||||||
async def chat(
|
async def chat(
|
||||||
payload: conversation_schema.ChatRequest,
|
payload: conversation_schema.ChatRequest,
|
||||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||||
|
|||||||
@@ -4,7 +4,17 @@
|
|||||||
认证方式: API Key
|
认证方式: API Key
|
||||||
"""
|
"""
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller, end_user_api_controller
|
|
||||||
|
from . import (
|
||||||
|
app_api_controller,
|
||||||
|
end_user_api_controller,
|
||||||
|
memory_api_controller,
|
||||||
|
memory_config_api_controller,
|
||||||
|
rag_api_chunk_controller,
|
||||||
|
rag_api_document_controller,
|
||||||
|
rag_api_file_controller,
|
||||||
|
rag_api_knowledge_controller,
|
||||||
|
)
|
||||||
|
|
||||||
# 创建 V1 API 路由器
|
# 创建 V1 API 路由器
|
||||||
service_router = APIRouter()
|
service_router = APIRouter()
|
||||||
@@ -17,5 +27,6 @@ service_router.include_router(rag_api_file_controller.router)
|
|||||||
service_router.include_router(rag_api_chunk_controller.router)
|
service_router.include_router(rag_api_chunk_controller.router)
|
||||||
service_router.include_router(memory_api_controller.router)
|
service_router.include_router(memory_api_controller.router)
|
||||||
service_router.include_router(end_user_api_controller.router)
|
service_router.include_router(end_user_api_controller.router)
|
||||||
|
service_router.include_router(memory_config_api_controller.router)
|
||||||
|
|
||||||
__all__ = ["service_router"]
|
__all__ = ["service_router"]
|
||||||
|
|||||||
@@ -5,23 +5,44 @@ import uuid
|
|||||||
from fastapi import APIRouter, Body, Depends, Request
|
from fastapi import APIRouter, Body, Depends, Request
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.controllers import user_memory_controllers
|
||||||
from app.core.api_key_auth import require_api_key
|
from app.core.api_key_auth import require_api_key
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.core.quota_stub import check_end_user_quota
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.repositories.end_user_repository import EndUserRepository
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
from app.schemas.api_key_schema import ApiKeyAuth
|
from app.schemas.api_key_schema import ApiKeyAuth
|
||||||
|
from app.schemas.end_user_info_schema import EndUserInfoUpdate
|
||||||
from app.schemas.memory_api_schema import CreateEndUserRequest, CreateEndUserResponse
|
from app.schemas.memory_api_schema import CreateEndUserRequest, CreateEndUserResponse
|
||||||
|
from app.services import api_key_service
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
|
||||||
router = APIRouter(prefix="/end_user", tags=["V1 - End User API"])
|
router = APIRouter(prefix="/end_user", tags=["V1 - End User API"])
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_current_user(api_key_auth: ApiKeyAuth, db: Session):
|
||||||
|
"""Build a current_user object from API key auth
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key_auth: Validated API key auth info
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User object with current_workspace_id set
|
||||||
|
"""
|
||||||
|
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||||
|
current_user = api_key.creator
|
||||||
|
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
@router.post("/create")
|
@router.post("/create")
|
||||||
@require_api_key(scopes=["memory"])
|
@require_api_key(scopes=["memory"])
|
||||||
|
@check_end_user_quota
|
||||||
async def create_end_user(
|
async def create_end_user(
|
||||||
request: Request,
|
request: Request,
|
||||||
api_key_auth: ApiKeyAuth = None,
|
api_key_auth: ApiKeyAuth = None,
|
||||||
@@ -37,6 +58,7 @@ async def create_end_user(
|
|||||||
|
|
||||||
Optionally accepts a memory_config_id to connect the end user to a specific
|
Optionally accepts a memory_config_id to connect the end user to a specific
|
||||||
memory configuration. If not provided, falls back to the workspace default config.
|
memory configuration. If not provided, falls back to the workspace default config.
|
||||||
|
Optionally accepts an app_id to bind the end user to a specific app.
|
||||||
"""
|
"""
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
payload = CreateEndUserRequest(**body)
|
payload = CreateEndUserRequest(**body)
|
||||||
@@ -71,14 +93,26 @@ async def create_end_user(
|
|||||||
else:
|
else:
|
||||||
logger.warning(f"No default memory config found for workspace: {workspace_id}")
|
logger.warning(f"No default memory config found for workspace: {workspace_id}")
|
||||||
|
|
||||||
|
# Resolve app_id: explicit from payload, otherwise None
|
||||||
|
app_id = None
|
||||||
|
if payload.app_id:
|
||||||
|
try:
|
||||||
|
app_id = uuid.UUID(payload.app_id)
|
||||||
|
except ValueError:
|
||||||
|
raise BusinessException(
|
||||||
|
f"Invalid app_id format: {payload.app_id}",
|
||||||
|
BizCode.INVALID_PARAMETER
|
||||||
|
)
|
||||||
|
|
||||||
end_user_repo = EndUserRepository(db)
|
end_user_repo = EndUserRepository(db)
|
||||||
end_user = end_user_repo.get_or_create_end_user_with_config(
|
end_user = end_user_repo.get_or_create_end_user_with_config(
|
||||||
app_id=api_key_auth.resource_id,
|
app_id=app_id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
other_id=payload.other_id,
|
other_id=payload.other_id,
|
||||||
memory_config_id=memory_config_id,
|
memory_config_id=memory_config_id,
|
||||||
|
other_name=payload.other_name,
|
||||||
)
|
)
|
||||||
|
end_user.other_name = payload.other_name
|
||||||
logger.info(f"End user ready: {end_user.id}")
|
logger.info(f"End user ready: {end_user.id}")
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
@@ -90,3 +124,50 @@ async def create_end_user(
|
|||||||
}
|
}
|
||||||
|
|
||||||
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")
|
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/info")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def get_end_user_info(
|
||||||
|
request: Request,
|
||||||
|
end_user_id: str,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get end user info.
|
||||||
|
|
||||||
|
Retrieves the info record (aliases, meta_data, etc.) for the specified end user.
|
||||||
|
Delegates to the manager-side controller for shared logic.
|
||||||
|
"""
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
return await user_memory_controllers.get_end_user_info(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/info/update")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def update_end_user_info(
|
||||||
|
request: Request,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
message: str = Body(None, description="Request body"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update end user info.
|
||||||
|
|
||||||
|
Updates the info record (other_name, aliases, meta_data) for the specified end user.
|
||||||
|
Delegates to the manager-side controller for shared logic.
|
||||||
|
"""
|
||||||
|
body = await request.json()
|
||||||
|
payload = EndUserInfoUpdate(**body)
|
||||||
|
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
return await user_memory_controllers.update_end_user_info(
|
||||||
|
info_update=payload,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,45 +1,75 @@
|
|||||||
"""Memory 服务接口 - 基于 API Key 认证"""
|
"""Memory 服务接口 - 基于 API Key 认证"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Body, Depends, Query, Request
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.api_key_auth import require_api_key
|
from app.core.api_key_auth import require_api_key
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.core.quota_stub import check_end_user_quota
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.schemas.api_key_schema import ApiKeyAuth
|
from app.schemas.api_key_schema import ApiKeyAuth
|
||||||
from app.schemas.memory_api_schema import (
|
from app.schemas.memory_api_schema import (
|
||||||
CreateEndUserRequest,
|
|
||||||
CreateEndUserResponse,
|
|
||||||
ListConfigsResponse,
|
|
||||||
MemoryReadRequest,
|
MemoryReadRequest,
|
||||||
MemoryReadResponse,
|
MemoryReadResponse,
|
||||||
|
MemoryReadSyncResponse,
|
||||||
MemoryWriteRequest,
|
MemoryWriteRequest,
|
||||||
MemoryWriteResponse,
|
MemoryWriteResponse,
|
||||||
|
MemoryWriteSyncResponse,
|
||||||
)
|
)
|
||||||
from app.services.memory_api_service import MemoryAPIService
|
from app.services.memory_api_service import MemoryAPIService
|
||||||
from fastapi import APIRouter, Body, Depends, Request
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/memory", tags=["V1 - Memory API"])
|
router = APIRouter(prefix="/memory", tags=["V1 - Memory API"])
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_task_result(result: dict) -> dict:
|
||||||
|
"""Make Celery task result JSON-serializable.
|
||||||
|
|
||||||
|
Converts UUID and other non-serializable values to strings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: Raw task result dict from task_service
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON-safe dict
|
||||||
|
"""
|
||||||
|
import uuid as _uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
def _convert(obj):
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {k: _convert(v) for k, v in obj.items()}
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return [_convert(i) for i in obj]
|
||||||
|
if isinstance(obj, _uuid.UUID):
|
||||||
|
return str(obj)
|
||||||
|
if isinstance(obj, datetime):
|
||||||
|
return obj.isoformat()
|
||||||
|
return obj
|
||||||
|
|
||||||
|
return _convert(result)
|
||||||
|
|
||||||
|
|
||||||
@router.get("")
|
@router.get("")
|
||||||
async def get_memory_info():
|
async def get_memory_info():
|
||||||
"""获取记忆服务信息(占位)"""
|
"""获取记忆服务信息(占位)"""
|
||||||
return success(data={}, msg="Memory API - Coming Soon")
|
return success(data={}, msg="Memory API - Coming Soon")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/write_api_service")
|
@router.post("/write")
|
||||||
@require_api_key(scopes=["memory"])
|
@require_api_key(scopes=["memory"])
|
||||||
async def write_memory_api_service(
|
async def write_memory(
|
||||||
request: Request,
|
request: Request,
|
||||||
api_key_auth: ApiKeyAuth = None,
|
api_key_auth: ApiKeyAuth = None,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
message: str = Body(..., description="Message content"),
|
message: str = Body(..., description="Message content"),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Write memory to storage.
|
Submit a memory write task.
|
||||||
|
|
||||||
Stores memory content for the specified end user using the Memory API Service.
|
Validates the end user, then dispatches the write to a Celery background task
|
||||||
|
with per-user fair locking. Returns a task_id for status polling.
|
||||||
"""
|
"""
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
payload = MemoryWriteRequest(**body)
|
payload = MemoryWriteRequest(**body)
|
||||||
@@ -47,7 +77,7 @@ async def write_memory_api_service(
|
|||||||
|
|
||||||
memory_api_service = MemoryAPIService(db)
|
memory_api_service = MemoryAPIService(db)
|
||||||
|
|
||||||
result = await memory_api_service.write_memory(
|
result = memory_api_service.write_memory(
|
||||||
workspace_id=api_key_auth.workspace_id,
|
workspace_id=api_key_auth.workspace_id,
|
||||||
end_user_id=payload.end_user_id,
|
end_user_id=payload.end_user_id,
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
@@ -56,22 +86,44 @@ async def write_memory_api_service(
|
|||||||
user_rag_memory_id=payload.user_rag_memory_id,
|
user_rag_memory_id=payload.user_rag_memory_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Memory write successful for end_user: {payload.end_user_id}")
|
logger.info(f"Memory write task submitted: task_id={result['task_id']}, end_user_id: {payload.end_user_id}")
|
||||||
return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory written successfully")
|
return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory write task submitted")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/read_api_service")
|
@router.get("/write/status")
|
||||||
@require_api_key(scopes=["memory"])
|
@require_api_key(scopes=["memory"])
|
||||||
async def read_memory_api_service(
|
async def get_write_task_status(
|
||||||
|
request: Request,
|
||||||
|
task_id: str = Query(..., description="Celery task ID"),
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Check the status of a memory write task.
|
||||||
|
|
||||||
|
Returns the current status and result (if completed) of a previously submitted write task.
|
||||||
|
"""
|
||||||
|
logger.info(f"Write task status check - task_id: {task_id}")
|
||||||
|
|
||||||
|
from app.services.task_service import get_task_memory_write_result
|
||||||
|
result = get_task_memory_write_result(task_id)
|
||||||
|
|
||||||
|
return success(data=_sanitize_task_result(result), msg="Task status retrieved")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/read")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def read_memory(
|
||||||
request: Request,
|
request: Request,
|
||||||
api_key_auth: ApiKeyAuth = None,
|
api_key_auth: ApiKeyAuth = None,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
message: str = Body(..., description="Query message"),
|
message: str = Body(..., description="Query message"),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Read memory from storage.
|
Submit a memory read task.
|
||||||
|
|
||||||
Queries and retrieves memories for the specified end user with context-aware responses.
|
Validates the end user, then dispatches the read to a Celery background task.
|
||||||
|
Returns a task_id for status polling.
|
||||||
"""
|
"""
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
payload = MemoryReadRequest(**body)
|
payload = MemoryReadRequest(**body)
|
||||||
@@ -79,7 +131,7 @@ async def read_memory_api_service(
|
|||||||
|
|
||||||
memory_api_service = MemoryAPIService(db)
|
memory_api_service = MemoryAPIService(db)
|
||||||
|
|
||||||
result = await memory_api_service.read_memory(
|
result = memory_api_service.read_memory(
|
||||||
workspace_id=api_key_auth.workspace_id,
|
workspace_id=api_key_auth.workspace_id,
|
||||||
end_user_id=payload.end_user_id,
|
end_user_id=payload.end_user_id,
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
@@ -89,57 +141,94 @@ async def read_memory_api_service(
|
|||||||
user_rag_memory_id=payload.user_rag_memory_id,
|
user_rag_memory_id=payload.user_rag_memory_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Memory read successful for end_user: {payload.end_user_id}")
|
logger.info(f"Memory read task submitted: task_id={result['task_id']}, end_user_id: {payload.end_user_id}")
|
||||||
return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read successfully")
|
return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read task submitted")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/configs")
|
@router.get("/read/status")
|
||||||
@require_api_key(scopes=["memory"])
|
@require_api_key(scopes=["memory"])
|
||||||
async def list_memory_configs(
|
async def get_read_task_status(
|
||||||
request: Request,
|
request: Request,
|
||||||
|
task_id: str = Query(..., description="Celery task ID"),
|
||||||
api_key_auth: ApiKeyAuth = None,
|
api_key_auth: ApiKeyAuth = None,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
List all memory configs for the workspace.
|
Check the status of a memory read task.
|
||||||
|
|
||||||
Returns all available memory configurations associated with the authorized workspace.
|
Returns the current status and result (if completed) of a previously submitted read task.
|
||||||
"""
|
"""
|
||||||
logger.info(f"List configs request - workspace_id: {api_key_auth.workspace_id}")
|
logger.info(f"Read task status check - task_id: {task_id}")
|
||||||
|
|
||||||
memory_api_service = MemoryAPIService(db)
|
from app.services.task_service import get_task_memory_read_result
|
||||||
|
result = get_task_memory_read_result(task_id)
|
||||||
|
|
||||||
result = memory_api_service.list_memory_configs(
|
return success(data=_sanitize_task_result(result), msg="Task status retrieved")
|
||||||
workspace_id=api_key_auth.workspace_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Listed {result['total']} configs for workspace: {api_key_auth.workspace_id}")
|
|
||||||
return success(data=ListConfigsResponse(**result).model_dump(), msg="Configs listed successfully")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/end_users")
|
@router.post("/write/sync")
|
||||||
@require_api_key(scopes=["memory"])
|
@require_api_key(scopes=["memory"])
|
||||||
async def create_end_user(
|
@check_end_user_quota
|
||||||
|
async def write_memory_sync(
|
||||||
request: Request,
|
request: Request,
|
||||||
api_key_auth: ApiKeyAuth = None,
|
api_key_auth: ApiKeyAuth = None,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
message: str = Body(..., description="Message content"),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create an end user.
|
Write memory synchronously.
|
||||||
|
|
||||||
Creates a new end user for the authorized workspace.
|
Blocks until the write completes and returns the result directly.
|
||||||
If an end user with the same other_id already exists, returns the existing one.
|
For async processing with task polling, use /write instead.
|
||||||
"""
|
"""
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
payload = CreateEndUserRequest(**body)
|
payload = MemoryWriteRequest(**body)
|
||||||
logger.info(f"Create end user request - other_id: {payload.other_id}, workspace_id: {api_key_auth.workspace_id}")
|
logger.info(f"Memory write (sync) request - end_user_id: {payload.end_user_id}")
|
||||||
|
|
||||||
memory_api_service = MemoryAPIService(db)
|
memory_api_service = MemoryAPIService(db)
|
||||||
|
|
||||||
result = memory_api_service.create_end_user(
|
result = await memory_api_service.write_memory_sync(
|
||||||
workspace_id=api_key_auth.workspace_id,
|
workspace_id=api_key_auth.workspace_id,
|
||||||
other_id=payload.other_id,
|
end_user_id=payload.end_user_id,
|
||||||
|
message=payload.message,
|
||||||
|
config_id=payload.config_id,
|
||||||
|
storage_type=payload.storage_type,
|
||||||
|
user_rag_memory_id=payload.user_rag_memory_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"End user ready: {result['id']}")
|
logger.info(f"Memory write (sync) successful for end_user: {payload.end_user_id}")
|
||||||
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")
|
return success(data=MemoryWriteSyncResponse(**result).model_dump(), msg="Memory written successfully")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/read/sync")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def read_memory_sync(
|
||||||
|
request: Request,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
message: str = Body(..., description="Query message"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Read memory synchronously.
|
||||||
|
|
||||||
|
Blocks until the read completes and returns the answer directly.
|
||||||
|
For async processing with task polling, use /read instead.
|
||||||
|
"""
|
||||||
|
body = await request.json()
|
||||||
|
payload = MemoryReadRequest(**body)
|
||||||
|
logger.info(f"Memory read (sync) request - end_user_id: {payload.end_user_id}")
|
||||||
|
|
||||||
|
memory_api_service = MemoryAPIService(db)
|
||||||
|
|
||||||
|
result = await memory_api_service.read_memory_sync(
|
||||||
|
workspace_id=api_key_auth.workspace_id,
|
||||||
|
end_user_id=payload.end_user_id,
|
||||||
|
message=payload.message,
|
||||||
|
search_switch=payload.search_switch,
|
||||||
|
config_id=payload.config_id,
|
||||||
|
storage_type=payload.storage_type,
|
||||||
|
user_rag_memory_id=payload.user_rag_memory_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Memory read (sync) successful for end_user: {payload.end_user_id}")
|
||||||
|
return success(data=MemoryReadSyncResponse(**result).model_dump(), msg="Memory read successfully")
|
||||||
|
|||||||
491
api/app/controllers/service/memory_config_api_controller.py
Normal file
@@ -0,0 +1,491 @@
|
|||||||
|
"""Memory Config 服务接口 - 基于 API Key 认证"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Body, Depends, Header, Query, Request
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.controllers import memory_storage_controller
|
||||||
|
from app.controllers import memory_forget_controller
|
||||||
|
from app.controllers import ontology_controller
|
||||||
|
from app.controllers import emotion_config_controller
|
||||||
|
from app.controllers import memory_reflection_controller
|
||||||
|
from app.schemas.memory_storage_schema import ForgettingConfigUpdateRequest
|
||||||
|
from app.controllers.emotion_config_controller import EmotionConfigUpdate
|
||||||
|
from app.schemas.memory_reflection_schemas import Memory_Reflection
|
||||||
|
from app.core.api_key_auth import require_api_key
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.db import get_db
|
||||||
|
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||||
|
from app.schemas.api_key_schema import ApiKeyAuth
|
||||||
|
from app.schemas.memory_api_schema import (
|
||||||
|
ConfigUpdateExtractedRequest,
|
||||||
|
ConfigUpdateRequest,
|
||||||
|
ListConfigsResponse,
|
||||||
|
ConfigCreateRequest,
|
||||||
|
ConfigUpdateForgettingRequest,
|
||||||
|
EmotionConfigUpdateRequest,
|
||||||
|
ReflectionConfigUpdateRequest,
|
||||||
|
)
|
||||||
|
from app.schemas.memory_storage_schema import (
|
||||||
|
ConfigUpdate,
|
||||||
|
ConfigUpdateExtracted,
|
||||||
|
ConfigParamsCreate,
|
||||||
|
)
|
||||||
|
from app.services import api_key_service
|
||||||
|
from app.services.memory_api_service import MemoryAPIService
|
||||||
|
from app.utils.config_utils import resolve_config_id
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/memory_config", tags=["V1 - Memory Config API"])
|
||||||
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_current_user(api_key_auth: ApiKeyAuth, db: Session):
|
||||||
|
"""Build a current_user object from API key auth
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key_auth: Validated API key auth info
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User object with current_workspace_id set
|
||||||
|
"""
|
||||||
|
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||||
|
current_user = api_key.creator
|
||||||
|
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
def _verify_config_ownership(config_id:str, workspace_id:uuid.UUID, db:Session):
|
||||||
|
"""Verify that the config belongs to the workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_id: The ID of the config to verify
|
||||||
|
workspace_id: The workspace ID tocheck against
|
||||||
|
db: Database session for querying
|
||||||
|
Raises:
|
||||||
|
BusinessException: If the config does not exist or does not belong to the workspace
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
resolved_id = resolve_config_id(config_id, db)
|
||||||
|
except ValueError as e:
|
||||||
|
raise BusinessException(
|
||||||
|
message=f"Invalid config_id: {e}",
|
||||||
|
code=BizCode.INVALID_PARAMETER,
|
||||||
|
)
|
||||||
|
config = MemoryConfigRepository.get_by_id(db, resolved_id)
|
||||||
|
if not config or config.workspace_id != workspace_id:
|
||||||
|
raise BusinessException(
|
||||||
|
message="Config not found or access denied",
|
||||||
|
code=BizCode.MEMORY_CONFIG_NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# @router.get("/configs")
|
||||||
|
# @require_api_key(scopes=["memory"])
|
||||||
|
# async def list_memory_configs(
|
||||||
|
# request: Request,
|
||||||
|
# api_key_auth: ApiKeyAuth = None,
|
||||||
|
# db: Session = Depends(get_db),
|
||||||
|
# ):
|
||||||
|
# """
|
||||||
|
# List all memory configs for the workspace.
|
||||||
|
|
||||||
|
# Returns all available memory configurations associated with the authorized workspace.
|
||||||
|
# """
|
||||||
|
# logger.info(f"List configs request - workspace_id: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
|
# memory_api_service = MemoryAPIService(db)
|
||||||
|
|
||||||
|
# result = memory_api_service.list_memory_configs(
|
||||||
|
# workspace_id=api_key_auth.workspace_id,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# logger.info(f"Listed {result['total']} configs for workspace: {api_key_auth.workspace_id}")
|
||||||
|
# return success(data=ListConfigsResponse(**result).model_dump(), msg="Configs listed successfully")
|
||||||
|
|
||||||
|
@router.get("/read_all_config")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def read_all_config(
|
||||||
|
request:Request,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
List all memory configs with full details (enhanced version).
|
||||||
|
|
||||||
|
Returns complete config fields for the authorized workspace.
|
||||||
|
No config_id ownership check needed — results are filtered by workspace.
|
||||||
|
"""
|
||||||
|
logger.info(f"V1 get all configs (full) - workspace: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
|
||||||
|
return memory_storage_controller.read_all_config(
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/scenes/simple")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def get_ontology_scenes(
|
||||||
|
request: Request,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get available ontology scenes for the workspace.
|
||||||
|
|
||||||
|
Returns a simple list of scene_id and scene_name for dropdown selection.
|
||||||
|
Used before creating a memory config to choose which ontology scene to associate.
|
||||||
|
"""
|
||||||
|
logger.info(f"V1 get scenes - workspace: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
|
||||||
|
return await ontology_controller.get_scenes_simple(
|
||||||
|
db=db,
|
||||||
|
current_user=current_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/read_config_extracted")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def read_config_extracted(
|
||||||
|
request: Request,
|
||||||
|
config_id: str = Query(..., description="config_id"),
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get extraction engine config details for a specific config.
|
||||||
|
|
||||||
|
Only configs belonging to the authorized workspace can be queried.
|
||||||
|
"""
|
||||||
|
logger.info(f"V1 read extracted config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
|
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
||||||
|
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
|
||||||
|
return memory_storage_controller.read_config_extracted(
|
||||||
|
config_id = config_id,
|
||||||
|
current_user = current_user,
|
||||||
|
db = db,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/read_config_forgetting")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def read_config_forgetting(
|
||||||
|
request: Request,
|
||||||
|
config_id: str = Query(..., description="config_id"),
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get forgetting settings for a specific memory config.
|
||||||
|
|
||||||
|
Only configs belonging to the authorized workspace can be queried.
|
||||||
|
"""
|
||||||
|
logger.info(f"V1 read forgetting config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
|
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
||||||
|
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
|
||||||
|
result = await memory_forget_controller.read_forgetting_config(
|
||||||
|
config_id = config_id,
|
||||||
|
current_user = current_user,
|
||||||
|
db = db,
|
||||||
|
)
|
||||||
|
return jsonable_encoder(result)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/read_config_emotion")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def read_config_emotion(
|
||||||
|
request: Request,
|
||||||
|
config_id: str = Query(..., description="config_id"),
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get emotion engine config details for a specific config.
|
||||||
|
|
||||||
|
Only configs belonging to the authorized workspace can be queried.
|
||||||
|
"""
|
||||||
|
logger.info(f"V1 read emotion config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
|
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
||||||
|
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
|
||||||
|
return jsonable_encoder(emotion_config_controller.get_emotion_config(
|
||||||
|
config_id=config_id,
|
||||||
|
db=db,
|
||||||
|
current_user=current_user,
|
||||||
|
))
|
||||||
|
|
||||||
|
@router.get("/read_config_reflection")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def read_config_reflection(
|
||||||
|
request: Request,
|
||||||
|
config_id: str = Query(..., description="config_id"),
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get reflection engine config details for a specific config.
|
||||||
|
|
||||||
|
Only configs belonging to the authorized workspace can be queried.
|
||||||
|
"""
|
||||||
|
logger.info(f"V1 read reflection config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
|
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
||||||
|
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
|
||||||
|
return jsonable_encoder(await memory_reflection_controller.start_reflection_configs(
|
||||||
|
config_id=config_id,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/create_config")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def create_memory_config(
|
||||||
|
request: Request,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
message: str = Body(None, description="Request body"),
|
||||||
|
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a new memory config for the workspace.
|
||||||
|
|
||||||
|
The config will be associated with the workspace of the API Key.
|
||||||
|
config_name is required, other fields are optional.
|
||||||
|
"""
|
||||||
|
body = await request.json()
|
||||||
|
payload = ConfigCreateRequest(**body)
|
||||||
|
|
||||||
|
logger.info(f"V1 create config - workspace: {api_key_auth.workspace_id}, config_name: {payload.config_name}")
|
||||||
|
|
||||||
|
# 构造管理端 Schema,workspace_id 从 API Key 注入
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
mgmt_payload = ConfigParamsCreate(
|
||||||
|
config_name=payload.config_name,
|
||||||
|
config_desc=payload.config_desc or "",
|
||||||
|
scene_id=payload.scene_id,
|
||||||
|
llm_id=payload.llm_id,
|
||||||
|
embedding_id=payload.embedding_id,
|
||||||
|
rerank_id=payload.rerank_id,
|
||||||
|
reflection_model_id=payload.reflection_model_id,
|
||||||
|
emotion_model_id=payload.emotion_model_id,
|
||||||
|
)
|
||||||
|
#将返回数据中UUID序列化处理
|
||||||
|
result =memory_storage_controller.create_config(
|
||||||
|
payload=mgmt_payload,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
x_language_type=x_language_type,
|
||||||
|
)
|
||||||
|
return jsonable_encoder(result)
|
||||||
|
|
||||||
|
@router.put("/update_config")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def update_memory_config(
|
||||||
|
request: Request,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
message: str = Body(None, description="Request body"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update memory config basic info (name, description, scene).
|
||||||
|
|
||||||
|
Requires API Key with 'memory' scope
|
||||||
|
Only configs belonging to the authorized workspace can be updated.
|
||||||
|
"""
|
||||||
|
body = await request.json()
|
||||||
|
payload = ConfigUpdateRequest(**body)
|
||||||
|
|
||||||
|
logger.info(f"V1 update config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
|
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
||||||
|
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
mgmt_payload = ConfigUpdate(
|
||||||
|
config_id = payload.config_id,
|
||||||
|
config_name = payload.config_name,
|
||||||
|
config_desc = payload.config_desc,
|
||||||
|
scene_id = payload.scene_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return memory_storage_controller.update_config(
|
||||||
|
payload = mgmt_payload,
|
||||||
|
current_user = current_user,
|
||||||
|
db = db,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.put("/update_config_extracted")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def update_memory_config_extracted(
|
||||||
|
request: Request,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
message: str = Body(None, description="Request body"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
update memory config extraction engine config (models, thresholds, chunking, pruning, etc.).
|
||||||
|
|
||||||
|
Requires API Key with 'memory' scope.
|
||||||
|
Only configs belonging to the authorized workspace can be updated.
|
||||||
|
"""
|
||||||
|
body = await request.json()
|
||||||
|
payload = ConfigUpdateExtractedRequest(**body)
|
||||||
|
|
||||||
|
logger.info(f"V1 update extracted config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
|
#校验权限
|
||||||
|
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
||||||
|
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
update_fields = payload.model_dump(exclude_unset=True)
|
||||||
|
mgmt_payload = ConfigUpdateExtracted(**update_fields)
|
||||||
|
|
||||||
|
return memory_storage_controller.update_config_extracted(
|
||||||
|
payload = mgmt_payload,
|
||||||
|
current_user = current_user,
|
||||||
|
db = db,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.put("/update_config_forgetting")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def update_memory_config_forgetting(
|
||||||
|
request: Request,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
message: str = Body(None, description="Request body"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
update memory config forgetting settings (forgetting strategy, parameters, etc.).
|
||||||
|
|
||||||
|
Requires API Key with 'memory' scope.
|
||||||
|
Only configs belonging to the authorized workspace can be updated.
|
||||||
|
"""
|
||||||
|
body = await request.json()
|
||||||
|
payload = ConfigUpdateForgettingRequest(**body)
|
||||||
|
|
||||||
|
logger.info(f"V1 update forgetting config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
|
#校验权限
|
||||||
|
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
||||||
|
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
update_fields = payload.model_dump(exclude_unset=True)
|
||||||
|
mgmt_payload = ForgettingConfigUpdateRequest(**update_fields)
|
||||||
|
|
||||||
|
#将返回数据中UUID序列化处理
|
||||||
|
result = await memory_forget_controller.update_forgetting_config(
|
||||||
|
payload = mgmt_payload,
|
||||||
|
current_user = current_user,
|
||||||
|
db = db,
|
||||||
|
)
|
||||||
|
return jsonable_encoder(result)
|
||||||
|
|
||||||
|
@router.put("/update_config_emotion")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def update_config_emotion(
|
||||||
|
request: Request,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
message: str = Body(None, description="Request body"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update emotion engine config (full update).
|
||||||
|
|
||||||
|
All fields except emotion_model_id are required.
|
||||||
|
Only configs belonging to the authorized workspace can be updated.
|
||||||
|
"""
|
||||||
|
body = await request.json()
|
||||||
|
payload = EmotionConfigUpdateRequest(**body)
|
||||||
|
|
||||||
|
logger.info(f"V1 update emotion config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
|
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
||||||
|
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
update_fields = payload.model_dump(exclude_unset=True)
|
||||||
|
mgmt_payload = EmotionConfigUpdate(**update_fields)
|
||||||
|
return jsonable_encoder(emotion_config_controller.update_emotion_config(
|
||||||
|
config=mgmt_payload,
|
||||||
|
db=db,
|
||||||
|
current_user=current_user,
|
||||||
|
))
|
||||||
|
|
||||||
|
@router.put("/update_config_reflection")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def update_config_reflection(
|
||||||
|
request: Request,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
message: str = Body(None, description="Request body"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update reflection engine config (full update).
|
||||||
|
|
||||||
|
All fields are required.
|
||||||
|
Only configs belonging to the authorized workspace can be updated.
|
||||||
|
"""
|
||||||
|
body = await request.json()
|
||||||
|
payload = ReflectionConfigUpdateRequest(**body)
|
||||||
|
|
||||||
|
logger.info(f"V1 update reflection config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
|
_verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db)
|
||||||
|
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
update_fields = payload.model_dump(exclude_unset=True)
|
||||||
|
mgmt_payload = Memory_Reflection(**update_fields)
|
||||||
|
|
||||||
|
return jsonable_encoder(await memory_reflection_controller.save_reflection_config(
|
||||||
|
request=mgmt_payload,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
))
|
||||||
|
|
||||||
|
@router.delete("/delete_config")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def delete_memory_config(
|
||||||
|
config_id: str,
|
||||||
|
request: Request,
|
||||||
|
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Delete a memory config.
|
||||||
|
|
||||||
|
- Default configs cannot be deleted.
|
||||||
|
- If end users are connected and force=False, returns a warning.
|
||||||
|
- If force=True, clears end user references and deletes the config.
|
||||||
|
|
||||||
|
Only configs belonging to the authorized workspace can be deleted.
|
||||||
|
"""
|
||||||
|
logger.info(f"V1 delete config - config_id: {config_id}, force: {force}, workspace: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
|
_verify_config_ownership(config_id, api_key_auth.workspace_id, db)
|
||||||
|
|
||||||
|
current_user = _get_current_user(api_key_auth, db)
|
||||||
|
|
||||||
|
return memory_storage_controller.delete_config(
|
||||||
|
config_id=config_id,
|
||||||
|
force=force,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
@@ -11,11 +11,13 @@ from app.schemas import skill_schema
|
|||||||
from app.schemas.response_schema import PageData, PageMeta
|
from app.schemas.response_schema import PageData, PageMeta
|
||||||
from app.services.skill_service import SkillService
|
from app.services.skill_service import SkillService
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
|
from app.core.quota_stub import check_skill_quota
|
||||||
|
|
||||||
router = APIRouter(prefix="/skills", tags=["Skills"])
|
router = APIRouter(prefix="/skills", tags=["Skills"])
|
||||||
|
|
||||||
|
|
||||||
@router.post("", summary="创建技能")
|
@router.post("", summary="创建技能")
|
||||||
|
@check_skill_quota
|
||||||
def create_skill(
|
def create_skill(
|
||||||
data: skill_schema.SkillCreate,
|
data: skill_schema.SkillCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
|||||||
173
api/app/controllers/tenant_subscription_controller.py
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
"""
|
||||||
|
租户套餐查询接口(普通用户可访问)
|
||||||
|
"""
|
||||||
|
import datetime
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.logging_config import get_api_logger
|
||||||
|
from app.core.response_utils import success, fail
|
||||||
|
from app.db import get_db
|
||||||
|
from app.dependencies import get_current_user
|
||||||
|
from app.i18n.dependencies import get_translator
|
||||||
|
from app.models.user_model import User
|
||||||
|
from app.schemas.response_schema import ApiResponse
|
||||||
|
|
||||||
|
logger = get_api_logger()
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/tenant", tags=["Tenant"])
|
||||||
|
public_router = APIRouter(tags=["Tenant"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/subscription", response_model=ApiResponse, summary="获取当前用户所属租户的套餐信息")
|
||||||
|
async def get_my_tenant_subscription(
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取当前登录用户所属租户的有效套餐订阅信息。
|
||||||
|
包含套餐名称、版本、配额、到期时间等。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from premium.platform_admin.package_plan_service import TenantSubscriptionService
|
||||||
|
|
||||||
|
if not current_user.tenant:
|
||||||
|
return JSONResponse(status_code=404, content=fail(code=404, msg="用户未关联租户"))
|
||||||
|
|
||||||
|
tenant_id = current_user.tenant.id
|
||||||
|
svc = TenantSubscriptionService(db)
|
||||||
|
sub = svc.get_subscription(tenant_id)
|
||||||
|
|
||||||
|
if not sub:
|
||||||
|
# 无订阅记录时,兜底返回免费套餐信息
|
||||||
|
free_plan = svc.plan_repo.get_free_plan()
|
||||||
|
if not free_plan:
|
||||||
|
return success(data=None, msg="暂无有效套餐")
|
||||||
|
return success(data={
|
||||||
|
"subscription_id": None,
|
||||||
|
"tenant_id": str(tenant_id),
|
||||||
|
"package_plan_id": str(free_plan.id),
|
||||||
|
"package_version": free_plan.version,
|
||||||
|
"package_plan": {
|
||||||
|
"id": str(free_plan.id),
|
||||||
|
"name": free_plan.name,
|
||||||
|
"name_en": free_plan.name_en,
|
||||||
|
"version": free_plan.version,
|
||||||
|
"category": free_plan.category,
|
||||||
|
"tier_level": free_plan.tier_level,
|
||||||
|
"price": float(free_plan.price) if free_plan.price is not None else 0.0,
|
||||||
|
"billing_cycle": free_plan.billing_cycle,
|
||||||
|
"core_value": free_plan.core_value,
|
||||||
|
"core_value_en": free_plan.core_value_en,
|
||||||
|
"tech_support": free_plan.tech_support,
|
||||||
|
"tech_support_en": free_plan.tech_support_en,
|
||||||
|
"sla_compliance": free_plan.sla_compliance,
|
||||||
|
"sla_compliance_en": free_plan.sla_compliance_en,
|
||||||
|
"page_customization": free_plan.page_customization,
|
||||||
|
"page_customization_en": free_plan.page_customization_en,
|
||||||
|
"theme_color": free_plan.theme_color,
|
||||||
|
},
|
||||||
|
"started_at": None,
|
||||||
|
"expired_at": None,
|
||||||
|
"status": "active",
|
||||||
|
"quotas": free_plan.quotas or {},
|
||||||
|
"created_at": int(datetime.datetime.utcnow().timestamp() * 1000),
|
||||||
|
"updated_at": int(datetime.datetime.utcnow().timestamp() * 1000),
|
||||||
|
}, msg="免费套餐")
|
||||||
|
|
||||||
|
return success(data=svc.build_response(sub))
|
||||||
|
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
# 社区版无 premium 模块,从配置文件读取免费套餐
|
||||||
|
if not current_user.tenant:
|
||||||
|
return JSONResponse(status_code=404, content=fail(code=404, msg="用户未关联租户"))
|
||||||
|
|
||||||
|
from app.config.default_free_plan import DEFAULT_FREE_PLAN
|
||||||
|
|
||||||
|
plan = DEFAULT_FREE_PLAN
|
||||||
|
response_data = {
|
||||||
|
"subscription_id": None,
|
||||||
|
"tenant_id": str(current_user.tenant.id),
|
||||||
|
"package_plan_id": None,
|
||||||
|
"package_version": plan["version"],
|
||||||
|
"package_plan": {
|
||||||
|
"id": None,
|
||||||
|
"name": plan["name"],
|
||||||
|
"name_en": plan.get("name_en"),
|
||||||
|
"version": plan["version"],
|
||||||
|
"category": plan["category"],
|
||||||
|
"tier_level": plan["tier_level"],
|
||||||
|
"price": float(plan["price"]),
|
||||||
|
"billing_cycle": plan["billing_cycle"],
|
||||||
|
"core_value": plan.get("core_value"),
|
||||||
|
"core_value_en": plan.get("core_value_en"),
|
||||||
|
"tech_support": plan.get("tech_support"),
|
||||||
|
"tech_support_en": plan.get("tech_support_en"),
|
||||||
|
"sla_compliance": plan.get("sla_compliance"),
|
||||||
|
"sla_compliance_en": plan.get("sla_compliance_en"),
|
||||||
|
"page_customization": plan.get("page_customization"),
|
||||||
|
"page_customization_en": plan.get("page_customization_en"),
|
||||||
|
"theme_color": plan.get("theme_color"),
|
||||||
|
},
|
||||||
|
"started_at": None,
|
||||||
|
"expired_at": None,
|
||||||
|
"status": "active",
|
||||||
|
"quotas": plan["quotas"],
|
||||||
|
"created_at": int(datetime.datetime.utcnow().timestamp() * 1000),
|
||||||
|
"updated_at": int(datetime.datetime.utcnow().timestamp() * 1000),
|
||||||
|
}
|
||||||
|
return success(data=response_data, msg="社区版免费套餐")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取租户套餐信息失败: {e}", exc_info=True)
|
||||||
|
return JSONResponse(status_code=500, content=fail(code=500, msg="获取套餐信息失败"))
|
||||||
|
|
||||||
|
|
||||||
|
@public_router.get("/package-plans", response_model=ApiResponse, summary="获取套餐列表(公开)")
|
||||||
|
async def list_package_plans_public(
|
||||||
|
category: Optional[str] = None,
|
||||||
|
status: Optional[bool] = None,
|
||||||
|
search: Optional[str] = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
公开接口,无需鉴权。
|
||||||
|
SaaS 版从数据库读取套餐列表;社区版降级返回 default_free_plan.py 中的免费套餐。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from premium.platform_admin.package_plan_service import PackagePlanService
|
||||||
|
from premium.platform_admin.package_plan_schema import PackagePlanResponse
|
||||||
|
svc = PackagePlanService(db)
|
||||||
|
result = svc.get_list(page=1, size=9999, category=category, status=status, search=search)
|
||||||
|
return success(data=[PackagePlanResponse.model_validate(p).model_dump(mode="json") for p in result["items"]])
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
from app.config.default_free_plan import DEFAULT_FREE_PLAN
|
||||||
|
plan = DEFAULT_FREE_PLAN
|
||||||
|
return success(data=[{
|
||||||
|
"id": None,
|
||||||
|
"name": plan["name"],
|
||||||
|
"name_en": plan.get("name_en"),
|
||||||
|
"version": plan["version"],
|
||||||
|
"category": plan["category"],
|
||||||
|
"tier_level": plan["tier_level"],
|
||||||
|
"price": float(plan["price"]),
|
||||||
|
"billing_cycle": plan["billing_cycle"],
|
||||||
|
"core_value": plan.get("core_value"),
|
||||||
|
"core_value_en": plan.get("core_value_en"),
|
||||||
|
"tech_support": plan.get("tech_support"),
|
||||||
|
"tech_support_en": plan.get("tech_support_en"),
|
||||||
|
"sla_compliance": plan.get("sla_compliance"),
|
||||||
|
"sla_compliance_en": plan.get("sla_compliance_en"),
|
||||||
|
"page_customization": plan.get("page_customization"),
|
||||||
|
"page_customization_en": plan.get("page_customization_en"),
|
||||||
|
"theme_color": plan.get("theme_color"),
|
||||||
|
"status": plan.get("status", True),
|
||||||
|
"quotas": plan["quotas"],
|
||||||
|
}])
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取套餐列表失败: {e}", exc_info=True)
|
||||||
|
return JSONResponse(status_code=500, content=fail(code=500, msg="获取套餐列表失败"))
|
||||||
@@ -114,11 +114,14 @@ def get_current_user_info(
|
|||||||
|
|
||||||
# 设置权限:如果用户来自 SSO Source,则使用该 Source 的 permissions;否则返回 "all" 表示拥有所有权限
|
# 设置权限:如果用户来自 SSO Source,则使用该 Source 的 permissions;否则返回 "all" 表示拥有所有权限
|
||||||
if current_user.external_source:
|
if current_user.external_source:
|
||||||
from premium.sso.models import SSOSource
|
try:
|
||||||
source = db.query(SSOSource).filter(SSOSource.source_code == current_user.external_source).first()
|
from premium.sso.models import SSOSource
|
||||||
if source and source.permissions:
|
source = db.query(SSOSource).filter(SSOSource.source_code == current_user.external_source).first()
|
||||||
result_schema.permissions = source.permissions
|
if source and source.permissions:
|
||||||
else:
|
result_schema.permissions = source.permissions
|
||||||
|
else:
|
||||||
|
result_schema.permissions = []
|
||||||
|
except ModuleNotFoundError:
|
||||||
result_schema.permissions = []
|
result_schema.permissions = []
|
||||||
else:
|
else:
|
||||||
result_schema.permissions = ["all"]
|
result_schema.permissions = ["all"]
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from app.schemas.workspace_schema import (
|
|||||||
WorkspaceUpdate,
|
WorkspaceUpdate,
|
||||||
)
|
)
|
||||||
from app.services import workspace_service
|
from app.services import workspace_service
|
||||||
|
from app.core.quota_stub import check_workspace_quota
|
||||||
|
|
||||||
# 获取API专用日志器
|
# 获取API专用日志器
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
@@ -106,6 +107,7 @@ def get_workspaces(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=ApiResponse)
|
@router.post("", response_model=ApiResponse)
|
||||||
|
@check_workspace_quota
|
||||||
def create_workspace(
|
def create_workspace(
|
||||||
workspace: WorkspaceCreate,
|
workspace: WorkspaceCreate,
|
||||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import time
|
|||||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||||
|
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langgraph.errors import GraphRecursionError
|
from langgraph.errors import GraphRecursionError
|
||||||
|
|
||||||
@@ -41,6 +41,7 @@ class LangChainAgent:
|
|||||||
max_tool_consecutive_calls: int = 3, # 单个工具最大连续调用次数
|
max_tool_consecutive_calls: int = 3, # 单个工具最大连续调用次数
|
||||||
deep_thinking: bool = False, # 是否启用深度思考模式
|
deep_thinking: bool = False, # 是否启用深度思考模式
|
||||||
thinking_budget_tokens: Optional[int] = None, # 深度思考 token 预算
|
thinking_budget_tokens: Optional[int] = None, # 深度思考 token 预算
|
||||||
|
json_output: bool = False, # 是否强制 JSON 输出
|
||||||
capability: Optional[List[str]] = None # 模型能力列表,用于校验是否支持深度思考
|
capability: Optional[List[str]] = None # 模型能力列表,用于校验是否支持深度思考
|
||||||
):
|
):
|
||||||
"""初始化 LangChain Agent
|
"""初始化 LangChain Agent
|
||||||
@@ -64,7 +65,6 @@ class LangChainAgent:
|
|||||||
self.streaming = streaming
|
self.streaming = streaming
|
||||||
self.is_omni = is_omni
|
self.is_omni = is_omni
|
||||||
self.max_tool_consecutive_calls = max_tool_consecutive_calls
|
self.max_tool_consecutive_calls = max_tool_consecutive_calls
|
||||||
self.deep_thinking = deep_thinking and ("thinking" in (capability or []))
|
|
||||||
|
|
||||||
# 工具调用计数器:记录每个工具的连续调用次数
|
# 工具调用计数器:记录每个工具的连续调用次数
|
||||||
self.tool_call_counter: Dict[str, int] = {}
|
self.tool_call_counter: Dict[str, int] = {}
|
||||||
@@ -80,6 +80,17 @@ class LangChainAgent:
|
|||||||
|
|
||||||
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
||||||
|
|
||||||
|
# ChatTongyi 要求 messages 含 'json' 字样才能使用 response_format
|
||||||
|
# 在 system prompt 中注入 JSON 要求
|
||||||
|
from app.models.models_model import ModelProvider
|
||||||
|
if json_output and (
|
||||||
|
(provider.lower() == ModelProvider.DASHSCOPE and not is_omni)
|
||||||
|
or provider.lower() == ModelProvider.VOLCANO
|
||||||
|
# 有工具时 response_format 会被移除,所有 provider 都需要 system prompt 注入保证 JSON 输出
|
||||||
|
or bool(tools)
|
||||||
|
):
|
||||||
|
self.system_prompt += "\n请以JSON格式输出。"
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Agent 迭代次数配置: max_iterations={self.max_iterations}, "
|
f"Agent 迭代次数配置: max_iterations={self.max_iterations}, "
|
||||||
f"tool_count={len(self.tools)}, "
|
f"tool_count={len(self.tools)}, "
|
||||||
@@ -87,23 +98,17 @@ class LangChainAgent:
|
|||||||
f"auto_calculated={max_iterations is None}"
|
f"auto_calculated={max_iterations is None}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 根据 capability 校验是否真正支持深度思考
|
# 创建 RedBearLLM,capability 校验由 RedBearModelConfig 统一处理
|
||||||
actual_deep_thinking = self.deep_thinking
|
|
||||||
if deep_thinking and not actual_deep_thinking:
|
|
||||||
logger.warning(
|
|
||||||
f"模型 {model_name} 不支持深度思考(capability 中无 'thinking'),已自动关闭 deep_thinking"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建 RedBearLLM(支持多提供商)
|
|
||||||
model_config = RedBearModelConfig(
|
model_config = RedBearModelConfig(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
is_omni=is_omni,
|
is_omni=is_omni,
|
||||||
deep_thinking=actual_deep_thinking,
|
capability=capability,
|
||||||
thinking_budget_tokens=thinking_budget_tokens if actual_deep_thinking else None,
|
deep_thinking=deep_thinking,
|
||||||
support_thinking="thinking" in (capability or []),
|
thinking_budget_tokens=thinking_budget_tokens,
|
||||||
|
json_output=json_output,
|
||||||
extra_params={
|
extra_params={
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
@@ -112,6 +117,9 @@ class LangChainAgent:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.llm = RedBearLLM(model_config, type=ModelType.CHAT)
|
self.llm = RedBearLLM(model_config, type=ModelType.CHAT)
|
||||||
|
# 从经过校验的 config 读取实际生效的能力开关
|
||||||
|
self.deep_thinking = model_config.deep_thinking
|
||||||
|
self.json_output = model_config.json_output
|
||||||
|
|
||||||
# 获取底层模型用于真正的流式调用
|
# 获取底层模型用于真正的流式调用
|
||||||
self._underlying_llm = self.llm._model if hasattr(self.llm, '_model') else self.llm
|
self._underlying_llm = self.llm._model if hasattr(self.llm, '_model') else self.llm
|
||||||
@@ -237,9 +245,7 @@ class LangChainAgent:
|
|||||||
Returns:
|
Returns:
|
||||||
List[BaseMessage]: 消息列表
|
List[BaseMessage]: 消息列表
|
||||||
"""
|
"""
|
||||||
messages:list = [SystemMessage(content=self.system_prompt)]
|
messages: list = []
|
||||||
|
|
||||||
# 添加系统提示词
|
|
||||||
|
|
||||||
# 添加历史消息
|
# 添加历史消息
|
||||||
if history:
|
if history:
|
||||||
|
|||||||
@@ -96,6 +96,38 @@ def require_api_key(
|
|||||||
resource_id=api_key_obj.resource_id,
|
resource_id=api_key_obj.resource_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ── Tenant 级别限速(来自套餐配额 api_ops_rate_limit)──────────
|
||||||
|
try:
|
||||||
|
from app.models.workspace_model import Workspace
|
||||||
|
from premium.platform_admin.package_plan_service import TenantSubscriptionService
|
||||||
|
|
||||||
|
workspace = db.query(Workspace).filter(
|
||||||
|
Workspace.id == api_key_obj.workspace_id
|
||||||
|
).first()
|
||||||
|
if workspace:
|
||||||
|
quota = TenantSubscriptionService(db).get_effective_quota(workspace.tenant_id)
|
||||||
|
tenant_qps_limit = quota.get("api_ops_rate_limit") if quota else None
|
||||||
|
if tenant_qps_limit:
|
||||||
|
rate_limiter = RateLimiterService()
|
||||||
|
tenant_ok, tenant_info = await rate_limiter.check_tenant_rate_limit(
|
||||||
|
workspace.tenant_id, tenant_qps_limit
|
||||||
|
)
|
||||||
|
if not tenant_ok:
|
||||||
|
raise RateLimitException(
|
||||||
|
"租户 API 调用速率超限",
|
||||||
|
BizCode.API_KEY_QPS_LIMIT_EXCEEDED,
|
||||||
|
rate_headers={
|
||||||
|
"X-RateLimit-Tenant-Limit": str(tenant_info["limit"]),
|
||||||
|
"X-RateLimit-Tenant-Remaining": str(tenant_info["remaining"]),
|
||||||
|
"X-RateLimit-Tenant-Reset": str(tenant_info["reset"]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except RateLimitException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Tenant 限速检查异常,跳过: {e}")
|
||||||
|
# ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
rate_limiter = RateLimiterService()
|
rate_limiter = RateLimiterService()
|
||||||
is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj)
|
is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj)
|
||||||
if not is_allowed:
|
if not is_allowed:
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from dotenv import load_dotenv
|
|||||||
|
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
||||||
|
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import _USER_PLACEHOLDER_NAMES
|
||||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import \
|
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import \
|
||||||
memory_summary_generation
|
memory_summary_generation
|
||||||
@@ -191,15 +192,37 @@ async def write(
|
|||||||
if success:
|
if success:
|
||||||
logger.info("Successfully saved all data to Neo4j")
|
logger.info("Successfully saved all data to Neo4j")
|
||||||
|
|
||||||
# 使用 Celery 异步任务触发聚类(不阻塞主流程)
|
|
||||||
if all_entity_nodes:
|
if all_entity_nodes:
|
||||||
|
end_user_id = all_entity_nodes[0].end_user_id
|
||||||
|
|
||||||
|
# Neo4j 写入完成后,用 PgSQL 权威 aliases 覆盖 Neo4j 用户实体
|
||||||
|
try:
|
||||||
|
from app.repositories.end_user_info_repository import EndUserInfoRepository
|
||||||
|
if end_user_id:
|
||||||
|
with get_db_context() as db_session:
|
||||||
|
info = EndUserInfoRepository(db_session).get_by_end_user_id(uuid.UUID(end_user_id))
|
||||||
|
pg_aliases = info.aliases if info and info.aliases else []
|
||||||
|
if info is not None:
|
||||||
|
# 将 Python 侧占位名集合作为参数传入,避免 Cypher 硬编码
|
||||||
|
placeholder_names = list(_USER_PLACEHOLDER_NAMES)
|
||||||
|
await neo4j_connector.execute_query(
|
||||||
|
"""
|
||||||
|
MATCH (e:ExtractedEntity)
|
||||||
|
WHERE e.end_user_id = $end_user_id AND toLower(e.name) IN $placeholder_names
|
||||||
|
SET e.aliases = $aliases
|
||||||
|
""",
|
||||||
|
end_user_id=end_user_id, aliases=pg_aliases,
|
||||||
|
placeholder_names=placeholder_names,
|
||||||
|
)
|
||||||
|
logger.info(f"[AliasSync] Neo4j 用户实体 aliases 已用 PgSQL 权威源覆盖: {pg_aliases}")
|
||||||
|
except Exception as sync_err:
|
||||||
|
logger.warning(f"[AliasSync] PgSQL→Neo4j aliases 同步失败(不影响主流程): {sync_err}")
|
||||||
|
|
||||||
|
# 使用 Celery 异步任务触发聚类(不阻塞主流程)
|
||||||
try:
|
try:
|
||||||
from app.tasks import run_incremental_clustering
|
from app.tasks import run_incremental_clustering
|
||||||
|
|
||||||
end_user_id = all_entity_nodes[0].end_user_id
|
|
||||||
new_entity_ids = [e.id for e in all_entity_nodes]
|
new_entity_ids = [e.id for e in all_entity_nodes]
|
||||||
|
|
||||||
# 异步提交 Celery 任务
|
|
||||||
task = run_incremental_clustering.apply_async(
|
task = run_incremental_clustering.apply_async(
|
||||||
kwargs={
|
kwargs={
|
||||||
"end_user_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
@@ -207,7 +230,6 @@ async def write(
|
|||||||
"llm_model_id": str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
"llm_model_id": str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||||
"embedding_model_id": str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
"embedding_model_id": str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
||||||
},
|
},
|
||||||
# 设置任务优先级(低优先级,不影响主业务)
|
|
||||||
priority=3,
|
priority=3,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -215,7 +237,6 @@ async def write(
|
|||||||
f"task_id={task.id}, end_user_id={end_user_id}, entity_count={len(new_entity_ids)}"
|
f"task_id={task.id}, end_user_id={end_user_id}, entity_count={len(new_entity_ids)}"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 聚类任务提交失败不影响主流程
|
|
||||||
logger.error(f"[Clustering] 提交聚类任务失败(不影响主流程): {e}", exc_info=True)
|
logger.error(f"[Clustering] 提交聚类任务失败(不影响主流程): {e}", exc_info=True)
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -61,9 +61,9 @@ from app.core.memory.models.triplet_models import (
|
|||||||
# User metadata models
|
# User metadata models
|
||||||
from app.core.memory.models.metadata_models import (
|
from app.core.memory.models.metadata_models import (
|
||||||
UserMetadata,
|
UserMetadata,
|
||||||
UserMetadataBehavioralHints,
|
|
||||||
UserMetadataProfile,
|
UserMetadataProfile,
|
||||||
MetadataExtractionResponse,
|
MetadataExtractionResponse,
|
||||||
|
MetadataFieldChange,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ontology scenario models (LLM extracted from scenarios)
|
# Ontology scenario models (LLM extracted from scenarios)
|
||||||
@@ -133,9 +133,9 @@ __all__ = [
|
|||||||
"Triplet",
|
"Triplet",
|
||||||
"TripletExtractionResponse",
|
"TripletExtractionResponse",
|
||||||
"UserMetadata",
|
"UserMetadata",
|
||||||
"UserMetadataBehavioralHints",
|
|
||||||
"UserMetadataProfile",
|
"UserMetadataProfile",
|
||||||
"MetadataExtractionResponse",
|
"MetadataExtractionResponse",
|
||||||
|
"MetadataFieldChange",
|
||||||
# Ontology models
|
# Ontology models
|
||||||
"OntologyClass",
|
"OntologyClass",
|
||||||
"OntologyExtractionResponse",
|
"OntologyExtractionResponse",
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ Independent from triplet_models.py - these models are used by the
|
|||||||
standalone metadata extraction pipeline (post-dedup async Celery task).
|
standalone metadata extraction pipeline (post-dedup async Celery task).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List
|
from typing import List, Literal, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
@@ -13,8 +13,8 @@ class UserMetadataProfile(BaseModel):
|
|||||||
"""用户画像信息"""
|
"""用户画像信息"""
|
||||||
|
|
||||||
model_config = ConfigDict(extra="ignore")
|
model_config = ConfigDict(extra="ignore")
|
||||||
role: str = Field(default="", description="用户职业或角色")
|
role: List[str] = Field(default_factory=list, description="用户职业或角色")
|
||||||
domain: str = Field(default="", description="用户所在领域")
|
domain: List[str] = Field(default_factory=list, description="用户所在领域")
|
||||||
expertise: List[str] = Field(
|
expertise: List[str] = Field(
|
||||||
default_factory=list, description="用户擅长的技能或工具"
|
default_factory=list, description="用户擅长的技能或工具"
|
||||||
)
|
)
|
||||||
@@ -23,31 +23,37 @@ class UserMetadataProfile(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class UserMetadataBehavioralHints(BaseModel):
|
|
||||||
"""行为偏好"""
|
|
||||||
|
|
||||||
model_config = ConfigDict(extra="ignore")
|
|
||||||
learning_stage: str = Field(default="", description="学习阶段")
|
|
||||||
preferred_depth: str = Field(default="", description="偏好深度")
|
|
||||||
tone_preference: str = Field(default="", description="语气偏好")
|
|
||||||
|
|
||||||
|
|
||||||
class UserMetadata(BaseModel):
|
class UserMetadata(BaseModel):
|
||||||
"""用户元数据顶层结构"""
|
"""用户元数据顶层结构"""
|
||||||
|
|
||||||
model_config = ConfigDict(extra="ignore")
|
model_config = ConfigDict(extra="ignore")
|
||||||
profile: UserMetadataProfile = Field(default_factory=UserMetadataProfile)
|
profile: UserMetadataProfile = Field(default_factory=UserMetadataProfile)
|
||||||
behavioral_hints: UserMetadataBehavioralHints = Field(
|
|
||||||
default_factory=UserMetadataBehavioralHints
|
|
||||||
|
class MetadataFieldChange(BaseModel):
|
||||||
|
"""单个元数据字段的变更操作"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="ignore")
|
||||||
|
field_path: str = Field(
|
||||||
|
description="字段路径,用点号分隔,如 'profile.role'、'profile.expertise'"
|
||||||
|
)
|
||||||
|
action: Literal["set", "remove"] = Field(
|
||||||
|
description="操作类型:'set' 表示新增或修改,'remove' 表示移除"
|
||||||
|
)
|
||||||
|
value: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="字段的新值(action='set' 时必填)。标量字段直接填值,列表字段填单个要新增的元素"
|
||||||
)
|
)
|
||||||
knowledge_tags: List[str] = Field(default_factory=list, description="知识标签")
|
|
||||||
|
|
||||||
|
|
||||||
class MetadataExtractionResponse(BaseModel):
|
class MetadataExtractionResponse(BaseModel):
|
||||||
"""元数据提取 LLM 响应结构"""
|
"""元数据提取 LLM 响应结构(增量模式)"""
|
||||||
|
|
||||||
model_config = ConfigDict(extra="ignore")
|
model_config = ConfigDict(extra="ignore")
|
||||||
user_metadata: UserMetadata = Field(default_factory=UserMetadata)
|
metadata_changes: List[MetadataFieldChange] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="元数据的增量变更列表,每项描述一个字段的新增、修改或移除操作",
|
||||||
|
)
|
||||||
aliases_to_add: List[str] = Field(
|
aliases_to_add: List[str] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="本次新发现的用户别名(用户自我介绍或他人对用户的称呼)",
|
description="本次新发现的用户别名(用户自我介绍或他人对用户的称呼)",
|
||||||
|
|||||||
@@ -82,51 +82,38 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
|
|||||||
canonical.connect_strength = next(iter(pair))
|
canonical.connect_strength = next(iter(pair))
|
||||||
|
|
||||||
# 别名合并(去重保序,使用标准化工具)
|
# 别名合并(去重保序,使用标准化工具)
|
||||||
|
# 用户实体的 aliases 由 PgSQL end_user_info 作为唯一权威源,去重合并时不修改
|
||||||
try:
|
try:
|
||||||
canonical_name = (getattr(canonical, "name", "") or "").strip()
|
canonical_name = (getattr(canonical, "name", "") or "").strip()
|
||||||
incoming_name = (getattr(ent, "name", "") or "").strip()
|
if canonical_name.lower() not in _USER_PLACEHOLDER_NAMES:
|
||||||
|
incoming_name = (getattr(ent, "name", "") or "").strip()
|
||||||
|
|
||||||
# 收集所有需要合并的别名
|
# 收集所有需要合并的别名,过滤掉用户占位名避免污染非用户实体
|
||||||
all_aliases = []
|
all_aliases = list(getattr(canonical, "aliases", []) or [])
|
||||||
|
if incoming_name and incoming_name != canonical_name and incoming_name.lower() not in _USER_PLACEHOLDER_NAMES:
|
||||||
|
all_aliases.append(incoming_name)
|
||||||
|
all_aliases.extend(
|
||||||
|
a for a in (getattr(ent, "aliases", []) or [])
|
||||||
|
if a and a.strip().lower() not in _USER_PLACEHOLDER_NAMES
|
||||||
|
)
|
||||||
|
|
||||||
# 1. 添加canonical现有的别名
|
try:
|
||||||
existing = getattr(canonical, "aliases", []) or []
|
from app.core.memory.utils.alias_utils import normalize_aliases
|
||||||
all_aliases.extend(existing)
|
canonical.aliases = normalize_aliases(canonical_name, all_aliases)
|
||||||
|
except Exception:
|
||||||
# 2. 添加incoming实体的名称(如果不同于canonical的名称)
|
seen_normalized = set()
|
||||||
if incoming_name and incoming_name != canonical_name:
|
unique_aliases = []
|
||||||
all_aliases.append(incoming_name)
|
for alias in all_aliases:
|
||||||
|
if not alias:
|
||||||
# 3. 添加incoming实体的所有别名
|
continue
|
||||||
incoming = getattr(ent, "aliases", []) or []
|
alias_stripped = str(alias).strip()
|
||||||
all_aliases.extend(incoming)
|
if not alias_stripped or alias_stripped == canonical_name:
|
||||||
|
continue
|
||||||
# 4. 标准化并去重(优先使用alias_utils工具函数)
|
alias_normalized = alias_stripped.lower()
|
||||||
try:
|
if alias_normalized not in seen_normalized:
|
||||||
from app.core.memory.utils.alias_utils import normalize_aliases
|
seen_normalized.add(alias_normalized)
|
||||||
canonical.aliases = normalize_aliases(canonical_name, all_aliases)
|
unique_aliases.append(alias_stripped)
|
||||||
except Exception:
|
canonical.aliases = sorted(unique_aliases)
|
||||||
# 如果导入失败,使用增强的去重逻辑
|
|
||||||
seen_normalized = set()
|
|
||||||
unique_aliases = []
|
|
||||||
|
|
||||||
for alias in all_aliases:
|
|
||||||
if not alias:
|
|
||||||
continue
|
|
||||||
|
|
||||||
alias_stripped = str(alias).strip()
|
|
||||||
if not alias_stripped or alias_stripped == canonical_name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 标准化:转小写用于去重判断
|
|
||||||
alias_normalized = alias_stripped.lower()
|
|
||||||
|
|
||||||
if alias_normalized not in seen_normalized:
|
|
||||||
seen_normalized.add(alias_normalized)
|
|
||||||
unique_aliases.append(alias_stripped)
|
|
||||||
|
|
||||||
# 排序并赋值
|
|
||||||
canonical.aliases = sorted(unique_aliases)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -733,66 +720,37 @@ def fuzzy_match(
|
|||||||
|
|
||||||
|
|
||||||
def _merge_entities_with_aliases(canonical: ExtractedEntityNode, losing: ExtractedEntityNode):
|
def _merge_entities_with_aliases(canonical: ExtractedEntityNode, losing: ExtractedEntityNode):
|
||||||
""" 模糊匹配中的实体合并。
|
"""模糊匹配中的实体合并(别名部分)。
|
||||||
|
|
||||||
合并策略:
|
用户实体的 aliases 由 PgSQL end_user_info 作为唯一权威源,跳过合并。
|
||||||
1. 保留canonical的主名称不变
|
|
||||||
2. 将losing的主名称添加为alias(如果不同)
|
|
||||||
3. 合并两个实体的所有aliases
|
|
||||||
4. 自动去重(case-insensitive)并排序
|
|
||||||
|
|
||||||
Args:
|
|
||||||
canonical: 规范实体(保留)
|
|
||||||
losing: 被合并实体(删除)
|
|
||||||
|
|
||||||
Note:
|
|
||||||
使用alias_utils.normalize_aliases进行标准化去重
|
|
||||||
"""
|
"""
|
||||||
# 获取规范实体的名称
|
|
||||||
canonical_name = (getattr(canonical, "name", "") or "").strip()
|
canonical_name = (getattr(canonical, "name", "") or "").strip()
|
||||||
|
if canonical_name.lower() in _USER_PLACEHOLDER_NAMES:
|
||||||
|
return
|
||||||
|
|
||||||
losing_name = (getattr(losing, "name", "") or "").strip()
|
losing_name = (getattr(losing, "name", "") or "").strip()
|
||||||
|
|
||||||
# 收集所有需要合并的别名
|
all_aliases = list(getattr(canonical, "aliases", []) or [])
|
||||||
all_aliases = []
|
|
||||||
|
|
||||||
# 1. 添加canonical现有的别名
|
|
||||||
current_aliases = getattr(canonical, "aliases", []) or []
|
|
||||||
all_aliases.extend(current_aliases)
|
|
||||||
|
|
||||||
# 2. 添加losing实体的名称(如果不同于canonical的名称)
|
|
||||||
if losing_name and losing_name != canonical_name:
|
if losing_name and losing_name != canonical_name:
|
||||||
all_aliases.append(losing_name)
|
all_aliases.append(losing_name)
|
||||||
|
all_aliases.extend(getattr(losing, "aliases", []) or [])
|
||||||
|
|
||||||
# 3. 添加losing实体的所有别名
|
|
||||||
losing_aliases = getattr(losing, "aliases", []) or []
|
|
||||||
all_aliases.extend(losing_aliases)
|
|
||||||
|
|
||||||
# 4. 标准化并去重(使用标准化后的字符串进行去重)
|
|
||||||
try:
|
try:
|
||||||
from app.core.memory.utils.alias_utils import normalize_aliases
|
from app.core.memory.utils.alias_utils import normalize_aliases
|
||||||
canonical.aliases = normalize_aliases(canonical_name, all_aliases)
|
canonical.aliases = normalize_aliases(canonical_name, all_aliases)
|
||||||
except Exception:
|
except Exception:
|
||||||
# 如果导入失败,使用增强的去重逻辑
|
|
||||||
# 使用标准化后的字符串作为key进行去重
|
|
||||||
seen_normalized = set()
|
seen_normalized = set()
|
||||||
unique_aliases = []
|
unique_aliases = []
|
||||||
|
|
||||||
for alias in all_aliases:
|
for alias in all_aliases:
|
||||||
if not alias:
|
if not alias:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
alias_stripped = str(alias).strip()
|
alias_stripped = str(alias).strip()
|
||||||
if not alias_stripped or alias_stripped == canonical_name:
|
if not alias_stripped or alias_stripped == canonical_name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 标准化:转小写用于去重判断
|
|
||||||
alias_normalized = alias_stripped.lower()
|
alias_normalized = alias_stripped.lower()
|
||||||
|
|
||||||
if alias_normalized not in seen_normalized:
|
if alias_normalized not in seen_normalized:
|
||||||
seen_normalized.add(alias_normalized)
|
seen_normalized.add(alias_normalized)
|
||||||
unique_aliases.append(alias_stripped)
|
unique_aliases.append(alias_stripped)
|
||||||
|
|
||||||
# 排序并赋值
|
|
||||||
canonical.aliases = sorted(unique_aliases)
|
canonical.aliases = sorted(unique_aliases)
|
||||||
|
|
||||||
# ========== 主循环:遍历所有实体对进行模糊匹配 ==========
|
# ========== 主循环:遍历所有实体对进行模糊匹配 ==========
|
||||||
|
|||||||
@@ -1391,18 +1391,18 @@ class ExtractionOrchestrator:
|
|||||||
"""
|
"""
|
||||||
将本轮提取的用户别名同步到 end_user 和 end_user_info 表。
|
将本轮提取的用户别名同步到 end_user 和 end_user_info 表。
|
||||||
|
|
||||||
注意:此方法在 Neo4j 写入之前调用,因此不能依赖 Neo4j 作为别名的权威数据源。
|
PgSQL end_user_info.aliases 是用户别名的唯一权威源。
|
||||||
改为直接使用内存中去重后的 entity_nodes 的 aliases,与 PgSQL 已有的 aliases 合并。
|
此方法仅将本轮 LLM 从对话中新提取的别名增量追加到 PgSQL,
|
||||||
|
不再从 Neo4j 二层去重合并历史别名,避免脏数据反向污染 PgSQL。
|
||||||
|
|
||||||
策略:
|
策略:
|
||||||
1. 从内存中的 entity_nodes 提取本轮用户别名(current_aliases)
|
1. 从本轮对话原始发言中提取用户别名(current_aliases)
|
||||||
2. 从去重后的 entity_nodes 中提取完整别名(含 Neo4j 二层去重合并的历史别名)
|
2. 从 PgSQL end_user_info 读取已有的 aliases(db_aliases)
|
||||||
3. 从 PgSQL end_user_info 读取已有的 aliases(db_aliases)
|
3. 合并 db_aliases + current_aliases,去重保序
|
||||||
4. 合并 db_aliases + deduped_aliases + current_aliases,去重保序
|
4. 写回 PgSQL
|
||||||
5. 写回 PgSQL
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
entity_nodes: 去重后的实体节点列表(内存中,含二层去重合并结果)
|
entity_nodes: 去重后的实体节点列表(内存中)
|
||||||
dialog_data_list: 对话数据列表
|
dialog_data_list: 对话数据列表
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
@@ -1418,11 +1418,6 @@ class ExtractionOrchestrator:
|
|||||||
# 1. 提取本轮对话的用户别名(保持 LLM 提取的原始顺序,不排序)
|
# 1. 提取本轮对话的用户别名(保持 LLM 提取的原始顺序,不排序)
|
||||||
current_aliases = self._extract_current_aliases(entity_nodes, dialog_data_list)
|
current_aliases = self._extract_current_aliases(entity_nodes, dialog_data_list)
|
||||||
|
|
||||||
# 1.5 从去重后的 entity_nodes 中提取完整别名
|
|
||||||
# 二层去重会将 Neo4j 中已有的历史别名合并到 entity_nodes 中,
|
|
||||||
# 这里提取出来确保 PgSQL 与 Neo4j 的别名保持同步
|
|
||||||
deduped_aliases = self._extract_deduped_entity_aliases(entity_nodes)
|
|
||||||
|
|
||||||
# 1.6 从 Neo4j 查询已有的 AI 助手别名,作为额外的排除源
|
# 1.6 从 Neo4j 查询已有的 AI 助手别名,作为额外的排除源
|
||||||
# (防止 LLM 未提取出 AI 助手实体时,AI 别名泄漏到用户别名中)
|
# (防止 LLM 未提取出 AI 助手实体时,AI 别名泄漏到用户别名中)
|
||||||
neo4j_assistant_aliases = await self._fetch_neo4j_assistant_aliases(end_user_id)
|
neo4j_assistant_aliases = await self._fetch_neo4j_assistant_aliases(end_user_id)
|
||||||
@@ -1434,19 +1429,12 @@ class ExtractionOrchestrator:
|
|||||||
]
|
]
|
||||||
if len(current_aliases) < before_count:
|
if len(current_aliases) < before_count:
|
||||||
logger.info(f"通过 Neo4j AI 助手别名排除了 {before_count - len(current_aliases)} 个误归属别名")
|
logger.info(f"通过 Neo4j AI 助手别名排除了 {before_count - len(current_aliases)} 个误归属别名")
|
||||||
# 同样过滤 deduped_aliases
|
|
||||||
deduped_aliases = [
|
|
||||||
a for a in deduped_aliases
|
|
||||||
if a.strip().lower() not in neo4j_assistant_aliases
|
|
||||||
]
|
|
||||||
|
|
||||||
if not current_aliases and not deduped_aliases:
|
if not current_aliases:
|
||||||
logger.debug(f"本轮未提取到用户别名,跳过同步: end_user_id={end_user_id}")
|
logger.debug(f"本轮未提取到用户别名,跳过同步: end_user_id={end_user_id}")
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"本轮对话提取的 aliases: {current_aliases}")
|
logger.info(f"本轮对话提取的 aliases: {current_aliases}")
|
||||||
if deduped_aliases:
|
|
||||||
logger.info(f"去重后实体的完整 aliases(含历史): {deduped_aliases}")
|
|
||||||
|
|
||||||
# 2. 同步到数据库
|
# 2. 同步到数据库
|
||||||
end_user_uuid = uuid.UUID(end_user_id)
|
end_user_uuid = uuid.UUID(end_user_id)
|
||||||
@@ -1457,21 +1445,15 @@ class ExtractionOrchestrator:
|
|||||||
logger.warning(f"未找到 end_user_id={end_user_id} 的用户记录")
|
logger.warning(f"未找到 end_user_id={end_user_id} 的用户记录")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 3. 从 PgSQL 读取已有 aliases 并与本轮合并
|
# 3. 从 PgSQL 读取已有 aliases 并与本轮新增合并
|
||||||
info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid)
|
info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid)
|
||||||
db_aliases = (info.aliases if info and info.aliases else [])
|
db_aliases = (info.aliases if info and info.aliases else [])
|
||||||
# 过滤掉占位名称
|
# 过滤掉占位名称
|
||||||
db_aliases = [a for a in db_aliases if a.strip().lower() not in self.USER_PLACEHOLDER_NAMES]
|
db_aliases = [a for a in db_aliases if a.strip().lower() not in self.USER_PLACEHOLDER_NAMES]
|
||||||
|
|
||||||
# 合并:已有 + 去重后完整别名 + 本轮新增,去重保序
|
# 合并:PgSQL 已有 + 本轮新增,去重保序(不再合并 Neo4j 历史别名)
|
||||||
merged_aliases = list(db_aliases)
|
merged_aliases = list(db_aliases)
|
||||||
seen_lower = {a.strip().lower() for a in merged_aliases}
|
seen_lower = {a.strip().lower() for a in merged_aliases}
|
||||||
# 先合并去重后实体的完整别名(含 Neo4j 历史别名)
|
|
||||||
for alias in deduped_aliases:
|
|
||||||
if alias.strip().lower() not in seen_lower:
|
|
||||||
merged_aliases.append(alias)
|
|
||||||
seen_lower.add(alias.strip().lower())
|
|
||||||
# 再合并本轮新提取的别名
|
|
||||||
for alias in current_aliases:
|
for alias in current_aliases:
|
||||||
if alias.strip().lower() not in seen_lower:
|
if alias.strip().lower() not in seen_lower:
|
||||||
merged_aliases.append(alias)
|
merged_aliases.append(alias)
|
||||||
@@ -1505,9 +1487,7 @@ class ExtractionOrchestrator:
|
|||||||
info.aliases = merged_aliases
|
info.aliases = merged_aliases
|
||||||
logger.info(f"同步合并后 aliases 到 end_user_info: {merged_aliases}")
|
logger.info(f"同步合并后 aliases 到 end_user_info: {merged_aliases}")
|
||||||
else:
|
else:
|
||||||
first_alias = current_aliases[0].strip() if current_aliases else (
|
first_alias = current_aliases[0].strip() if current_aliases else ""
|
||||||
deduped_aliases[0].strip() if deduped_aliases else ""
|
|
||||||
)
|
|
||||||
# 确保 first_alias 不是占位名称
|
# 确保 first_alias 不是占位名称
|
||||||
if first_alias and first_alias.lower() not in self.USER_PLACEHOLDER_NAMES:
|
if first_alias and first_alias.lower() not in self.USER_PLACEHOLDER_NAMES:
|
||||||
db.add(EndUserInfo(
|
db.add(EndUserInfo(
|
||||||
|
|||||||
@@ -118,7 +118,7 @@ class MetadataExtractor:
|
|||||||
existing_aliases: Optional[List[str]] = None,
|
existing_aliases: Optional[List[str]] = None,
|
||||||
) -> Optional[tuple]:
|
) -> Optional[tuple]:
|
||||||
"""
|
"""
|
||||||
对筛选后的 statement 列表调用 LLM 提取元数据和用户别名。
|
对筛选后的 statement 列表调用 LLM 提取元数据增量变更和用户别名。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
statements: 用户发言的 statement 文本列表
|
statements: 用户发言的 statement 文本列表
|
||||||
@@ -126,7 +126,8 @@ class MetadataExtractor:
|
|||||||
existing_aliases: 数据库已有的用户别名列表(可选)
|
existing_aliases: 数据库已有的用户别名列表(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(UserMetadata, List[str], List[str]) tuple: (metadata, aliases_to_add, aliases_to_remove) on success, None on failure
|
(List[MetadataFieldChange], List[str], List[str]) tuple:
|
||||||
|
(metadata_changes, aliases_to_add, aliases_to_remove) on success, None on failure
|
||||||
"""
|
"""
|
||||||
if not statements:
|
if not statements:
|
||||||
return None
|
return None
|
||||||
@@ -160,12 +161,12 @@ class MetadataExtractor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if response:
|
if response:
|
||||||
metadata = response.user_metadata if response.user_metadata else None
|
changes = response.metadata_changes if response.metadata_changes else []
|
||||||
to_add = response.aliases_to_add if response.aliases_to_add else []
|
to_add = response.aliases_to_add if response.aliases_to_add else []
|
||||||
to_remove = (
|
to_remove = (
|
||||||
response.aliases_to_remove if response.aliases_to_remove else []
|
response.aliases_to_remove if response.aliases_to_remove else []
|
||||||
)
|
)
|
||||||
return metadata, to_add, to_remove
|
return changes, to_add, to_remove
|
||||||
|
|
||||||
logger.warning("LLM 返回的响应为空")
|
logger.warning("LLM 返回的响应为空")
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -4,11 +4,6 @@
|
|||||||
本模块提供统一的搜索服务接口,支持关键词搜索、语义搜索和混合搜索。
|
本模块提供统一的搜索服务接口,支持关键词搜索、语义搜索和混合搜索。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from app.schemas.memory_config_schema import MemoryConfig
|
|
||||||
|
|
||||||
from app.core.memory.storage_services.search.hybrid_search import HybridSearchStrategy
|
from app.core.memory.storage_services.search.hybrid_search import HybridSearchStrategy
|
||||||
from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy
|
from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy
|
||||||
from app.core.memory.storage_services.search.search_strategy import (
|
from app.core.memory.storage_services.search.search_strategy import (
|
||||||
@@ -29,115 +24,87 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# 向后兼容的函数式API
|
# 向后兼容的函数式API (DEPRECATED - 未被使用)
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# 为了兼容旧代码,提供与 src/search.py 相同的函数式接口
|
# 所有调用方均直接使用 app.core.memory.src.search.run_hybrid_search
|
||||||
|
# 保留注释以备参考
|
||||||
|
|
||||||
|
# async def run_hybrid_search(
|
||||||
async def run_hybrid_search(
|
# query_text: str,
|
||||||
query_text: str,
|
# search_type: str = "hybrid",
|
||||||
search_type: str = "hybrid",
|
# end_user_id: str | None = None,
|
||||||
end_user_id: str | None = None,
|
# apply_id: str | None = None,
|
||||||
apply_id: str | None = None,
|
# user_id: str | None = None,
|
||||||
user_id: str | None = None,
|
# limit: int = 50,
|
||||||
limit: int = 50,
|
# include: list[str] | None = None,
|
||||||
include: list[str] | None = None,
|
# alpha: float = 0.6,
|
||||||
alpha: float = 0.6,
|
# use_forgetting_curve: bool = False,
|
||||||
use_forgetting_curve: bool = False,
|
# memory_config: "MemoryConfig" = None,
|
||||||
memory_config: "MemoryConfig" = None,
|
# **kwargs
|
||||||
**kwargs
|
# ) -> dict:
|
||||||
) -> dict:
|
# """运行混合搜索(向后兼容的函数式API)"""
|
||||||
"""运行混合搜索(向后兼容的函数式API)
|
# from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||||
|
# from app.core.models.base import RedBearModelConfig
|
||||||
这是一个向后兼容的包装函数,将旧的函数式API转换为新的基于类的API。
|
# from app.db import get_db_context
|
||||||
|
# from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
Args:
|
# from app.services.memory_config_service import MemoryConfigService
|
||||||
query_text: 查询文本
|
#
|
||||||
search_type: 搜索类型("hybrid", "keyword", "semantic")
|
# if not memory_config:
|
||||||
end_user_id: 组ID过滤
|
# raise ValueError("memory_config is required for search")
|
||||||
apply_id: 应用ID过滤
|
#
|
||||||
user_id: 用户ID过滤
|
# connector = Neo4jConnector()
|
||||||
limit: 每个类别的最大结果数
|
# with get_db_context() as db:
|
||||||
include: 要包含的搜索类别列表
|
# config_service = MemoryConfigService(db)
|
||||||
alpha: BM25分数权重(0.0-1.0)
|
# embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
|
||||||
use_forgetting_curve: 是否使用遗忘曲线
|
# embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||||
memory_config: MemoryConfig object containing embedding_model_id
|
# embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||||
**kwargs: 其他参数
|
#
|
||||||
|
# try:
|
||||||
Returns:
|
# if search_type == "keyword":
|
||||||
dict: 搜索结果字典,格式与旧API兼容
|
# strategy = KeywordSearchStrategy(connector=connector)
|
||||||
"""
|
# elif search_type == "semantic":
|
||||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
# strategy = SemanticSearchStrategy(
|
||||||
from app.core.models.base import RedBearModelConfig
|
# connector=connector,
|
||||||
from app.db import get_db_context
|
# embedder_client=embedder_client
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
# )
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
# else:
|
||||||
|
# strategy = HybridSearchStrategy(
|
||||||
if not memory_config:
|
# connector=connector,
|
||||||
raise ValueError("memory_config is required for search")
|
# embedder_client=embedder_client,
|
||||||
|
# alpha=alpha,
|
||||||
# 初始化客户端
|
# use_forgetting_curve=use_forgetting_curve
|
||||||
connector = Neo4jConnector()
|
# )
|
||||||
with get_db_context() as db:
|
#
|
||||||
config_service = MemoryConfigService(db)
|
# result = await strategy.search(
|
||||||
embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
|
# query_text=query_text,
|
||||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
# end_user_id=end_user_id,
|
||||||
embedder_client = OpenAIEmbedderClient(embedder_config)
|
# limit=limit,
|
||||||
|
# include=include,
|
||||||
try:
|
# alpha=alpha,
|
||||||
# 根据搜索类型选择策略
|
# use_forgetting_curve=use_forgetting_curve,
|
||||||
if search_type == "keyword":
|
# **kwargs
|
||||||
strategy = KeywordSearchStrategy(connector=connector)
|
# )
|
||||||
elif search_type == "semantic":
|
#
|
||||||
strategy = SemanticSearchStrategy(
|
# result_dict = result.to_dict()
|
||||||
connector=connector,
|
#
|
||||||
embedder_client=embedder_client
|
# output_path = kwargs.get('output_path', 'search_results.json')
|
||||||
)
|
# if output_path:
|
||||||
else: # hybrid
|
# import json
|
||||||
strategy = HybridSearchStrategy(
|
# import os
|
||||||
connector=connector,
|
# from datetime import datetime
|
||||||
embedder_client=embedder_client,
|
#
|
||||||
alpha=alpha,
|
# try:
|
||||||
use_forgetting_curve=use_forgetting_curve
|
# out_dir = os.path.dirname(output_path)
|
||||||
)
|
# if out_dir:
|
||||||
|
# os.makedirs(out_dir, exist_ok=True)
|
||||||
# 执行搜索
|
# with open(output_path, "w", encoding="utf-8") as f:
|
||||||
result = await strategy.search(
|
# json.dump(result_dict, f, ensure_ascii=False, indent=2, default=str)
|
||||||
query_text=query_text,
|
# print(f"Search results saved to {output_path}")
|
||||||
end_user_id=end_user_id,
|
# except Exception as e:
|
||||||
limit=limit,
|
# print(f"Error saving search results: {e}")
|
||||||
include=include,
|
# return result_dict
|
||||||
alpha=alpha,
|
#
|
||||||
use_forgetting_curve=use_forgetting_curve,
|
# finally:
|
||||||
**kwargs
|
# await connector.close()
|
||||||
)
|
#
|
||||||
|
# __all__.append("run_hybrid_search")
|
||||||
# 转换为旧格式
|
|
||||||
result_dict = result.to_dict()
|
|
||||||
|
|
||||||
# 保存到文件(如果指定了output_path)
|
|
||||||
output_path = kwargs.get('output_path', 'search_results.json')
|
|
||||||
if output_path:
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 确保目录存在
|
|
||||||
out_dir = os.path.dirname(output_path)
|
|
||||||
if out_dir:
|
|
||||||
os.makedirs(out_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# 保存结果
|
|
||||||
with open(output_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(result_dict, f, ensure_ascii=False, indent=2, default=str)
|
|
||||||
print(f"Search results saved to {output_path}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error saving search results: {e}")
|
|
||||||
return result_dict
|
|
||||||
|
|
||||||
finally:
|
|
||||||
await connector.close()
|
|
||||||
|
|
||||||
|
|
||||||
__all__.append("run_hybrid_search")
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
===Task===
|
===Task===
|
||||||
Extract user metadata from the following conversation statements spoken by the user.
|
Extract user metadata changes from the following conversation statements spoken by the user.
|
||||||
|
|
||||||
{% if language == "zh" %}
|
{% if language == "zh" %}
|
||||||
**"三度原则"判断标准:**
|
**"三度原则"判断标准:**
|
||||||
@@ -10,28 +10,36 @@ Extract user metadata from the following conversation statements spoken by the u
|
|||||||
**提取规则:**
|
**提取规则:**
|
||||||
- **只提取关于"用户本人"的画像信息**,忽略用户提到的第三方人物(如朋友、同事、家人)的信息
|
- **只提取关于"用户本人"的画像信息**,忽略用户提到的第三方人物(如朋友、同事、家人)的信息
|
||||||
- 仅提取文本中明确提到的信息,不要推测
|
- 仅提取文本中明确提到的信息,不要推测
|
||||||
- 如果文本中没有可提取的用户画像信息,返回空的 user_metadata 对象
|
|
||||||
- **输出语言必须与输入文本的语言一致**(输入中文则输出中文值,输入英文则输出英文值)
|
- **输出语言必须与输入文本的语言一致**(输入中文则输出中文值,输入英文则输出英文值)
|
||||||
|
|
||||||
|
**增量模式(重要):**
|
||||||
|
你只需要输出**本次对话引起的变更操作**,不要输出完整的元数据。每个变更是一个对象,包含:
|
||||||
|
- `field_path`:字段路径,用点号分隔(如 `profile.role`、`profile.expertise`)
|
||||||
|
- `action`:操作类型
|
||||||
|
* `set`:新增或修改一个字段的值
|
||||||
|
* `remove`:移除一个字段的值
|
||||||
|
- `value`:字段的新值(`action="set"` 时必填,`action="remove"` 时填要移除的元素值)
|
||||||
|
* 所有字段均为列表类型,每个元素一条变更记录
|
||||||
|
|
||||||
|
**判断规则:**
|
||||||
|
- 用户提到新信息 → `action="set"`,填入新值
|
||||||
|
- 用户明确否定已有信息(如"我不再做老师了"、"我已经不学Python了")→ `action="remove"`,`value` 填要移除的元素值
|
||||||
|
- 如果本次对话没有任何可提取的变更,返回空的 `metadata_changes` 数组 `[]`
|
||||||
|
- **不要为未被提及的字段生成任何变更操作**
|
||||||
|
|
||||||
{% if existing_metadata %}
|
{% if existing_metadata %}
|
||||||
**重要:合并已有元数据**
|
**已有元数据(仅供参考,用于判断是否需要变更):**
|
||||||
下方提供了数据库中已有的用户元数据。请结合用户最新发言,输出**合并后的完整元数据**:
|
请对比已有数据和用户最新发言,只输出差异部分的变更操作。
|
||||||
- 如果用户明确否定了已有信息(如"我不再教高中物理了"),在输出中**移除**该信息
|
- 如果用户说的信息和已有数据一致,不需要输出变更
|
||||||
- 如果用户提到了新信息,**添加**到对应字段中
|
- 如果用户否定了已有数据中的某个值,输出 `remove` 操作
|
||||||
- 如果已有信息未被用户否定,**保留**在输出中
|
- 如果用户提到了新信息,输出 `set` 操作
|
||||||
- 标量字段(如 role、domain):如果用户提到了新值,用新值替换;否则保留已有值
|
|
||||||
- 最终输出应该是完整的、合并后的元数据,不是增量
|
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
**字段说明:**
|
**字段说明:**
|
||||||
- profile.role:用户的职业或角色,如 教师、医生、后端工程师
|
- profile.role:用户的职业或角色(列表),如 教师、医生、后端工程师,一个人可以有多个角色
|
||||||
- profile.domain:用户所在领域,如 教育、医疗、软件开发
|
- profile.domain:用户所在领域(列表),如 教育、医疗、软件开发,一个人可以涉及多个领域
|
||||||
- profile.expertise:用户擅长的技能或工具(通用,不限于编程),如 Python、心理咨询、高中物理
|
- profile.expertise:用户擅长的技能或工具(列表),如 Python、心理咨询、高中物理
|
||||||
- profile.interests:用户主动表达兴趣的话题或领域标签
|
- profile.interests:用户主动表达兴趣的话题或领域标签(列表)
|
||||||
- behavioral_hints.learning_stage:学习阶段(初学者/中级/高级)
|
|
||||||
- behavioral_hints.preferred_depth:偏好深度(概览/技术细节/深入探讨)
|
|
||||||
- behavioral_hints.tone_preference:语气偏好(轻松随意/专业简洁/学术严谨)
|
|
||||||
- knowledge_tags:用户涉及的知识领域标签
|
|
||||||
|
|
||||||
**用户别名变更(增量模式):**
|
**用户别名变更(增量模式):**
|
||||||
- **aliases_to_add**:本次新发现的用户别名,包括:
|
- **aliases_to_add**:本次新发现的用户别名,包括:
|
||||||
@@ -43,7 +51,6 @@ Extract user metadata from the following conversation statements spoken by the u
|
|||||||
- **aliases_to_remove**:用户明确否认的别名,包括:
|
- **aliases_to_remove**:用户明确否认的别名,包括:
|
||||||
* 用户说"我不叫XX了"、"别叫我XX"、"我改名了,不叫XX" → 将 XX 放入此数组
|
* 用户说"我不叫XX了"、"别叫我XX"、"我改名了,不叫XX" → 将 XX 放入此数组
|
||||||
* **严格限制**:只将用户原文中**逐字提到**的被否认名字放入,不要推断关联的其他别名
|
* **严格限制**:只将用户原文中**逐字提到**的被否认名字放入,不要推断关联的其他别名
|
||||||
* 例如:用户说"我不叫陈小刀了" → 只移除"陈小刀",不要移除"陈哥"、"老陈"等未被提及的别名
|
|
||||||
* 如果没有要移除的别名,返回空数组 `[]`
|
* 如果没有要移除的别名,返回空数组 `[]`
|
||||||
{% if existing_aliases %}
|
{% if existing_aliases %}
|
||||||
- 已有别名:{{ existing_aliases | tojson }}(仅供参考,不需要在输出中重复)
|
- 已有别名:{{ existing_aliases | tojson }}(仅供参考,不需要在输出中重复)
|
||||||
@@ -57,28 +64,36 @@ Extract user metadata from the following conversation statements spoken by the u
|
|||||||
**Extraction rules:**
|
**Extraction rules:**
|
||||||
- **Only extract profile information about the user themselves**, ignore information about third parties (friends, colleagues, family) mentioned by the user
|
- **Only extract profile information about the user themselves**, ignore information about third parties (friends, colleagues, family) mentioned by the user
|
||||||
- Only extract information explicitly mentioned in the text, do not speculate
|
- Only extract information explicitly mentioned in the text, do not speculate
|
||||||
- If no user profile information can be extracted, return an empty user_metadata object
|
|
||||||
- **Output language must match the input text language**
|
- **Output language must match the input text language**
|
||||||
|
|
||||||
|
**Incremental mode (important):**
|
||||||
|
You should only output **the change operations caused by this conversation**, not the complete metadata. Each change is an object containing:
|
||||||
|
- `field_path`: Field path separated by dots (e.g. `profile.role`, `profile.expertise`)
|
||||||
|
- `action`: Operation type
|
||||||
|
* `set`: Add or update a field value
|
||||||
|
* `remove`: Remove a field value
|
||||||
|
- `value`: The new value for the field (required when `action="set"`, for `action="remove"` fill in the element value to remove)
|
||||||
|
* All fields are list types, one change record per element
|
||||||
|
|
||||||
|
**Decision rules:**
|
||||||
|
- User mentions new information → `action="set"`, fill in the new value
|
||||||
|
- User explicitly negates existing info (e.g. "I'm no longer a teacher", "I stopped learning Python") → `action="remove"`, `value` is the element to remove
|
||||||
|
- If this conversation has no extractable changes, return an empty `metadata_changes` array `[]`
|
||||||
|
- **Do NOT generate any change operations for fields not mentioned in the conversation**
|
||||||
|
|
||||||
{% if existing_metadata %}
|
{% if existing_metadata %}
|
||||||
**Important: Merge with existing metadata**
|
**Existing metadata (for reference only, to determine if changes are needed):**
|
||||||
Existing user metadata from the database is provided below. Combine with the user's latest statements to output the **complete merged metadata**:
|
Compare existing data with the user's latest statements, and only output change operations for the differences.
|
||||||
- If the user explicitly negates existing info (e.g. "I no longer teach high school physics"), **remove** it from output
|
- If the user's statement matches existing data, no change is needed
|
||||||
- If the user mentions new info, **add** it to the corresponding field
|
- If the user negates a value in existing data, output a `remove` operation
|
||||||
- If existing info is not negated by the user, **keep** it in the output
|
- If the user mentions new information, output a `set` operation
|
||||||
- Scalar fields (e.g. role, domain): replace with new value if user mentions one; otherwise keep existing
|
|
||||||
- The final output should be the complete, merged metadata — not an incremental update
|
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
**Field descriptions:**
|
**Field descriptions:**
|
||||||
- profile.role: User's occupation or role, e.g. teacher, doctor, software engineer
|
- profile.role: User's occupation or role (list), e.g. teacher, doctor, software engineer. A person can have multiple roles
|
||||||
- profile.domain: User's domain, e.g. education, healthcare, software development
|
- profile.domain: User's domain (list), e.g. education, healthcare, software development. A person can span multiple domains
|
||||||
- profile.expertise: User's skills or tools (general, not limited to programming)
|
- profile.expertise: User's skills or tools (list), e.g. Python, counseling, physics
|
||||||
- profile.interests: Topics or domain tags the user actively expressed interest in
|
- profile.interests: Topics or domain tags the user actively expressed interest in (list)
|
||||||
- behavioral_hints.learning_stage: Learning stage (beginner/intermediate/advanced)
|
|
||||||
- behavioral_hints.preferred_depth: Preferred depth (overview/detailed/deep dive)
|
|
||||||
- behavioral_hints.tone_preference: Tone preference (casual/professional/academic)
|
|
||||||
- knowledge_tags: Knowledge domain tags related to the user
|
|
||||||
|
|
||||||
**User alias changes (incremental mode):**
|
**User alias changes (incremental mode):**
|
||||||
- **aliases_to_add**: Newly discovered user aliases from this conversation, including:
|
- **aliases_to_add**: Newly discovered user aliases from this conversation, including:
|
||||||
@@ -90,7 +105,6 @@ Existing user metadata from the database is provided below. Combine with the use
|
|||||||
- **aliases_to_remove**: Aliases the user explicitly denies, including:
|
- **aliases_to_remove**: Aliases the user explicitly denies, including:
|
||||||
* User says "Don't call me XX anymore", "I'm not called XX", "I changed my name from XX" → put XX in this array
|
* User says "Don't call me XX anymore", "I'm not called XX", "I changed my name from XX" → put XX in this array
|
||||||
* **Strict rule**: Only include the exact name the user **verbatim mentions** as denied. Do NOT infer or remove related aliases
|
* **Strict rule**: Only include the exact name the user **verbatim mentions** as denied. Do NOT infer or remove related aliases
|
||||||
* Example: User says "I'm not called John anymore" → only remove "John", do NOT remove "Johnny", "J" or other related aliases not mentioned
|
|
||||||
* If no aliases to remove, return empty array `[]`
|
* If no aliases to remove, return empty array `[]`
|
||||||
{% if existing_aliases %}
|
{% if existing_aliases %}
|
||||||
- Existing aliases: {{ existing_aliases | tojson }} (for reference only, do not repeat in output)
|
- Existing aliases: {{ existing_aliases | tojson }} (for reference only, do not repeat in output)
|
||||||
@@ -113,20 +127,11 @@ Existing user metadata from the database is provided below. Combine with the use
|
|||||||
Return a JSON object with the following structure:
|
Return a JSON object with the following structure:
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"user_metadata": {
|
"metadata_changes": [
|
||||||
"profile": {
|
{"field_path": "profile.role", "action": "set", "value": "后端工程师"},
|
||||||
"role": "",
|
{"field_path": "profile.expertise", "action": "set", "value": "Python"},
|
||||||
"domain": "",
|
{"field_path": "profile.expertise", "action": "remove", "value": "Java"}
|
||||||
"expertise": [],
|
],
|
||||||
"interests": []
|
|
||||||
},
|
|
||||||
"behavioral_hints": {
|
|
||||||
"learning_stage": "",
|
|
||||||
"preferred_depth": "",
|
|
||||||
"tone_preference": ""
|
|
||||||
},
|
|
||||||
"knowledge_tags": []
|
|
||||||
},
|
|
||||||
"aliases_to_add": [],
|
"aliases_to_add": [],
|
||||||
"aliases_to_remove": []
|
"aliases_to_remove": []
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, Optional, TypeVar
|
from typing import Any, Dict, List, Optional, TypeVar
|
||||||
|
|
||||||
from langchain_aws import ChatBedrock
|
from langchain_aws import ChatBedrock
|
||||||
from langchain_community.chat_models import ChatTongyi
|
from langchain_community.chat_models import ChatTongyi
|
||||||
@@ -9,12 +9,12 @@ from langchain_core.embeddings import Embeddings
|
|||||||
from langchain_core.language_models import BaseLLM
|
from langchain_core.language_models import BaseLLM
|
||||||
from langchain_ollama import OllamaLLM
|
from langchain_ollama import OllamaLLM
|
||||||
from langchain_openai import ChatOpenAI, OpenAI
|
from langchain_openai import ChatOpenAI, OpenAI
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.models.models_model import ModelProvider, ModelType
|
from app.models.models_model import ModelProvider, ModelType
|
||||||
from app.core.models.volcano_chat import VolcanoChatOpenAI
|
from app.core.models.compatible_chat import CompatibleChatOpenAI
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
@@ -25,10 +25,11 @@ class RedBearModelConfig(BaseModel):
|
|||||||
provider: str
|
provider: str
|
||||||
api_key: str
|
api_key: str
|
||||||
base_url: Optional[str] = None
|
base_url: Optional[str] = None
|
||||||
|
capability: List[str] = Field(default_factory=list) # 模型能力列表,驱动所有能力开关
|
||||||
is_omni: bool = False # 是否为 Omni 模型
|
is_omni: bool = False # 是否为 Omni 模型
|
||||||
deep_thinking: bool = False # 是否启用深度思考模式
|
deep_thinking: bool = False # 是否启用深度思考模式
|
||||||
thinking_budget_tokens: Optional[int] = None # 深度思考 token 预算
|
thinking_budget_tokens: Optional[int] = None # 深度思考 token 预算
|
||||||
support_thinking: bool = False # 模型是否支持 enable_thinking 参数(capability 含 thinking)
|
json_output: bool = False # 是否强制 JSON 输出
|
||||||
# 请求超时时间(秒)- 默认120秒以支持复杂的LLM调用,可通过环境变量 LLM_TIMEOUT 配置
|
# 请求超时时间(秒)- 默认120秒以支持复杂的LLM调用,可通过环境变量 LLM_TIMEOUT 配置
|
||||||
timeout: float = Field(default_factory=lambda: float(os.getenv("LLM_TIMEOUT", "120.0")))
|
timeout: float = Field(default_factory=lambda: float(os.getenv("LLM_TIMEOUT", "120.0")))
|
||||||
# 最大重试次数 - 默认2次以避免过长等待,可通过环境变量 LLM_MAX_RETRIES 配置
|
# 最大重试次数 - 默认2次以避免过长等待,可通过环境变量 LLM_MAX_RETRIES 配置
|
||||||
@@ -36,6 +37,23 @@ class RedBearModelConfig(BaseModel):
|
|||||||
concurrency: int = 5 # 并发限流
|
concurrency: int = 5 # 并发限流
|
||||||
extra_params: Dict[str, Any] = {}
|
extra_params: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def _resolve_capabilities(self) -> "RedBearModelConfig":
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
logger = get_business_logger()
|
||||||
|
if self.deep_thinking and "thinking" not in self.capability:
|
||||||
|
logger.warning(
|
||||||
|
f"模型 {self.model_name} 不支持深度思考(capability 中无 'thinking'),已自动关闭 deep_thinking"
|
||||||
|
)
|
||||||
|
self.deep_thinking = False
|
||||||
|
self.thinking_budget_tokens = None
|
||||||
|
if self.json_output and "json_output" not in self.capability:
|
||||||
|
logger.warning(
|
||||||
|
f"模型 {self.model_name} 不支持 JSON 输出(capability 中无 'json_output'),已自动关闭 json_output"
|
||||||
|
)
|
||||||
|
self.json_output = False
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class RedBearModelFactory:
|
class RedBearModelFactory:
|
||||||
"""模型工厂类"""
|
"""模型工厂类"""
|
||||||
@@ -74,18 +92,19 @@ class RedBearModelFactory:
|
|||||||
is_streaming = bool(config.extra_params.get("streaming"))
|
is_streaming = bool(config.extra_params.get("streaming"))
|
||||||
if is_streaming:
|
if is_streaming:
|
||||||
params["stream_usage"] = True
|
params["stream_usage"] = True
|
||||||
# 只有支持 thinking 的模型才传 enable_thinking
|
# 支持 thinking 的模型始终传 enable_thinking,关闭时显式传 False 避免模型默认开启思考
|
||||||
if config.support_thinking:
|
if "thinking" in config.capability:
|
||||||
model_kwargs: Dict[str, Any] = config.extra_params.get("model_kwargs", {})
|
extra_body = params.setdefault("extra_body", {})
|
||||||
if is_streaming:
|
if config.deep_thinking:
|
||||||
model_kwargs["enable_thinking"] = config.deep_thinking
|
extra_body["enable_thinking"] = False
|
||||||
if config.deep_thinking:
|
if is_streaming:
|
||||||
model_kwargs["incremental_output"] = True
|
extra_body["enable_thinking"] = True
|
||||||
if config.thinking_budget_tokens:
|
if config.thinking_budget_tokens:
|
||||||
model_kwargs["thinking_budget"] = config.thinking_budget_tokens
|
extra_body["thinking_budget"] = config.thinking_budget_tokens
|
||||||
else:
|
# JSON 输出模式
|
||||||
model_kwargs["enable_thinking"] = False
|
if config.json_output:
|
||||||
params["model_kwargs"] = model_kwargs
|
model_kwargs = params.setdefault("model_kwargs", {})
|
||||||
|
model_kwargs["response_format"] = {"type": "json_object"}
|
||||||
return params
|
return params
|
||||||
|
|
||||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]:
|
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]:
|
||||||
@@ -108,26 +127,31 @@ class RedBearModelFactory:
|
|||||||
**config.extra_params
|
**config.extra_params
|
||||||
}
|
}
|
||||||
# 流式模式下启用 stream_usage 以获取 token 统计
|
# 流式模式下启用 stream_usage 以获取 token 统计
|
||||||
if config.extra_params.get("streaming"):
|
|
||||||
params["stream_usage"] = True
|
|
||||||
# 深度思考模式
|
|
||||||
is_streaming = bool(config.extra_params.get("streaming"))
|
is_streaming = bool(config.extra_params.get("streaming"))
|
||||||
if is_streaming and not config.is_omni:
|
if is_streaming:
|
||||||
|
params["stream_usage"] = True
|
||||||
|
# 支持 thinking 的模型始终传 enable_thinking,关闭时显式传 False 避免模型默认开启思考
|
||||||
|
if "thinking" in config.capability:
|
||||||
|
# VOLCANO 深度思考仅流式支持
|
||||||
if provider == ModelProvider.VOLCANO:
|
if provider == ModelProvider.VOLCANO:
|
||||||
# 火山引擎深度思考仅流式调用支持,非流式时不传 thinking 参数
|
thinking_config: Dict[str, Any] = {"type": "enabled" if config.deep_thinking else "disabled"}
|
||||||
thinking_config: Dict[str, Any] = {
|
|
||||||
"type": "enabled" if config.deep_thinking else "disabled"
|
|
||||||
}
|
|
||||||
if config.deep_thinking and config.thinking_budget_tokens:
|
if config.deep_thinking and config.thinking_budget_tokens:
|
||||||
thinking_config["budget_tokens"] = config.thinking_budget_tokens
|
thinking_config["budget_tokens"] = config.thinking_budget_tokens
|
||||||
params["extra_body"] = {"thinking": thinking_config}
|
params["extra_body"] = {"thinking": thinking_config}
|
||||||
else:
|
else:
|
||||||
# 始终显式传递 enable_thinking,不支持该参数的模型(如 DeepSeek-R1)会直接忽略
|
extra_body = params.setdefault("extra_body", {})
|
||||||
model_kwargs: Dict[str, Any] = config.extra_params.get("model_kwargs", {})
|
if config.deep_thinking:
|
||||||
model_kwargs["enable_thinking"] = config.deep_thinking
|
extra_body["enable_thinking"] = False
|
||||||
if config.deep_thinking and config.thinking_budget_tokens:
|
if is_streaming:
|
||||||
model_kwargs["thinking_budget"] = config.thinking_budget_tokens
|
extra_body["enable_thinking"] = True
|
||||||
params["model_kwargs"] = model_kwargs
|
if config.thinking_budget_tokens:
|
||||||
|
extra_body["thinking_budget"] = config.thinking_budget_tokens
|
||||||
|
# JSON 输出模式
|
||||||
|
if config.json_output:
|
||||||
|
model_kwargs = params.setdefault("model_kwargs", {})
|
||||||
|
# VOLCANO 模型不支持 response_format,JSON 输出由 system prompt 注入实现
|
||||||
|
if provider != ModelProvider.VOLCANO:
|
||||||
|
model_kwargs["response_format"] = {"type": "json_object"}
|
||||||
return params
|
return params
|
||||||
elif provider == ModelProvider.DASHSCOPE:
|
elif provider == ModelProvider.DASHSCOPE:
|
||||||
params = {
|
params = {
|
||||||
@@ -136,19 +160,20 @@ class RedBearModelFactory:
|
|||||||
"max_retries": config.max_retries,
|
"max_retries": config.max_retries,
|
||||||
**config.extra_params
|
**config.extra_params
|
||||||
}
|
}
|
||||||
# 只有支持 thinking 的模型才传 enable_thinking
|
# 支持 thinking 的模型始终传 enable_thinking,关闭时显式传 False 避免模型默认开启思考
|
||||||
if config.support_thinking:
|
if "thinking" in config.capability:
|
||||||
is_streaming = bool(config.extra_params.get("streaming"))
|
is_streaming = bool(config.extra_params.get("streaming"))
|
||||||
model_kwargs: Dict[str, Any] = config.extra_params.get("model_kwargs", {})
|
model_kwargs = params.setdefault("model_kwargs", {})
|
||||||
if is_streaming:
|
if config.deep_thinking:
|
||||||
model_kwargs["enable_thinking"] = config.deep_thinking
|
|
||||||
if config.deep_thinking:
|
|
||||||
model_kwargs["incremental_output"] = True
|
|
||||||
if config.thinking_budget_tokens:
|
|
||||||
model_kwargs["thinking_budget"] = config.thinking_budget_tokens
|
|
||||||
else:
|
|
||||||
model_kwargs["enable_thinking"] = False
|
model_kwargs["enable_thinking"] = False
|
||||||
params["model_kwargs"] = model_kwargs
|
if is_streaming:
|
||||||
|
model_kwargs["enable_thinking"] = True
|
||||||
|
model_kwargs["incremental_output"] = True
|
||||||
|
if config.thinking_budget_tokens:
|
||||||
|
model_kwargs["thinking_budget"] = config.thinking_budget_tokens
|
||||||
|
if config.json_output:
|
||||||
|
model_kwargs = params.setdefault("model_kwargs", {})
|
||||||
|
model_kwargs["response_format"] = {"type": "json_object"}
|
||||||
return params
|
return params
|
||||||
elif provider == ModelProvider.BEDROCK:
|
elif provider == ModelProvider.BEDROCK:
|
||||||
# Bedrock 使用 AWS 凭证
|
# Bedrock 使用 AWS 凭证
|
||||||
@@ -195,6 +220,10 @@ class RedBearModelFactory:
|
|||||||
params["additional_model_request_fields"] = {
|
params["additional_model_request_fields"] = {
|
||||||
"thinking": {"type": "enabled", "budget_tokens": budget}
|
"thinking": {"type": "enabled", "budget_tokens": budget}
|
||||||
}
|
}
|
||||||
|
# JSON 输出模式
|
||||||
|
if config.json_output:
|
||||||
|
model_kwargs = params.setdefault("model_kwargs", {})
|
||||||
|
model_kwargs["response_format"] = {"type": "json_object"}
|
||||||
return params
|
return params
|
||||||
else:
|
else:
|
||||||
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||||
@@ -223,18 +252,19 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy
|
|||||||
"""根据模型提供商获取对应的模型类"""
|
"""根据模型提供商获取对应的模型类"""
|
||||||
provider = config.provider.lower()
|
provider = config.provider.lower()
|
||||||
|
|
||||||
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
# dashscope的omni模型 和 volcano模型使用
|
||||||
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
||||||
return ChatOpenAI
|
return CompatibleChatOpenAI
|
||||||
if provider == ModelProvider.VOLCANO:
|
if provider == ModelProvider.VOLCANO:
|
||||||
return VolcanoChatOpenAI
|
return CompatibleChatOpenAI
|
||||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||||
if type == ModelType.LLM:
|
return CompatibleChatOpenAI
|
||||||
return OpenAI
|
# if type == ModelType.LLM:
|
||||||
elif type == ModelType.CHAT:
|
# return OpenAI
|
||||||
return ChatOpenAI
|
# elif type == ModelType.CHAT:
|
||||||
else:
|
# return CompatibleChatOpenAI
|
||||||
raise BusinessException(f"不支持的模型提供商及类型: {provider}-{type}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
# else:
|
||||||
|
# raise BusinessException(f"不支持的模型提供商及类型: {provider}-{type}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||||
elif provider == ModelProvider.DASHSCOPE:
|
elif provider == ModelProvider.DASHSCOPE:
|
||||||
return ChatTongyi
|
return ChatTongyi
|
||||||
elif provider == ModelProvider.OLLAMA:
|
elif provider == ModelProvider.OLLAMA:
|
||||||
|
|||||||
@@ -8,12 +8,33 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
from langchain_core.messages import BaseMessage
|
||||||
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
|
||||||
class VolcanoChatOpenAI(ChatOpenAI):
|
class CompatibleChatOpenAI(ChatOpenAI):
|
||||||
"""火山引擎 Chat 模型,支持深度思考内容(reasoning_content)的流式和非流式透传。"""
|
"""火山和千问的omni兼容模型,支持深度思考内容(reasoning_content)的流式和非流式透传。
|
||||||
|
|
||||||
|
同时修复 json_output + tools 同时使用时 langchain_openai 强制走 .parse()/.stream()
|
||||||
|
导致 strict 校验报错的问题:有工具时从 payload 中移除 response_format,
|
||||||
|
让父类走普通 .create()/.astream() 路径,JSON 输出由 system prompt 指令保证。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_request_payload(
|
||||||
|
self,
|
||||||
|
input_: list[BaseMessage],
|
||||||
|
*,
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> dict:
|
||||||
|
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
||||||
|
# 有工具时 langchain_openai 检测到 response_format 会切换到 .parse()/.stream()
|
||||||
|
# 接口,OpenAI SDK 要求此时所有工具必须 strict=True,动态生成的工具不满足。
|
||||||
|
# 移除 response_format,让父类走普通路径,JSON 输出由 system prompt 指令保证。
|
||||||
|
if payload.get("tools") and "response_format" in payload:
|
||||||
|
payload.pop("response_format")
|
||||||
|
return payload
|
||||||
|
|
||||||
def _create_chat_result(self, response: Union[dict, Any], generation_info: Optional[dict] = None) -> ChatResult:
|
def _create_chat_result(self, response: Union[dict, Any], generation_info: Optional[dict] = None) -> ChatResult:
|
||||||
result = super()._create_chat_result(response, generation_info)
|
result = super()._create_chat_result(response, generation_info)
|
||||||
@@ -6,7 +6,8 @@ models:
|
|||||||
description: AI21 Labs大语言模型,completion生成模式,256000上下文窗口
|
description: AI21 Labs大语言模型,completion生成模式,256000上下文窗口
|
||||||
is_deprecated: false
|
is_deprecated: false
|
||||||
is_official: true
|
is_official: true
|
||||||
capability: []
|
capability:
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -20,6 +21,7 @@ models:
|
|||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- vision
|
- vision
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -38,6 +40,7 @@ models:
|
|||||||
capability:
|
capability:
|
||||||
- vision
|
- vision
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -54,7 +57,8 @@ models:
|
|||||||
description: Cohere大语言模型,支持智能体思考、工具调用、流式工具调用,128000上下文窗口,对话模式
|
description: Cohere大语言模型,支持智能体思考、工具调用、流式工具调用,128000上下文窗口,对话模式
|
||||||
is_deprecated: false
|
is_deprecated: false
|
||||||
is_official: true
|
is_official: true
|
||||||
capability: []
|
capability:
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -72,6 +76,7 @@ models:
|
|||||||
capability:
|
capability:
|
||||||
- vision
|
- vision
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -87,7 +92,8 @@ models:
|
|||||||
description: Meta Llama大语言模型,支持智能体思考、工具调用,128000上下文窗口,对话模式
|
description: Meta Llama大语言模型,支持智能体思考、工具调用,128000上下文窗口,对话模式
|
||||||
is_deprecated: false
|
is_deprecated: false
|
||||||
is_official: true
|
is_official: true
|
||||||
capability: []
|
capability:
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -101,7 +107,8 @@ models:
|
|||||||
description: Mistral AI大语言模型,支持智能体思考、工具调用,32000上下文窗口,对话模式
|
description: Mistral AI大语言模型,支持智能体思考、工具调用,32000上下文窗口,对话模式
|
||||||
is_deprecated: false
|
is_deprecated: false
|
||||||
is_official: true
|
is_official: true
|
||||||
capability: []
|
capability:
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -115,7 +122,8 @@ models:
|
|||||||
description: OpenAI大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式
|
description: OpenAI大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式
|
||||||
is_deprecated: false
|
is_deprecated: false
|
||||||
is_official: true
|
is_official: true
|
||||||
capability: []
|
capability:
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -130,7 +138,8 @@ models:
|
|||||||
description: Qwen大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式
|
description: Qwen大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式
|
||||||
is_deprecated: false
|
is_deprecated: false
|
||||||
is_official: true
|
is_official: true
|
||||||
capability: []
|
capability:
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ models:
|
|||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -22,6 +23,7 @@ models:
|
|||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -36,6 +38,7 @@ models:
|
|||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -48,7 +51,8 @@ models:
|
|||||||
description: DeepSeek-V3.1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
description: DeepSeek-V3.1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||||
is_deprecated: false
|
is_deprecated: false
|
||||||
is_official: true
|
is_official: true
|
||||||
capability: []
|
capability:
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -61,7 +65,8 @@ models:
|
|||||||
description: DeepSeek-V3.2-exp实验版大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
description: DeepSeek-V3.2-exp实验版大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||||
is_deprecated: false
|
is_deprecated: false
|
||||||
is_official: true
|
is_official: true
|
||||||
capability: []
|
capability:
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -74,7 +79,8 @@ models:
|
|||||||
description: DeepSeek-V3.2大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
description: DeepSeek-V3.2大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||||
is_deprecated: false
|
is_deprecated: false
|
||||||
is_official: true
|
is_official: true
|
||||||
capability: []
|
capability:
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -87,7 +93,8 @@ models:
|
|||||||
description: DeepSeek-V3大语言模型,支持智能体思考,64000上下文窗口,对话模式,支持文本与JSON格式输出
|
description: DeepSeek-V3大语言模型,支持智能体思考,64000上下文窗口,对话模式,支持文本与JSON格式输出
|
||||||
is_deprecated: false
|
is_deprecated: false
|
||||||
is_official: true
|
is_official: true
|
||||||
capability: []
|
capability:
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -100,7 +107,8 @@ models:
|
|||||||
description: farui-plus大语言模型,支持多工具调用、智能体思考、流式工具调用,12288上下文窗口,对话模式
|
description: farui-plus大语言模型,支持多工具调用、智能体思考、流式工具调用,12288上下文窗口,对话模式
|
||||||
is_deprecated: false
|
is_deprecated: false
|
||||||
is_official: true
|
is_official: true
|
||||||
capability: []
|
capability:
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -115,7 +123,8 @@ models:
|
|||||||
description: GLM-4.7大语言模型,支持多工具调用、智能体思考、流式工具调用,202752超大上下文窗口,对话模式
|
description: GLM-4.7大语言模型,支持多工具调用、智能体思考、流式工具调用,202752超大上下文窗口,对话模式
|
||||||
is_deprecated: false
|
is_deprecated: false
|
||||||
is_official: true
|
is_official: true
|
||||||
capability: []
|
capability:
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -133,6 +142,7 @@ models:
|
|||||||
capability:
|
capability:
|
||||||
- vision
|
- vision
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -150,6 +160,7 @@ models:
|
|||||||
capability:
|
capability:
|
||||||
- vision
|
- vision
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -180,6 +191,7 @@ models:
|
|||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -210,7 +222,7 @@ models:
|
|||||||
is_deprecated: false
|
is_deprecated: false
|
||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- thinking
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -376,6 +388,7 @@ models:
|
|||||||
capability:
|
capability:
|
||||||
- vision
|
- vision
|
||||||
- video
|
- video
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -448,6 +461,7 @@ models:
|
|||||||
capability:
|
capability:
|
||||||
- vision
|
- vision
|
||||||
- video
|
- video
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -466,6 +480,7 @@ models:
|
|||||||
capability:
|
capability:
|
||||||
- vision
|
- vision
|
||||||
- video
|
- video
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -481,7 +496,8 @@ models:
|
|||||||
description: qwen2.5-0.5b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,未废弃
|
description: qwen2.5-0.5b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,未废弃
|
||||||
is_deprecated: false
|
is_deprecated: false
|
||||||
is_official: true
|
is_official: true
|
||||||
capability: []
|
capability:
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -498,6 +514,7 @@ models:
|
|||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -513,7 +530,7 @@ models:
|
|||||||
is_deprecated: false
|
is_deprecated: false
|
||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- thinking
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -530,6 +547,7 @@ models:
|
|||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -546,6 +564,7 @@ models:
|
|||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -561,7 +580,7 @@ models:
|
|||||||
is_deprecated: false
|
is_deprecated: false
|
||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- thinking
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -578,6 +597,7 @@ models:
|
|||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -594,6 +614,7 @@ models:
|
|||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -610,6 +631,7 @@ models:
|
|||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -626,6 +648,7 @@ models:
|
|||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -641,7 +664,7 @@ models:
|
|||||||
is_deprecated: false
|
is_deprecated: false
|
||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- thinking
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -656,7 +679,7 @@ models:
|
|||||||
is_deprecated: false
|
is_deprecated: false
|
||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- thinking
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -672,6 +695,7 @@ models:
|
|||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -687,6 +711,7 @@ models:
|
|||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -702,6 +727,7 @@ models:
|
|||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -719,6 +745,7 @@ models:
|
|||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -736,6 +763,7 @@ models:
|
|||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -752,6 +780,7 @@ models:
|
|||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -768,7 +797,7 @@ models:
|
|||||||
is_deprecated: false
|
is_deprecated: false
|
||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- thinking
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -785,6 +814,7 @@ models:
|
|||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -803,6 +833,8 @@ models:
|
|||||||
- vision
|
- vision
|
||||||
- video
|
- video
|
||||||
- audio
|
- audio
|
||||||
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: true
|
is_omni: true
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -822,7 +854,7 @@ models:
|
|||||||
capability:
|
capability:
|
||||||
- vision
|
- vision
|
||||||
- video
|
- video
|
||||||
- thinking
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -844,6 +876,7 @@ models:
|
|||||||
- vision
|
- vision
|
||||||
- video
|
- video
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -864,7 +897,7 @@ models:
|
|||||||
capability:
|
capability:
|
||||||
- vision
|
- vision
|
||||||
- video
|
- video
|
||||||
- thinking
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -886,6 +919,7 @@ models:
|
|||||||
- vision
|
- vision
|
||||||
- video
|
- video
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -907,6 +941,7 @@ models:
|
|||||||
- vision
|
- vision
|
||||||
- video
|
- video
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -928,6 +963,7 @@ models:
|
|||||||
- vision
|
- vision
|
||||||
- video
|
- video
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -947,6 +983,7 @@ models:
|
|||||||
- vision
|
- vision
|
||||||
- video
|
- video
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -964,6 +1001,7 @@ models:
|
|||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -979,6 +1017,7 @@ models:
|
|||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -994,6 +1033,7 @@ models:
|
|||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ models:
|
|||||||
- vision
|
- vision
|
||||||
- audio
|
- audio
|
||||||
- video
|
- video
|
||||||
|
- json_output
|
||||||
is_omni: true
|
is_omni: true
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -27,7 +28,8 @@ models:
|
|||||||
description: gpt-3.5-turbo-0125大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
description: gpt-3.5-turbo-0125大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||||
is_deprecated: false
|
is_deprecated: false
|
||||||
is_official: true
|
is_official: true
|
||||||
capability: []
|
capability:
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -42,7 +44,8 @@ models:
|
|||||||
description: gpt-3.5-turbo-1106大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
description: gpt-3.5-turbo-1106大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||||
is_deprecated: false
|
is_deprecated: false
|
||||||
is_official: true
|
is_official: true
|
||||||
capability: []
|
capability:
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -57,7 +60,8 @@ models:
|
|||||||
description: gpt-3.5-turbo-16k大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
description: gpt-3.5-turbo-16k大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||||
is_deprecated: false
|
is_deprecated: false
|
||||||
is_official: true
|
is_official: true
|
||||||
capability: []
|
capability:
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -84,7 +88,8 @@ models:
|
|||||||
description: gpt-3.5-turbo大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
description: gpt-3.5-turbo大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||||
is_deprecated: false
|
is_deprecated: false
|
||||||
is_official: true
|
is_official: true
|
||||||
capability: []
|
capability:
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -99,7 +104,8 @@ models:
|
|||||||
description: gpt-4-0125-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
description: gpt-4-0125-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||||
is_deprecated: false
|
is_deprecated: false
|
||||||
is_official: true
|
is_official: true
|
||||||
capability: []
|
capability:
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -114,7 +120,8 @@ models:
|
|||||||
description: gpt-4-1106-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
description: gpt-4-1106-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||||
is_deprecated: false
|
is_deprecated: false
|
||||||
is_official: true
|
is_official: true
|
||||||
capability: []
|
capability:
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -131,6 +138,7 @@ models:
|
|||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- vision
|
- vision
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -146,7 +154,8 @@ models:
|
|||||||
description: gpt-4-turbo-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
description: gpt-4-turbo-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||||
is_deprecated: false
|
is_deprecated: false
|
||||||
is_official: true
|
is_official: true
|
||||||
capability: []
|
capability:
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -163,6 +172,7 @@ models:
|
|||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- vision
|
- vision
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -194,6 +204,7 @@ models:
|
|||||||
capability:
|
capability:
|
||||||
- vision
|
- vision
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -213,6 +224,7 @@ models:
|
|||||||
capability:
|
capability:
|
||||||
- vision
|
- vision
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -231,6 +243,7 @@ models:
|
|||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -248,6 +261,7 @@ models:
|
|||||||
is_official: true
|
is_official: true
|
||||||
capability:
|
capability:
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -266,6 +280,7 @@ models:
|
|||||||
capability:
|
capability:
|
||||||
- vision
|
- vision
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -284,6 +299,7 @@ models:
|
|||||||
capability:
|
capability:
|
||||||
- vision
|
- vision
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -302,6 +318,7 @@ models:
|
|||||||
capability:
|
capability:
|
||||||
- vision
|
- vision
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -321,6 +338,7 @@ models:
|
|||||||
capability:
|
capability:
|
||||||
- vision
|
- vision
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -340,6 +358,7 @@ models:
|
|||||||
capability:
|
capability:
|
||||||
- vision
|
- vision
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ models:
|
|||||||
- vision
|
- vision
|
||||||
- video
|
- video
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -26,6 +27,7 @@ models:
|
|||||||
- vision
|
- vision
|
||||||
- video
|
- video
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -41,6 +43,7 @@ models:
|
|||||||
- vision
|
- vision
|
||||||
- video
|
- video
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -56,6 +59,7 @@ models:
|
|||||||
- vision
|
- vision
|
||||||
- video
|
- video
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -72,6 +76,7 @@ models:
|
|||||||
capability:
|
capability:
|
||||||
- vision
|
- vision
|
||||||
- video
|
- video
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -87,6 +92,7 @@ models:
|
|||||||
- vision
|
- vision
|
||||||
- video
|
- video
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -102,6 +108,7 @@ models:
|
|||||||
- vision
|
- vision
|
||||||
- video
|
- video
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -117,6 +124,7 @@ models:
|
|||||||
- vision
|
- vision
|
||||||
- video
|
- video
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -132,6 +140,7 @@ models:
|
|||||||
- vision
|
- vision
|
||||||
- video
|
- video
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -148,6 +157,7 @@ models:
|
|||||||
- vision
|
- vision
|
||||||
- video
|
- video
|
||||||
- thinking
|
- thinking
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -175,7 +185,8 @@ models:
|
|||||||
description: 全新一代主力模型,性能全面升级,在知识、代码、推理等方面表现卓越。最大支持 128k 上下文窗口,输出长度支持最大 12k tokens。
|
description: 全新一代主力模型,性能全面升级,在知识、代码、推理等方面表现卓越。最大支持 128k 上下文窗口,输出长度支持最大 12k tokens。
|
||||||
is_deprecated: false
|
is_deprecated: false
|
||||||
is_official: true
|
is_official: true
|
||||||
capability: []
|
capability:
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
@@ -187,7 +198,8 @@ models:
|
|||||||
description: 全新一代轻量版模型,极致响应速度,效果与时延均达到全球一流水平。支持 32k 上下文窗口,输出长度支持最大 12k tokens。
|
description: 全新一代轻量版模型,极致响应速度,效果与时延均达到全球一流水平。支持 32k 上下文窗口,输出长度支持最大 12k tokens。
|
||||||
is_deprecated: false
|
is_deprecated: false
|
||||||
is_official: true
|
is_official: true
|
||||||
capability: []
|
capability:
|
||||||
|
- json_output
|
||||||
is_omni: false
|
is_omni: false
|
||||||
tags:
|
tags:
|
||||||
- 大语言模型
|
- 大语言模型
|
||||||
|
|||||||
485
api/app/core/quota_manager.py
Normal file
@@ -0,0 +1,485 @@
|
|||||||
|
"""
|
||||||
|
统一配额管理器 - 社区版和 SaaS 版共用
|
||||||
|
|
||||||
|
配额来源策略:
|
||||||
|
1. 优先从 premium 模块的 tenant_subscriptions 表读取(SaaS 版)
|
||||||
|
2. 降级到 default_free_plan.py 配置文件(社区版兜底)
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Optional, Callable, Dict, Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import func
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.logging_config import get_auth_logger
|
||||||
|
from app.i18n.exceptions import QuotaExceededError
|
||||||
|
|
||||||
|
logger = get_auth_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_user_from_kwargs(kwargs: dict):
|
||||||
|
"""从 kwargs 中获取 user 对象"""
|
||||||
|
for key in ["user", "current_user"]:
|
||||||
|
if key in kwargs:
|
||||||
|
return kwargs[key]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tenant_id_from_kwargs(db: Session, kwargs: dict):
|
||||||
|
"""从 kwargs 中获取 tenant_id"""
|
||||||
|
user = _get_user_from_kwargs(kwargs)
|
||||||
|
if user and hasattr(user, 'tenant_id'):
|
||||||
|
return user.tenant_id
|
||||||
|
|
||||||
|
workspace_id = kwargs.get("workspace_id")
|
||||||
|
if workspace_id:
|
||||||
|
from app.models.workspace_model import Workspace
|
||||||
|
workspace = db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||||
|
if workspace:
|
||||||
|
return workspace.tenant_id
|
||||||
|
|
||||||
|
api_key_auth = kwargs.get("api_key_auth")
|
||||||
|
if api_key_auth and hasattr(api_key_auth, 'workspace_id'):
|
||||||
|
from app.models.workspace_model import Workspace
|
||||||
|
workspace = db.query(Workspace).filter(Workspace.id == api_key_auth.workspace_id).first()
|
||||||
|
if workspace:
|
||||||
|
return workspace.tenant_id
|
||||||
|
|
||||||
|
data = kwargs.get("data") or kwargs.get("body") or kwargs.get("payload")
|
||||||
|
if data and hasattr(data, "workspace_id"):
|
||||||
|
from app.models.workspace_model import Workspace
|
||||||
|
workspace = db.query(Workspace).filter(Workspace.id == data.workspace_id).first()
|
||||||
|
if workspace:
|
||||||
|
return workspace.tenant_id
|
||||||
|
|
||||||
|
share_data = kwargs.get("share_data")
|
||||||
|
if share_data and hasattr(share_data, 'share_token'):
|
||||||
|
from app.models.workspace_model import Workspace
|
||||||
|
from app.models.app_model import App
|
||||||
|
share_token = share_data.share_token
|
||||||
|
from app.models.release_share_model import ReleaseShare
|
||||||
|
share_record = db.query(ReleaseShare).filter(ReleaseShare.share_token == share_token).first()
|
||||||
|
if share_record:
|
||||||
|
app = db.query(App).filter(App.id == share_record.app_id, App.is_active.is_(True)).first()
|
||||||
|
if app:
|
||||||
|
return app.workspace.tenant_id
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_quota_config(db: Session, tenant_id: UUID) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
获取租户的配额配置
|
||||||
|
|
||||||
|
优先级:
|
||||||
|
1. premium 模块的 tenant_subscriptions(SaaS 版)
|
||||||
|
2. default_free_plan.py 配置文件(社区版兜底)
|
||||||
|
"""
|
||||||
|
# 尝试从 premium 模块获取
|
||||||
|
try:
|
||||||
|
from premium.platform_admin.package_plan_service import TenantSubscriptionService
|
||||||
|
quota_config = TenantSubscriptionService(db).get_effective_quota(tenant_id)
|
||||||
|
if quota_config:
|
||||||
|
logger.debug(f"从 premium 模块获取租户 {tenant_id} 配额配置")
|
||||||
|
return quota_config
|
||||||
|
except (ModuleNotFoundError, ImportError, Exception) as e:
|
||||||
|
logger.debug(f"无法从 premium 模块获取配额配置: {e}")
|
||||||
|
|
||||||
|
# 降级到配置文件
|
||||||
|
try:
|
||||||
|
from app.config.default_free_plan import DEFAULT_FREE_PLAN
|
||||||
|
logger.info(f"使用配置文件中的免费套餐配额: tenant={tenant_id}")
|
||||||
|
return DEFAULT_FREE_PLAN.get("quotas")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"无法从配置文件获取配额: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaUsageRepository:
|
||||||
|
"""配额使用量数据访问层"""
|
||||||
|
|
||||||
|
def __init__(self, db: Session):
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
def count_workspaces(self, tenant_id: UUID) -> int:
|
||||||
|
from app.models.workspace_model import Workspace
|
||||||
|
return self.db.query(Workspace).filter(
|
||||||
|
Workspace.tenant_id == tenant_id,
|
||||||
|
Workspace.is_active.is_(True)
|
||||||
|
).count()
|
||||||
|
|
||||||
|
def count_apps(self, tenant_id: UUID) -> int:
|
||||||
|
from app.models.app_model import App
|
||||||
|
from app.models.workspace_model import Workspace
|
||||||
|
return self.db.query(App).join(
|
||||||
|
Workspace, App.workspace_id == Workspace.id
|
||||||
|
).filter(
|
||||||
|
Workspace.tenant_id == tenant_id,
|
||||||
|
App.is_active.is_(True)
|
||||||
|
).count()
|
||||||
|
|
||||||
|
def count_skills(self, tenant_id: UUID) -> int:
|
||||||
|
from app.models.skill_model import Skill
|
||||||
|
return self.db.query(Skill).filter(
|
||||||
|
Skill.tenant_id == tenant_id,
|
||||||
|
Skill.is_active.is_(True)
|
||||||
|
).count()
|
||||||
|
|
||||||
|
def sum_knowledge_capacity_gb(self, tenant_id: UUID) -> float:
|
||||||
|
from app.models.document_model import Document
|
||||||
|
from app.models.knowledge_model import Knowledge
|
||||||
|
from app.models.workspace_model import Workspace
|
||||||
|
result = self.db.query(func.coalesce(func.sum(Document.file_size), 0)).join(
|
||||||
|
Knowledge, Document.kb_id == Knowledge.id
|
||||||
|
).join(
|
||||||
|
Workspace, Knowledge.workspace_id == Workspace.id
|
||||||
|
).filter(
|
||||||
|
Workspace.tenant_id == tenant_id,
|
||||||
|
Document.status == 1,
|
||||||
|
).scalar()
|
||||||
|
return float(result) / (1024 ** 3) if result else 0.0
|
||||||
|
|
||||||
|
def count_memory_engines(self, tenant_id: UUID) -> int:
|
||||||
|
from app.models.memory_config_model import MemoryConfig
|
||||||
|
from app.models.workspace_model import Workspace
|
||||||
|
return self.db.query(MemoryConfig).join(
|
||||||
|
Workspace, MemoryConfig.workspace_id == Workspace.id
|
||||||
|
).filter(
|
||||||
|
Workspace.tenant_id == tenant_id
|
||||||
|
).count()
|
||||||
|
|
||||||
|
def count_end_users(self, tenant_id: UUID) -> int:
|
||||||
|
from app.models.end_user_model import EndUser
|
||||||
|
from app.models.workspace_model import Workspace
|
||||||
|
return self.db.query(EndUser).join(
|
||||||
|
Workspace, EndUser.workspace_id == Workspace.id
|
||||||
|
).filter(
|
||||||
|
Workspace.tenant_id == tenant_id
|
||||||
|
).count()
|
||||||
|
|
||||||
|
def count_models(self, tenant_id: UUID) -> int:
|
||||||
|
from app.models.models_model import ModelConfig
|
||||||
|
return self.db.query(ModelConfig).filter(
|
||||||
|
ModelConfig.tenant_id == tenant_id,
|
||||||
|
ModelConfig.is_active == True
|
||||||
|
).count()
|
||||||
|
|
||||||
|
def count_ontology_projects(self, tenant_id: UUID) -> int:
|
||||||
|
from app.models.ontology_scene import OntologyScene
|
||||||
|
from app.models.workspace_model import Workspace
|
||||||
|
return self.db.query(OntologyScene).join(
|
||||||
|
Workspace, OntologyScene.workspace_id == Workspace.id
|
||||||
|
).filter(
|
||||||
|
Workspace.tenant_id == tenant_id
|
||||||
|
).count()
|
||||||
|
|
||||||
|
def get_usage_by_quota_type(self, tenant_id: UUID, quota_type: str):
|
||||||
|
"""按配额类型分发,返回当前使用量"""
|
||||||
|
dispatch = {
|
||||||
|
"workspace_quota": self.count_workspaces,
|
||||||
|
"app_quota": self.count_apps,
|
||||||
|
"skill_quota": self.count_skills,
|
||||||
|
"knowledge_capacity_quota": self.sum_knowledge_capacity_gb,
|
||||||
|
"memory_engine_quota": self.count_memory_engines,
|
||||||
|
"end_user_quota": self.count_end_users,
|
||||||
|
"model_quota": self.count_models,
|
||||||
|
"ontology_project_quota": self.count_ontology_projects,
|
||||||
|
}
|
||||||
|
fn = dispatch.get(quota_type)
|
||||||
|
return fn(tenant_id) if fn else 0
|
||||||
|
|
||||||
|
|
||||||
|
def _check_quota(
|
||||||
|
db: Session,
|
||||||
|
tenant_id: UUID,
|
||||||
|
quota_type: str,
|
||||||
|
resource_name: str,
|
||||||
|
usage_func: Optional[Callable] = None,
|
||||||
|
) -> None:
|
||||||
|
"""核心配额检查逻辑:对比使用量和配额限制"""
|
||||||
|
try:
|
||||||
|
quota_config = _get_quota_config(db, tenant_id)
|
||||||
|
if not quota_config:
|
||||||
|
logger.warning(f"租户 {tenant_id} 无有效配额配置,跳过配额检查")
|
||||||
|
return
|
||||||
|
|
||||||
|
quota_limit = quota_config.get(quota_type)
|
||||||
|
if quota_limit is None:
|
||||||
|
logger.warning(f"配额配置未包含 {quota_type},跳过配额检查")
|
||||||
|
return
|
||||||
|
|
||||||
|
if usage_func:
|
||||||
|
current_usage = usage_func(db, tenant_id)
|
||||||
|
else:
|
||||||
|
current_usage = QuotaUsageRepository(db).get_usage_by_quota_type(tenant_id, quota_type)
|
||||||
|
|
||||||
|
if current_usage >= quota_limit:
|
||||||
|
logger.warning(
|
||||||
|
f"配额不足: tenant={tenant_id}, type={quota_type}, "
|
||||||
|
f"usage={current_usage}, limit={quota_limit}"
|
||||||
|
)
|
||||||
|
raise QuotaExceededError(
|
||||||
|
resource=resource_name,
|
||||||
|
current_usage=current_usage,
|
||||||
|
quota_limit=quota_limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"配额检查通过: tenant={tenant_id}, type={quota_type}, "
|
||||||
|
f"usage={current_usage}, limit={quota_limit}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except QuotaExceededError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"配额检查异常: tenant={tenant_id}, type={quota_type}, "
|
||||||
|
f"error_type={type(e).__name__}, error={str(e)}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# ─── 具名装饰器 ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def check_workspace_quota(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
db: Session = kwargs.get("db")
|
||||||
|
user = _get_user_from_kwargs(kwargs)
|
||||||
|
if not db or not user:
|
||||||
|
logger.warning("配额检查失败:缺少 db 或 user 参数")
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
_check_quota(db, user.tenant_id, "workspace_quota", "workspace")
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def check_skill_quota(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
db: Session = kwargs.get("db")
|
||||||
|
user = _get_user_from_kwargs(kwargs)
|
||||||
|
if not db or not user:
|
||||||
|
logger.warning("配额检查失败:缺少 db 或 user 参数")
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
_check_quota(db, user.tenant_id, "skill_quota", "skill")
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def check_app_quota(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
db: Session = kwargs.get("db")
|
||||||
|
user = _get_user_from_kwargs(kwargs)
|
||||||
|
if not db or not user:
|
||||||
|
logger.warning("配额检查失败:缺少 db 或 user 参数")
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
_check_quota(db, user.tenant_id, "app_quota", "app")
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def check_knowledge_capacity_quota(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
async def async_wrapper(*args, **kwargs):
|
||||||
|
db: Session = kwargs.get("db")
|
||||||
|
if not db:
|
||||||
|
logger.warning("配额检查失败:缺少 db 参数")
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
tenant_id = _get_tenant_id_from_kwargs(db, kwargs)
|
||||||
|
if not tenant_id:
|
||||||
|
logger.warning("配额检查失败:无法获取 tenant_id")
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
_check_quota(db, tenant_id, "knowledge_capacity_quota", "knowledge_capacity")
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def sync_wrapper(*args, **kwargs):
|
||||||
|
db: Session = kwargs.get("db")
|
||||||
|
user = _get_user_from_kwargs(kwargs)
|
||||||
|
if not db or not user:
|
||||||
|
logger.warning("配额检查失败:缺少 db 或 user 参数")
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
_check_quota(db, user.tenant_id, "knowledge_capacity_quota", "knowledge_capacity")
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def check_memory_engine_quota(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
db: Session = kwargs.get("db")
|
||||||
|
user = _get_user_from_kwargs(kwargs)
|
||||||
|
if not db or not user:
|
||||||
|
logger.warning("配额检查失败:缺少 db 或 user 参数")
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
_check_quota(db, user.tenant_id, "memory_engine_quota", "memory_engine")
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def check_end_user_quota(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
async def async_wrapper(*args, **kwargs):
|
||||||
|
db: Session = kwargs.get("db")
|
||||||
|
if not db:
|
||||||
|
logger.warning("配额检查失败:缺少 db 参数")
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
tenant_id = _get_tenant_id_from_kwargs(db, kwargs)
|
||||||
|
if not tenant_id:
|
||||||
|
logger.warning("配额检查失败:无法获取 tenant_id")
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
_check_quota(db, tenant_id, "end_user_quota", "end_user")
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def sync_wrapper(*args, **kwargs):
|
||||||
|
db: Session = kwargs.get("db")
|
||||||
|
if not db:
|
||||||
|
logger.warning("配额检查失败:缺少 db 参数")
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
tenant_id = _get_tenant_id_from_kwargs(db, kwargs)
|
||||||
|
if not tenant_id:
|
||||||
|
logger.warning("配额检查失败:无法获取 tenant_id")
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
_check_quota(db, tenant_id, "end_user_quota", "end_user")
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def check_ontology_project_quota(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
db: Session = kwargs.get("db")
|
||||||
|
user = _get_user_from_kwargs(kwargs)
|
||||||
|
if not db or not user:
|
||||||
|
logger.warning("配额检查失败:缺少 db 或 user 参数")
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
_check_quota(db, user.tenant_id, "ontology_project_quota", "ontology_project")
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def check_model_quota(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
db: Session = kwargs.get("db")
|
||||||
|
user = _get_user_from_kwargs(kwargs)
|
||||||
|
if not db or not user:
|
||||||
|
logger.warning("配额检查失败:缺少 db 或 user 参数")
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
_check_quota(db, user.tenant_id, "model_quota", "model")
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def check_model_activation_quota(func: Callable) -> Callable:
|
||||||
|
"""模型激活时的配额检查装饰器"""
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
db: Session = kwargs.get("db")
|
||||||
|
user = _get_user_from_kwargs(kwargs)
|
||||||
|
if not db or not user:
|
||||||
|
logger.warning("配额检查失败:缺少 db 或 user 参数")
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
model_id = kwargs.get("model_id") or (args[1] if len(args) > 1 else None)
|
||||||
|
model_data = kwargs.get("model_data")
|
||||||
|
|
||||||
|
if not model_id or not model_data:
|
||||||
|
logger.warning("模型激活配额检查失败:缺少 model_id 或 model_data 参数")
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
if model_data.is_active is True:
|
||||||
|
try:
|
||||||
|
from app.models.models_model import ModelConfig
|
||||||
|
from app.services.model_service import ModelConfigService
|
||||||
|
|
||||||
|
existing_model = ModelConfigService.get_model_by_id(
|
||||||
|
db=db,
|
||||||
|
model_id=model_id,
|
||||||
|
tenant_id=user.tenant_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not existing_model.is_active:
|
||||||
|
logger.info(f"模型激活操作,检查配额: model_id={model_id}, tenant_id={user.tenant_id}")
|
||||||
|
_check_quota(db, user.tenant_id, "model_quota", "model")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"模型激活配额检查异常: model_id={model_id}, error={str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def check_quota(quota_type: str, resource_name: str, usage_func: Optional[Callable] = None):
|
||||||
|
"""通用配额检查装饰器,支持自定义使用量获取函数"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
db: Session = kwargs.get("db")
|
||||||
|
user = _get_user_from_kwargs(kwargs)
|
||||||
|
if not db or not user:
|
||||||
|
logger.warning("配额检查失败:缺少 db 或 user 参数")
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
_check_quota(db, user.tenant_id, quota_type, resource_name, usage_func)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
# ─── 配额使用统计 ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def get_quota_usage(db: Session, tenant_id: UUID) -> dict:
|
||||||
|
"""获取租户所有配额的使用情况"""
|
||||||
|
quota_config = _get_quota_config(db, tenant_id)
|
||||||
|
if not quota_config:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
repo = QuotaUsageRepository(db)
|
||||||
|
|
||||||
|
def pct(used, limit):
|
||||||
|
return round(used / limit * 100, 1) if limit else None
|
||||||
|
|
||||||
|
workspace_count = repo.count_workspaces(tenant_id)
|
||||||
|
skill_count = repo.count_skills(tenant_id)
|
||||||
|
app_count = repo.count_apps(tenant_id)
|
||||||
|
knowledge_gb = repo.sum_knowledge_capacity_gb(tenant_id)
|
||||||
|
memory_count = repo.count_memory_engines(tenant_id)
|
||||||
|
end_user_count = repo.count_end_users(tenant_id)
|
||||||
|
model_count = repo.count_models(tenant_id)
|
||||||
|
ontology_count = repo.count_ontology_projects(tenant_id)
|
||||||
|
|
||||||
|
api_ops_current = 0
|
||||||
|
try:
|
||||||
|
from app.core.config import settings
|
||||||
|
import redis
|
||||||
|
_now = time.time()
|
||||||
|
_rk = f"rate_limit:tenant_qps:{tenant_id}"
|
||||||
|
_r = redis.StrictRedis(
|
||||||
|
host=settings.REDIS_HOST, port=settings.REDIS_PORT,
|
||||||
|
db=settings.REDIS_DB, password=settings.REDIS_PASSWORD,
|
||||||
|
decode_responses=True
|
||||||
|
)
|
||||||
|
api_ops_current = int(_r.zcount(_rk, _now - 1, "+inf"))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return {
|
||||||
|
"workspace": {"used": workspace_count, "limit": quota_config.get("workspace_quota"), "percentage": pct(workspace_count, quota_config.get("workspace_quota"))},
|
||||||
|
"skill": {"used": skill_count, "limit": quota_config.get("skill_quota"), "percentage": pct(skill_count, quota_config.get("skill_quota"))},
|
||||||
|
"app": {"used": app_count, "limit": quota_config.get("app_quota"), "percentage": pct(app_count, quota_config.get("app_quota"))},
|
||||||
|
"knowledge_capacity": {"used": round(knowledge_gb, 2), "limit": quota_config.get("knowledge_capacity_quota"), "percentage": pct(knowledge_gb, quota_config.get("knowledge_capacity_quota")), "unit": "GB"},
|
||||||
|
"memory_engine": {"used": memory_count, "limit": quota_config.get("memory_engine_quota"), "percentage": pct(memory_count, quota_config.get("memory_engine_quota"))},
|
||||||
|
"end_user": {"used": end_user_count, "limit": quota_config.get("end_user_quota"), "percentage": pct(end_user_count, quota_config.get("end_user_quota"))},
|
||||||
|
"ontology_project": {"used": ontology_count, "limit": quota_config.get("ontology_project_quota"), "percentage": pct(ontology_count, quota_config.get("ontology_project_quota"))},
|
||||||
|
"model": {"used": model_count, "limit": quota_config.get("model_quota"), "percentage": pct(model_count, quota_config.get("model_quota"))},
|
||||||
|
"api_ops_rate_limit": {"current": api_ops_current, "limit": quota_config.get("api_ops_rate_limit"), "percentage": None, "unit": "次/秒"},
|
||||||
|
}
|
||||||
36
api/app/core/quota_stub.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
"""
|
||||||
|
配额检查 stub - 社区版和 SaaS 版统一使用 core.quota_manager 实现
|
||||||
|
|
||||||
|
所有配额检查逻辑统一在 core 层实现,两个版本共用:
|
||||||
|
- 社区版:从 default_free_plan.py 读取配额限制
|
||||||
|
- SaaS 版:优先从 tenant_subscriptions 表读取,降级到配置文件
|
||||||
|
"""
|
||||||
|
from app.core.quota_manager import (
|
||||||
|
check_workspace_quota,
|
||||||
|
check_skill_quota,
|
||||||
|
check_app_quota,
|
||||||
|
check_knowledge_capacity_quota,
|
||||||
|
check_memory_engine_quota,
|
||||||
|
check_end_user_quota,
|
||||||
|
check_ontology_project_quota,
|
||||||
|
check_model_quota,
|
||||||
|
check_model_activation_quota,
|
||||||
|
get_quota_usage,
|
||||||
|
_check_quota,
|
||||||
|
QuotaUsageRepository,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"check_workspace_quota",
|
||||||
|
"check_skill_quota",
|
||||||
|
"check_app_quota",
|
||||||
|
"check_knowledge_capacity_quota",
|
||||||
|
"check_memory_engine_quota",
|
||||||
|
"check_end_user_quota",
|
||||||
|
"check_ontology_project_quota",
|
||||||
|
"check_model_quota",
|
||||||
|
"check_model_activation_quota",
|
||||||
|
"get_quota_usage",
|
||||||
|
"_check_quota",
|
||||||
|
"QuotaUsageRepository",
|
||||||
|
]
|
||||||
@@ -672,10 +672,15 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|||||||
excel_parser = ExcelParser()
|
excel_parser = ExcelParser()
|
||||||
if parser_config.get("html4excel") and parser_config.get("html4excel").lower() == "true":
|
if parser_config.get("html4excel") and parser_config.get("html4excel").lower() == "true":
|
||||||
sections = [(_, "") for _ in excel_parser.html(binary, 12) if _]
|
sections = [(_, "") for _ in excel_parser.html(binary, 12) if _]
|
||||||
parser_config["chunk_token_num"] = 0
|
|
||||||
else:
|
else:
|
||||||
sections = [(_, "") for _ in excel_parser(binary) if _]
|
sections = [(_, "") for _ in excel_parser(binary) if _]
|
||||||
parser_config["chunk_token_num"] = 12800
|
callback(0.8, "Finish parsing.")
|
||||||
|
# Excel 每行直接作为一个 chunk,不经过 naive_merge 避免被 delimiter 拆分
|
||||||
|
chunks = [s for s, _ in sections]
|
||||||
|
res.extend(tokenize_chunks(chunks, doc, is_english, None))
|
||||||
|
res.extend(embed_res)
|
||||||
|
res.extend(url_res)
|
||||||
|
return res
|
||||||
|
|
||||||
elif re.search(r"\.(txt|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|sql)$", filename, re.IGNORECASE):
|
elif re.search(r"\.(txt|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|sql)$", filename, re.IGNORECASE):
|
||||||
callback(0.1, "Start to parse.")
|
callback(0.1, "Start to parse.")
|
||||||
|
|||||||
@@ -232,14 +232,14 @@ class RAGExcelParser:
|
|||||||
t = str(ti[i].value) if i < len(ti) else ""
|
t = str(ti[i].value) if i < len(ti) else ""
|
||||||
t += (":" if t else "") + str(c.value)
|
t += (":" if t else "") + str(c.value)
|
||||||
fields.append(t)
|
fields.append(t)
|
||||||
line = "; ".join(fields)
|
line = "\n".join(fields)
|
||||||
if sheetname.lower().find("sheet") < 0:
|
if sheetname.lower().find("sheet") < 0:
|
||||||
line += " ——" + sheetname
|
line += "\n——" + sheetname
|
||||||
res.append(line)
|
res.append(line)
|
||||||
else:
|
else:
|
||||||
# 只有表头的情况
|
# 只有表头的情况
|
||||||
if header_fields:
|
if header_fields:
|
||||||
line = "; ".join(header_fields)
|
line = "\n".join(header_fields)
|
||||||
if sheetname.lower().find("sheet") < 0:
|
if sheetname.lower().find("sheet") < 0:
|
||||||
line += " ——" + sheetname
|
line += " ——" + sheetname
|
||||||
res.append(line)
|
res.append(line)
|
||||||
|
|||||||
@@ -50,7 +50,9 @@ class OpenAIEmbed(Base):
|
|||||||
def encode(self, texts: list):
|
def encode(self, texts: list):
|
||||||
# OpenAI requires batch size <=16
|
# OpenAI requires batch size <=16
|
||||||
batch_size = 16
|
batch_size = 16
|
||||||
texts = [truncate(t, 8191) for t in texts]
|
# Use 8000 instead of 8191 to leave safety margin for tokenizer differences
|
||||||
|
# between cl100k_base (used by truncate) and the actual embedding model
|
||||||
|
texts = [truncate(t, 8000) for t in texts]
|
||||||
ress = []
|
ress = []
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
for i in range(0, len(texts), batch_size):
|
for i in range(0, len(texts), batch_size):
|
||||||
@@ -63,7 +65,7 @@ class OpenAIEmbed(Base):
|
|||||||
return np.array(ress), total_tokens
|
return np.array(ress), total_tokens
|
||||||
|
|
||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float",extra_body={"drop_params": True})
|
res = self.client.embeddings.create(input=[truncate(text, 8000)], model=self.model_name, encoding_format="float",extra_body={"drop_params": True})
|
||||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||||
|
|
||||||
|
|
||||||
@@ -79,6 +81,7 @@ class LocalAIEmbed(Base):
|
|||||||
|
|
||||||
def encode(self, texts: list):
|
def encode(self, texts: list):
|
||||||
batch_size = 16
|
batch_size = 16
|
||||||
|
texts = [truncate(t, 8000) for t in texts]
|
||||||
ress = []
|
ress = []
|
||||||
for i in range(0, len(texts), batch_size):
|
for i in range(0, len(texts), batch_size):
|
||||||
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
|
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
|
||||||
@@ -173,6 +176,7 @@ class XinferenceEmbed(Base):
|
|||||||
|
|
||||||
def encode(self, texts: list):
|
def encode(self, texts: list):
|
||||||
batch_size = 16
|
batch_size = 16
|
||||||
|
texts = [truncate(t, 8000) for t in texts]
|
||||||
ress = []
|
ress = []
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
for i in range(0, len(texts), batch_size):
|
for i in range(0, len(texts), batch_size):
|
||||||
@@ -188,7 +192,7 @@ class XinferenceEmbed(Base):
|
|||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
res = None
|
res = None
|
||||||
try:
|
try:
|
||||||
res = self.client.embeddings.create(input=[text], model=self.model_name)
|
res = self.client.embeddings.create(input=[truncate(text, 8000)], model=self.model_name)
|
||||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
log_exception(_e, res)
|
log_exception(_e, res)
|
||||||
|
|||||||
@@ -253,9 +253,9 @@ class DateTimeTool(BuiltinTool):
|
|||||||
return {
|
return {
|
||||||
"datetime": input_value,
|
"datetime": input_value,
|
||||||
"timezone": timezone_str,
|
"timezone": timezone_str,
|
||||||
"timestamp": int(dt.timestamp()) * 1000,
|
"timestamp": int(dt.timestamp() * 1000),
|
||||||
"iso_format": dt.isoformat(),
|
"iso_format": dt.isoformat(),
|
||||||
"result_data": int(dt.timestamp()) * 1000
|
"result_data": int(dt.timestamp() * 1000)
|
||||||
}
|
}
|
||||||
|
|
||||||
def _calculate_datetime(self, kwargs) -> dict:
|
def _calculate_datetime(self, kwargs) -> dict:
|
||||||
|
|||||||
@@ -201,12 +201,15 @@ class VariablePool:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_field(struct: "VariableStruct", field: str | None) -> Any:
|
def _extract_field(struct: "VariableStruct", field: str | None) -> Any:
|
||||||
"""If field is given, drill into a dict/object variable's value."""
|
"""If field is given, drill into a dict/object/array[file] variable's value."""
|
||||||
if field is None:
|
if field is None:
|
||||||
return struct.instance.get_value()
|
return struct.instance.get_value()
|
||||||
value = struct.instance.get_value()
|
value = struct.instance.get_value()
|
||||||
|
# array[file]: extract the field from every element, return a list
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [item.get(field) if isinstance(item, dict) else getattr(item, field, None) for item in value]
|
||||||
if not isinstance(value, dict):
|
if not isinstance(value, dict):
|
||||||
raise KeyError(f"Variable is not an object, cannot access field '{field}'")
|
raise KeyError(f"Variable is not an object or array, cannot access field '{field}'")
|
||||||
return value.get(field)
|
return value.get(field)
|
||||||
|
|
||||||
def get_instance(
|
def get_instance(
|
||||||
|
|||||||
@@ -28,86 +28,135 @@ class IterationRuntime:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
start_id: str,
|
|
||||||
stream: bool,
|
stream: bool,
|
||||||
graph: CompiledStateGraph,
|
|
||||||
node_id: str,
|
node_id: str,
|
||||||
config: dict[str, Any],
|
config: dict[str, Any],
|
||||||
state: WorkflowState,
|
state: WorkflowState,
|
||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool,
|
||||||
child_variable_pool: VariablePool,
|
cycle_nodes: list,
|
||||||
|
cycle_edges: list,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the iteration runtime.
|
Initialize the iteration runtime.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph: Compiled workflow graph capable of async invocation.
|
stream: Whether to run in streaming mode. When True, each iteration
|
||||||
node_id: Unique identifier of the loop node.
|
uses graph.astream and emits cycle_item events in real time.
|
||||||
config: Dictionary containing iteration node configuration.
|
When False, graph.ainvoke is used instead.
|
||||||
state: Current workflow state at the point of iteration.
|
node_id: The unique identifier of the iteration node in the workflow.
|
||||||
|
Also used as the variable namespace for item/index inside
|
||||||
|
the subgraph (e.g. {{ node_id.item }}).
|
||||||
|
config: Raw configuration dict for the iteration node, parsed into
|
||||||
|
IterationNodeConfig. Controls input/output variable selectors,
|
||||||
|
parallel execution settings, and output flattening.
|
||||||
|
state: The parent workflow state at the point the iteration node is
|
||||||
|
entered. Each task receives a copy of this state as its
|
||||||
|
starting point.
|
||||||
|
variable_pool: The parent VariablePool containing all variables available
|
||||||
|
at the time the iteration node executes, including sys.*,
|
||||||
|
conv.*, and outputs from upstream nodes. Used as the source
|
||||||
|
for deep-copying into each task's independent child pool.
|
||||||
|
cycle_nodes: List of node config dicts belonging to this iteration's
|
||||||
|
subgraph (i.e. nodes whose cycle field equals node_id).
|
||||||
|
Passed to GraphBuilder when constructing each task's subgraph.
|
||||||
|
cycle_edges: List of edge config dicts connecting nodes within the subgraph.
|
||||||
|
Passed to GraphBuilder alongside cycle_nodes.
|
||||||
"""
|
"""
|
||||||
self.start_id = start_id
|
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
self.graph = graph
|
|
||||||
self.state = state
|
self.state = state
|
||||||
self.node_id = node_id
|
self.node_id = node_id
|
||||||
self.typed_config = IterationNodeConfig(**config)
|
self.typed_config = IterationNodeConfig(**config)
|
||||||
self.looping = True
|
self.looping = True
|
||||||
self.variable_pool = variable_pool
|
self.variable_pool = variable_pool
|
||||||
self.child_variable_pool = child_variable_pool
|
self.cycle_nodes = cycle_nodes
|
||||||
|
self.cycle_edges = cycle_edges
|
||||||
self.event_write = get_stream_writer()
|
self.event_write = get_stream_writer()
|
||||||
self.checkpoint = RunnableConfig(
|
|
||||||
configurable={
|
|
||||||
"thread_id": uuid.uuid4()
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
self.output_value = None
|
self.output_value = None
|
||||||
self.result: list = []
|
self.result: list = []
|
||||||
|
|
||||||
async def _init_iteration_state(self, item, idx):
|
def _build_child_graph(self) -> tuple[CompiledStateGraph, VariablePool, str]:
|
||||||
"""
|
"""
|
||||||
Initialize a per-iteration copy of the workflow state.
|
Build an independent compiled subgraph for a single iteration task.
|
||||||
|
|
||||||
Args:
|
Each call creates a brand-new VariablePool by deep-copying the parent pool,
|
||||||
item: Current element from the input array for this iteration.
|
then passes it to GraphBuilder. GraphBuilder binds this pool to every node's
|
||||||
idx: Index of the element in the input array.
|
execution closure at build time, so the pool and the subgraph always reference
|
||||||
|
the same object. This is the key design invariant: item/index written into the
|
||||||
|
pool after build will be visible to all nodes inside the subgraph.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A copy of the workflow state with iteration-specific variables set.
|
graph: The compiled LangGraph subgraph ready for invocation.
|
||||||
|
child_pool: The VariablePool bound to this subgraph's node closures.
|
||||||
|
Callers must write item/index into this pool before invoking
|
||||||
|
the graph, and read output from it after invocation.
|
||||||
|
start_node_id: The ID of the CYCLE_START node inside the subgraph,
|
||||||
|
used to set the initial activation signal in workflow state.
|
||||||
"""
|
"""
|
||||||
loopstate = WorkflowState(
|
from app.core.workflow.engine.graph_builder import GraphBuilder
|
||||||
**self.state
|
child_pool = VariablePool()
|
||||||
|
child_pool.copy(self.variable_pool)
|
||||||
|
builder = GraphBuilder(
|
||||||
|
{"nodes": self.cycle_nodes, "edges": self.cycle_edges},
|
||||||
|
stream=self.stream,
|
||||||
|
variable_pool=child_pool,
|
||||||
|
cycle=self.node_id,
|
||||||
)
|
)
|
||||||
self.child_variable_pool.copy(self.variable_pool)
|
graph = builder.build()
|
||||||
await self.child_variable_pool.new(self.node_id, "item", item, VariableType.type_map(item), mut=True)
|
return graph, builder.variable_pool, builder.start_node_id
|
||||||
await self.child_variable_pool.new(self.node_id, "index", item, VariableType.type_map(item), mut=True)
|
|
||||||
loopstate["node_outputs"][self.node_id] = {
|
async def _init_iteration_state(self, item, idx, child_pool: VariablePool, start_id: str):
|
||||||
"item": item,
|
"""
|
||||||
"index": idx,
|
Initialize the workflow state for a single iteration.
|
||||||
}
|
|
||||||
|
Writes the current item and its index into child_pool under the iteration
|
||||||
|
node's namespace (e.g. iteration_xxx.item, iteration_xxx.index), making them
|
||||||
|
accessible to downstream nodes inside the subgraph via variable selectors.
|
||||||
|
|
||||||
|
Also prepares a copy of the parent workflow state with:
|
||||||
|
- node_outputs[node_id] set to {item, index} so the state snapshot is consistent
|
||||||
|
with the pool values.
|
||||||
|
- looping flag set to 1 (active) to signal the subgraph is inside a cycle.
|
||||||
|
- activate[start_id] set to True to trigger the CYCLE_START node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
item: The current element from the input array.
|
||||||
|
idx: The zero-based index of this element in the input array.
|
||||||
|
child_pool: The VariablePool bound to this iteration's subgraph.
|
||||||
|
Must be the same object returned by _build_child_graph.
|
||||||
|
start_id: The ID of the CYCLE_START node inside the subgraph.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A WorkflowState instance ready to be passed to graph.ainvoke or graph.astream.
|
||||||
|
"""
|
||||||
|
loopstate = WorkflowState(**self.state)
|
||||||
|
await child_pool.new(self.node_id, "item", item, VariableType.type_map(item), mut=True)
|
||||||
|
await child_pool.new(self.node_id, "index", idx, VariableType.type_map(idx), mut=True)
|
||||||
|
loopstate["node_outputs"][self.node_id] = {"item": item, "index": idx}
|
||||||
loopstate["looping"] = 1
|
loopstate["looping"] = 1
|
||||||
loopstate["activate"][self.start_id] = True
|
loopstate["activate"][start_id] = True
|
||||||
return loopstate
|
return loopstate
|
||||||
|
|
||||||
def merge_conv_vars(self):
|
def _merge_conv_vars(self, child_pool: VariablePool):
|
||||||
self.variable_pool.variables["conv"].update(
|
self.variable_pool.variables["conv"].update(child_pool.variables["conv"])
|
||||||
self.child_variable_pool.variables["conv"]
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run_task(self, item, idx):
|
async def run_task(self, item, idx):
|
||||||
"""
|
"""
|
||||||
Execute a single iteration asynchronously.
|
Execute a single iteration asynchronously.
|
||||||
|
Each task builds its own subgraph so the variable pool closure is independent.
|
||||||
|
|
||||||
Args:
|
Returns:
|
||||||
item: The input element for this iteration.
|
Tuple of (idx, output, result, child_pool, stopped)
|
||||||
idx: The index of this iteration.
|
|
||||||
"""
|
"""
|
||||||
|
graph, child_pool, start_id = self._build_child_graph()
|
||||||
|
checkpoint = RunnableConfig(configurable={"thread_id": uuid.uuid4()})
|
||||||
|
init_state = await self._init_iteration_state(item, idx, child_pool, start_id)
|
||||||
|
|
||||||
if self.stream:
|
if self.stream:
|
||||||
async for event in self.graph.astream(
|
async for event in graph.astream(
|
||||||
await self._init_iteration_state(item, idx),
|
init_state,
|
||||||
stream_mode=["debug"],
|
stream_mode=["debug"],
|
||||||
config=self.checkpoint
|
config=checkpoint
|
||||||
):
|
):
|
||||||
if isinstance(event, tuple) and len(event) == 2:
|
if isinstance(event, tuple) and len(event) == 2:
|
||||||
mode, data = event
|
mode, data = event
|
||||||
@@ -117,7 +166,6 @@ class IterationRuntime:
|
|||||||
event_type = data.get("type")
|
event_type = data.get("type")
|
||||||
payload = data.get("payload", {})
|
payload = data.get("payload", {})
|
||||||
node_name = payload.get("name")
|
node_name = payload.get("name")
|
||||||
|
|
||||||
if node_name and node_name.startswith("nop"):
|
if node_name and node_name.startswith("nop"):
|
||||||
continue
|
continue
|
||||||
if event_type == "task_result":
|
if event_type == "task_result":
|
||||||
@@ -140,17 +188,13 @@ class IterationRuntime:
|
|||||||
"token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
|
"token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
result = self.graph.get_state(config=self.checkpoint).values
|
result = graph.get_state(config=checkpoint).values
|
||||||
else:
|
else:
|
||||||
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
|
result = await graph.ainvoke(init_state)
|
||||||
output = self.child_variable_pool.get_value(self.output_value)
|
|
||||||
if isinstance(output, list) and self.typed_config.flatten:
|
output = child_pool.get_value(self.output_value)
|
||||||
self.result.extend(output)
|
stopped = result["looping"] == 2
|
||||||
else:
|
return idx, output, result, child_pool, stopped
|
||||||
self.result.append(output)
|
|
||||||
if result["looping"] == 2:
|
|
||||||
self.looping = False
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _create_iteration_tasks(self, array_obj, idx):
|
def _create_iteration_tasks(self, array_obj, idx):
|
||||||
"""
|
"""
|
||||||
@@ -196,16 +240,32 @@ class IterationRuntime:
|
|||||||
tasks = self._create_iteration_tasks(array_obj, idx)
|
tasks = self._create_iteration_tasks(array_obj, idx)
|
||||||
logger.info(f"Iteration node {self.node_id}: running, concurrency {len(tasks)}")
|
logger.info(f"Iteration node {self.node_id}: running, concurrency {len(tasks)}")
|
||||||
idx += self.typed_config.parallel_count
|
idx += self.typed_config.parallel_count
|
||||||
child_state.extend(await asyncio.gather(*tasks))
|
batch = await asyncio.gather(*tasks)
|
||||||
self.merge_conv_vars()
|
# Sort by idx to preserve order, then collect results
|
||||||
|
batch_sorted = sorted(batch, key=lambda x: x[0])
|
||||||
|
for _, output, result, child_pool, stopped in batch_sorted:
|
||||||
|
if isinstance(output, list) and self.typed_config.flatten:
|
||||||
|
self.result.extend(output)
|
||||||
|
else:
|
||||||
|
self.result.append(output)
|
||||||
|
child_state.append(result)
|
||||||
|
self._merge_conv_vars(child_pool)
|
||||||
|
if stopped:
|
||||||
|
self.looping = False
|
||||||
else:
|
else:
|
||||||
# Execute iterations sequentially
|
# Execute iterations sequentially
|
||||||
while idx < len(array_obj) and self.looping:
|
while idx < len(array_obj) and self.looping:
|
||||||
logger.info(f"Iteration node {self.node_id}: running")
|
logger.info(f"Iteration node {self.node_id}: running")
|
||||||
item = array_obj[idx]
|
item = array_obj[idx]
|
||||||
result = await self.run_task(item, idx)
|
_, output, result, child_pool, stopped = await self.run_task(item, idx)
|
||||||
self.merge_conv_vars()
|
if isinstance(output, list) and self.typed_config.flatten:
|
||||||
|
self.result.extend(output)
|
||||||
|
else:
|
||||||
|
self.result.append(output)
|
||||||
|
self._merge_conv_vars(child_pool)
|
||||||
child_state.append(result)
|
child_state.append(result)
|
||||||
|
if stopped:
|
||||||
|
self.looping = False
|
||||||
idx += 1
|
idx += 1
|
||||||
logger.info(f"Iteration node {self.node_id}: execution completed")
|
logger.info(f"Iteration node {self.node_id}: execution completed")
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ class CycleGraphNode(BaseNode):
|
|||||||
|
|
||||||
return cycle_nodes, cycle_edges
|
return cycle_nodes, cycle_edges
|
||||||
|
|
||||||
def build_graph(self):
|
def build_graph(self, variable_pool: VariablePool):
|
||||||
"""
|
"""
|
||||||
Build and compile the internal subgraph for this cycle node.
|
Build and compile the internal subgraph for this cycle node.
|
||||||
|
|
||||||
@@ -135,6 +135,7 @@ class CycleGraphNode(BaseNode):
|
|||||||
from app.core.workflow.engine.graph_builder import GraphBuilder
|
from app.core.workflow.engine.graph_builder import GraphBuilder
|
||||||
|
|
||||||
self.child_variable_pool = VariablePool()
|
self.child_variable_pool = VariablePool()
|
||||||
|
self.child_variable_pool.copy(variable_pool)
|
||||||
builder = GraphBuilder(
|
builder = GraphBuilder(
|
||||||
{
|
{
|
||||||
"nodes": self.cycle_nodes,
|
"nodes": self.cycle_nodes,
|
||||||
@@ -165,8 +166,8 @@ class CycleGraphNode(BaseNode):
|
|||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If the node type is unsupported.
|
RuntimeError: If the node type is unsupported.
|
||||||
"""
|
"""
|
||||||
self.build_graph()
|
|
||||||
if self.node_type == NodeType.LOOP:
|
if self.node_type == NodeType.LOOP:
|
||||||
|
self.build_graph(variable_pool)
|
||||||
return await LoopRuntime(
|
return await LoopRuntime(
|
||||||
start_id=self.start_node_id,
|
start_id=self.start_node_id,
|
||||||
stream=False,
|
stream=False,
|
||||||
@@ -179,20 +180,19 @@ class CycleGraphNode(BaseNode):
|
|||||||
).run()
|
).run()
|
||||||
if self.node_type == NodeType.ITERATION:
|
if self.node_type == NodeType.ITERATION:
|
||||||
return await IterationRuntime(
|
return await IterationRuntime(
|
||||||
start_id=self.start_node_id,
|
|
||||||
stream=False,
|
stream=False,
|
||||||
graph=self.graph,
|
|
||||||
node_id=self.node_id,
|
node_id=self.node_id,
|
||||||
config=self.config,
|
config=self.config,
|
||||||
state=state,
|
state=state,
|
||||||
variable_pool=variable_pool,
|
variable_pool=variable_pool,
|
||||||
child_variable_pool=self.child_variable_pool
|
cycle_nodes=self.cycle_nodes,
|
||||||
|
cycle_edges=self.cycle_edges,
|
||||||
).run()
|
).run()
|
||||||
raise RuntimeError("Unknown cycle node type")
|
raise RuntimeError("Unknown cycle node type")
|
||||||
|
|
||||||
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
||||||
self.build_graph()
|
|
||||||
if self.node_type == NodeType.LOOP:
|
if self.node_type == NodeType.LOOP:
|
||||||
|
self.build_graph(variable_pool)
|
||||||
yield {
|
yield {
|
||||||
"__final__": True,
|
"__final__": True,
|
||||||
"result": await LoopRuntime(
|
"result": await LoopRuntime(
|
||||||
@@ -211,14 +211,13 @@ class CycleGraphNode(BaseNode):
|
|||||||
yield {
|
yield {
|
||||||
"__final__": True,
|
"__final__": True,
|
||||||
"result": await IterationRuntime(
|
"result": await IterationRuntime(
|
||||||
start_id=self.start_node_id,
|
|
||||||
stream=True,
|
stream=True,
|
||||||
graph=self.graph,
|
|
||||||
node_id=self.node_id,
|
node_id=self.node_id,
|
||||||
config=self.config,
|
config=self.config,
|
||||||
state=state,
|
state=state,
|
||||||
variable_pool=variable_pool,
|
variable_pool=variable_pool,
|
||||||
child_variable_pool=self.child_variable_pool
|
cycle_nodes=self.cycle_nodes,
|
||||||
|
cycle_edges=self.cycle_edges,
|
||||||
).run()
|
).run()
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -72,8 +72,9 @@ class HttpContentTypeConfig(BaseModel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def validate_data(cls, v, info):
|
def validate_data(cls, v, info):
|
||||||
content_type = info.data.get("content_type")
|
content_type = info.data.get("content_type")
|
||||||
if content_type == HttpContentType.FROM_DATA and not isinstance(v, HttpFormData):
|
if content_type == HttpContentType.FROM_DATA and (
|
||||||
raise ValueError("When content_type is 'form-data', data must be of type HttpFormData")
|
not isinstance(v, list) or not all(isinstance(item, HttpFormData) for item in v)):
|
||||||
|
raise ValueError("When content_type is 'form-data', data must be a list of HttpFormData")
|
||||||
elif content_type in [HttpContentType.JSON] and not isinstance(v, str):
|
elif content_type in [HttpContentType.JSON] and not isinstance(v, str):
|
||||||
raise ValueError("When content_type is JSON, data must be of type str")
|
raise ValueError("When content_type is JSON, data must be of type str")
|
||||||
elif content_type in [HttpContentType.WWW_FORM] and not isinstance(v, dict):
|
elif content_type in [HttpContentType.WWW_FORM] and not isinstance(v, dict):
|
||||||
|
|||||||
@@ -260,17 +260,22 @@ class HttpRequestNode(BaseNode):
|
|||||||
))
|
))
|
||||||
case HttpContentType.FROM_DATA:
|
case HttpContentType.FROM_DATA:
|
||||||
data = {}
|
data = {}
|
||||||
content["files"] = {}
|
files = []
|
||||||
for item in self.typed_config.body.data:
|
for item in self.typed_config.body.data:
|
||||||
|
key = self._render_template(item.key, variable_pool)
|
||||||
if item.type == "text":
|
if item.type == "text":
|
||||||
data[self._render_template(item.key, variable_pool)] = self._render_template(item.value,
|
data[key] = self._render_template(item.value, variable_pool)
|
||||||
variable_pool)
|
|
||||||
elif item.type == "file":
|
elif item.type == "file":
|
||||||
content["files"][self._render_template(item.key, variable_pool)] = (
|
file_instance = variable_pool.get_instance(item.value)
|
||||||
uuid.uuid4().hex,
|
if isinstance(file_instance, ArrayVariable):
|
||||||
await variable_pool.get_instance(item.value).get_content()
|
for v in file_instance.value:
|
||||||
)
|
if isinstance(v, FileVariable):
|
||||||
|
files.append((key, (uuid.uuid4().hex, await v.get_content())))
|
||||||
|
elif isinstance(file_instance, FileVariable):
|
||||||
|
files.append((key, (uuid.uuid4().hex, await file_instance.get_content())))
|
||||||
content["data"] = data
|
content["data"] = data
|
||||||
|
if files:
|
||||||
|
content["files"] = files
|
||||||
case HttpContentType.BINARY:
|
case HttpContentType.BINARY:
|
||||||
content["files"] = []
|
content["files"] = []
|
||||||
file_instence = variable_pool.get_instance(self.typed_config.body.data)
|
file_instence = variable_pool.get_instance(self.typed_config.body.data)
|
||||||
|
|||||||
@@ -6,6 +6,30 @@ from app.core.workflow.nodes.base_config import BaseNodeConfig
|
|||||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
|
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
|
||||||
|
|
||||||
|
|
||||||
|
class SubVariableConditionItem(BaseModel):
|
||||||
|
"""A single condition on a file object's field, used inside sub_variable_condition."""
|
||||||
|
key: str = Field(..., description="Field name of the file object, e.g. type, size, name")
|
||||||
|
operator: ComparisonOperator = Field(..., description="Comparison operator")
|
||||||
|
value: Any = Field(default=None, description="Value to compare with, or variable selector when input_type=variable")
|
||||||
|
input_type: ValueInputType = Field(default=ValueInputType.CONSTANT, description="constant or variable")
|
||||||
|
|
||||||
|
@field_validator("input_type", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def lower_input_type(cls, v):
|
||||||
|
if isinstance(v, str):
|
||||||
|
try:
|
||||||
|
return ValueInputType(v.lower())
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"Invalid input_type: {v}")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class SubVariableCondition(BaseModel):
|
||||||
|
"""Sub-conditions applied to each file element in an array[file] variable."""
|
||||||
|
logical_operator: LogicOperator = Field(default=LogicOperator.AND)
|
||||||
|
conditions: list[SubVariableConditionItem] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class ConditionDetail(BaseModel):
|
class ConditionDetail(BaseModel):
|
||||||
operator: ComparisonOperator = Field(
|
operator: ComparisonOperator = Field(
|
||||||
...,
|
...,
|
||||||
@@ -14,12 +38,12 @@ class ConditionDetail(BaseModel):
|
|||||||
|
|
||||||
left: str = Field(
|
left: str = Field(
|
||||||
...,
|
...,
|
||||||
description="Value to compare against"
|
description="Variable selector, e.g. {{sys.files}}"
|
||||||
)
|
)
|
||||||
|
|
||||||
right: Any = Field(
|
right: Any = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Value to compare with"
|
description="Value to compare with (unused when sub_variable_condition is set)"
|
||||||
)
|
)
|
||||||
|
|
||||||
input_type: ValueInputType = Field(
|
input_type: ValueInputType = Field(
|
||||||
@@ -27,6 +51,11 @@ class ConditionDetail(BaseModel):
|
|||||||
description="Value input type for comparison"
|
description="Value input type for comparison"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sub_variable_condition: SubVariableCondition | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Sub-conditions for array[file] fields. When set, operator must be contains/not_contains."
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator("input_type", mode="before")
|
@field_validator("input_type", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def lower_input_type(cls, v):
|
def lower_input_type(cls, v):
|
||||||
@@ -39,16 +68,19 @@ class ConditionDetail(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ConditionBranchConfig(BaseModel):
|
class ConditionBranchConfig(BaseModel):
|
||||||
"""Configuration for a conditional branch"""
|
"""Configuration for a conditional branch.
|
||||||
|
|
||||||
|
logical_operator controls how all expressions are combined (AND/OR).
|
||||||
|
"""
|
||||||
|
|
||||||
logical_operator: LogicOperator = Field(
|
logical_operator: LogicOperator = Field(
|
||||||
default=LogicOperator.AND,
|
default=LogicOperator.AND,
|
||||||
description="Logical operator used to combine multiple condition expressions"
|
description="Logical operator used to combine all conditions"
|
||||||
)
|
)
|
||||||
|
|
||||||
expressions: list[ConditionDetail] = Field(
|
expressions: list[ConditionDetail] = Field(
|
||||||
...,
|
default_factory=list,
|
||||||
description="List of condition expressions within this branch"
|
description="List of conditions within this branch"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
|||||||
from app.core.workflow.nodes.base_node import BaseNode
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
|
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
|
||||||
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||||
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance
|
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance, ArrayFileContainsOperator
|
||||||
from app.core.workflow.variable.base_variable import VariableType
|
from app.core.workflow.variable.base_variable import VariableType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -90,11 +90,9 @@ class IfElseNode(BaseNode):
|
|||||||
list[str]: A list of Python boolean expression strings,
|
list[str]: A list of Python boolean expression strings,
|
||||||
ordered by branch priority.
|
ordered by branch priority.
|
||||||
"""
|
"""
|
||||||
branch_index = 0
|
|
||||||
conditions = []
|
conditions = []
|
||||||
|
|
||||||
for case_branch in self.typed_config.cases:
|
for case_branch in self.typed_config.cases:
|
||||||
branch_index += 1
|
|
||||||
branch_result = []
|
branch_result = []
|
||||||
for expression in case_branch.expressions:
|
for expression in case_branch.expressions:
|
||||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||||
@@ -103,13 +101,18 @@ class IfElseNode(BaseNode):
|
|||||||
left_value = self.get_variable(left_string, variable_pool)
|
left_value = self.get_variable(left_string, variable_pool)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
left_value = None
|
left_value = None
|
||||||
evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
|
|
||||||
variable_pool,
|
if expression.sub_variable_condition is not None and isinstance(left_value, list):
|
||||||
expression.left,
|
evaluator = ArrayFileContainsOperator(left_value, expression.sub_variable_condition, variable_pool)
|
||||||
expression.right,
|
else:
|
||||||
expression.input_type
|
evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
|
||||||
)
|
variable_pool,
|
||||||
|
expression.left,
|
||||||
|
expression.right,
|
||||||
|
expression.input_type
|
||||||
|
)
|
||||||
branch_result.append(self._evaluate(expression.operator, evaluator))
|
branch_result.append(self._evaluate(expression.operator, evaluator))
|
||||||
|
|
||||||
if case_branch.logical_operator == LogicOperator.AND:
|
if case_branch.logical_operator == LogicOperator.AND:
|
||||||
conditions.append(all(branch_result))
|
conditions.append(all(branch_result))
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -116,6 +116,11 @@ class LLMNodeConfig(BaseNodeConfig):
|
|||||||
description="Top-p 采样参数"
|
description="Top-p 采样参数"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
json_output: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="是否以 JSON 格式输出"
|
||||||
|
)
|
||||||
|
|
||||||
frequency_penalty: float | None = Field(
|
frequency_penalty: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
ge=-2.0,
|
ge=-2.0,
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from app.db import get_db_context
|
|||||||
from app.models import ModelType
|
from app.models import ModelType
|
||||||
from app.schemas.model_schema import ModelInfo
|
from app.schemas.model_schema import ModelInfo
|
||||||
from app.services.model_service import ModelConfigService
|
from app.services.model_service import ModelConfigService
|
||||||
|
from app.models.models_model import ModelProvider
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -126,7 +127,11 @@ class LLMNode(BaseNode):
|
|||||||
|
|
||||||
# 4. 创建 LLM 实例(使用已提取的数据)
|
# 4. 创建 LLM 实例(使用已提取的数据)
|
||||||
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
|
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
|
||||||
extra_params = {"streaming": stream} if stream else {}
|
extra_params: dict[str, Any] = {"streaming": stream} if stream else {}
|
||||||
|
if self.typed_config.temperature is not None:
|
||||||
|
extra_params["temperature"] = self.typed_config.temperature
|
||||||
|
if self.typed_config.max_tokens is not None:
|
||||||
|
extra_params["max_tokens"] = self.typed_config.max_tokens
|
||||||
|
|
||||||
llm = RedBearLLM(
|
llm = RedBearLLM(
|
||||||
RedBearModelConfig(
|
RedBearModelConfig(
|
||||||
@@ -135,7 +140,9 @@ class LLMNode(BaseNode):
|
|||||||
api_key=model_info.api_key,
|
api_key=model_info.api_key,
|
||||||
base_url=model_info.api_base,
|
base_url=model_info.api_base,
|
||||||
extra_params=extra_params,
|
extra_params=extra_params,
|
||||||
is_omni=model_info.is_omni
|
is_omni=model_info.is_omni,
|
||||||
|
capability=model_info.capability,
|
||||||
|
json_output=self.typed_config.json_output,
|
||||||
),
|
),
|
||||||
type=model_info.model_type
|
type=model_info.model_type
|
||||||
)
|
)
|
||||||
@@ -218,6 +225,19 @@ class LLMNode(BaseNode):
|
|||||||
rendered = self._render_template(prompt_template, variable_pool)
|
rendered = self._render_template(prompt_template, variable_pool)
|
||||||
self.messages = [{"role": "user", "content": rendered}]
|
self.messages = [{"role": "user", "content": rendered}]
|
||||||
|
|
||||||
|
# ChatTongyi 要求 messages 含 'json' 字样才能使用 response_format,在 system prompt 中注入
|
||||||
|
# VOLCANO 模型不支持 response_format,同样需要 system prompt 注入
|
||||||
|
need_json_prompt = self.typed_config.json_output and (
|
||||||
|
(model_info.provider.lower() == ModelProvider.DASHSCOPE and not model_info.is_omni)
|
||||||
|
or model_info.provider.lower() == ModelProvider.VOLCANO
|
||||||
|
)
|
||||||
|
if need_json_prompt:
|
||||||
|
system_msg = next((m for m in self.messages if m["role"] == "system"), None)
|
||||||
|
if system_msg:
|
||||||
|
system_msg["content"] += "\n请以JSON格式输出。"
|
||||||
|
else:
|
||||||
|
self.messages.insert(0, {"role": "system", "content": "请以JSON格式输出。"})
|
||||||
|
|
||||||
return llm
|
return llm
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> AIMessage:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> AIMessage:
|
||||||
|
|||||||
@@ -395,11 +395,73 @@ class NoneObjectComparisonOperator:
|
|||||||
return lambda *args, **kwargs: False
|
return lambda *args, **kwargs: False
|
||||||
|
|
||||||
|
|
||||||
|
class ArrayFileContainsOperator:
|
||||||
|
"""Handles contains/not_contains on array[file] with sub_variable_condition."""
|
||||||
|
|
||||||
|
def __init__(self, left_value: list[dict], sub_variable_condition: Any, pool: VariablePool | None = None):
|
||||||
|
self.left_value = left_value
|
||||||
|
self.sub_variable_condition = sub_variable_condition
|
||||||
|
self.pool = pool
|
||||||
|
|
||||||
|
def _resolve_value(self, cond: Any) -> Any:
|
||||||
|
if cond.input_type == ValueInputType.VARIABLE and self.pool is not None:
|
||||||
|
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||||
|
selector = re.sub(pattern, r"\1", str(cond.value)).strip()
|
||||||
|
return self.pool.get_value(selector, default=None, strict=False)
|
||||||
|
return cond.value
|
||||||
|
|
||||||
|
def _match_item(self, file_item: dict) -> bool:
|
||||||
|
results = []
|
||||||
|
for cond in self.sub_variable_condition.conditions:
|
||||||
|
field_val = file_item.get(cond.key)
|
||||||
|
expected = self._resolve_value(cond)
|
||||||
|
result = self._eval_sub(field_val, cond.operator.value, expected)
|
||||||
|
results.append(result)
|
||||||
|
if self.sub_variable_condition.logical_operator.value == "and":
|
||||||
|
return all(results)
|
||||||
|
return any(results)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _eval_sub(field_val: Any, op: str, expected: Any) -> bool:
|
||||||
|
if field_val is None:
|
||||||
|
return op == "empty"
|
||||||
|
match op:
|
||||||
|
case "eq": return str(field_val) == str(expected)
|
||||||
|
case "ne": return str(field_val) != str(expected)
|
||||||
|
case "contains": return isinstance(field_val, str) and str(expected) in field_val
|
||||||
|
case "not_contains": return isinstance(field_val, str) and str(expected) not in field_val
|
||||||
|
case "in": return field_val in (expected if isinstance(expected, list) else [expected])
|
||||||
|
case "not_in": return field_val not in (expected if isinstance(expected, list) else [expected])
|
||||||
|
case "gt": return isinstance(field_val, (int, float)) and field_val > float(expected)
|
||||||
|
case "ge": return isinstance(field_val, (int, float)) and field_val >= float(expected)
|
||||||
|
case "lt": return isinstance(field_val, (int, float)) and field_val < float(expected)
|
||||||
|
case "le": return isinstance(field_val, (int, float)) and field_val <= float(expected)
|
||||||
|
case "empty": return field_val in (None, "", 0)
|
||||||
|
case "not_empty": return field_val not in (None, "", 0)
|
||||||
|
case _: return False
|
||||||
|
|
||||||
|
def contains(self) -> bool:
|
||||||
|
return any(self._match_item(f) for f in self.left_value if isinstance(f, dict))
|
||||||
|
|
||||||
|
def not_contains(self) -> bool:
|
||||||
|
return not self.contains()
|
||||||
|
|
||||||
|
def empty(self) -> bool:
|
||||||
|
return not self.left_value
|
||||||
|
|
||||||
|
def not_empty(self) -> bool:
|
||||||
|
return bool(self.left_value)
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return lambda *args, **kwargs: False
|
||||||
|
|
||||||
|
|
||||||
CompareOperatorInstance = Union[
|
CompareOperatorInstance = Union[
|
||||||
StringComparisonOperator,
|
StringComparisonOperator,
|
||||||
NumberComparisonOperator,
|
NumberComparisonOperator,
|
||||||
BooleanComparisonOperator,
|
BooleanComparisonOperator,
|
||||||
ArrayComparisonOperator,
|
ArrayComparisonOperator,
|
||||||
|
ArrayFileContainsOperator,
|
||||||
ObjectComparisonOperator
|
ObjectComparisonOperator
|
||||||
]
|
]
|
||||||
CompareOperatorType = Type[CompareOperatorInstance]
|
CompareOperatorType = Type[CompareOperatorInstance]
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from app.services.tool_service import ToolService
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
TEMPLATE_PATTERN = re.compile(r"\{\{.*?}}")
|
TEMPLATE_PATTERN = re.compile(r"\{\{.*?}}")
|
||||||
|
PURE_VARIABLE_PATTERN = re.compile(r"^\{\{\s*([\w.]+)\s*}}$")
|
||||||
|
|
||||||
|
|
||||||
class ToolNode(BaseNode):
|
class ToolNode(BaseNode):
|
||||||
@@ -52,13 +53,21 @@ class ToolNode(BaseNode):
|
|||||||
# 渲染工具参数
|
# 渲染工具参数
|
||||||
rendered_parameters = {}
|
rendered_parameters = {}
|
||||||
for param_name, param_template in self.typed_config.tool_parameters.items():
|
for param_name, param_template in self.typed_config.tool_parameters.items():
|
||||||
if isinstance(param_template, str) and TEMPLATE_PATTERN.search(param_template):
|
if isinstance(param_template, str):
|
||||||
try:
|
pure_match = PURE_VARIABLE_PATTERN.match(param_template)
|
||||||
rendered_value = self._render_template(param_template, variable_pool)
|
if pure_match:
|
||||||
except Exception as e:
|
# 纯单变量引用直接取原始值,保留 int/bool/float 等类型
|
||||||
raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e
|
rendered_value = self.get_variable(pure_match.group(1), variable_pool, strict=False)
|
||||||
|
if rendered_value is None:
|
||||||
|
rendered_value = self._render_template(param_template, variable_pool)
|
||||||
|
elif TEMPLATE_PATTERN.search(param_template):
|
||||||
|
try:
|
||||||
|
rendered_value = self._render_template(param_template, variable_pool)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e
|
||||||
|
else:
|
||||||
|
rendered_value = param_template
|
||||||
else:
|
else:
|
||||||
# 非模板参数(数字/布尔/普通字符串)直接保留原值
|
|
||||||
rendered_value = param_template
|
rendered_value = param_template
|
||||||
rendered_parameters[param_name] = rendered_value
|
rendered_parameters[param_name] = rendered_value
|
||||||
|
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ class FileVariable(BaseVariable):
|
|||||||
total_bytes = 0
|
total_bytes = 0
|
||||||
chunks = []
|
chunks = []
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||||
async with client.stream("GET", self.value.url) as resp:
|
async with client.stream("GET", self.value.url) as resp:
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
async for chunk in resp.aiter_bytes(8192):
|
async for chunk in resp.aiter_bytes(8192):
|
||||||
|
|||||||
@@ -29,11 +29,8 @@ class Tenants(Base):
|
|||||||
contact_email = Column(String(255), nullable=True) # 联系人邮箱
|
contact_email = Column(String(255), nullable=True) # 联系人邮箱
|
||||||
contact_phone = Column(String(50), nullable=True) # 联系人电话
|
contact_phone = Column(String(50), nullable=True) # 联系人电话
|
||||||
|
|
||||||
# 租户套餐信息
|
# 租户套餐信息(只读,从 tenant_subscriptions 动态获取)
|
||||||
plan = Column(String(50), nullable=True) # 套餐类型
|
status = Column(String(50), nullable=True, default='active', server_default='active') # 租户状态
|
||||||
plan_expired_at = Column(DateTime, nullable=True) # 套餐到期时间
|
|
||||||
api_ops_rate_limit = Column(String(100), nullable=True) # API 调用频率限制
|
|
||||||
status = Column(String(50), nullable=True, default='active') # 租户状态
|
|
||||||
|
|
||||||
# Relationship to users - one tenant has many users
|
# Relationship to users - one tenant has many users
|
||||||
users = relationship("User", back_populates="tenant")
|
users = relationship("User", back_populates="tenant")
|
||||||
|
|||||||
@@ -5,16 +5,9 @@ Implicit Emotions Storage Repository
|
|||||||
事务由调用方控制,仓储层只使用 flush/refresh
|
事务由调用方控制,仓储层只使用 flush/refresh
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
from datetime import date, datetime, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Generator, Optional
|
from typing import Generator, Optional
|
||||||
|
|
||||||
|
|
||||||
class TimeFilterUnavailableError(Exception):
|
|
||||||
"""redis_client 不可用,无法执行时间轴筛选。
|
|
||||||
|
|
||||||
调用方捕获此异常后可选择回退到 get_all_user_ids 进行全量处理。
|
|
||||||
"""
|
|
||||||
|
|
||||||
import redis
|
import redis
|
||||||
from sqlalchemy import exists, not_, select
|
from sqlalchemy import exists, not_, select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -25,6 +18,13 @@ from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TimeFilterUnavailableError(Exception):
|
||||||
|
"""redis_client 不可用,无法执行时间轴筛选。
|
||||||
|
|
||||||
|
调用方捕获此异常后可选择回退到 get_all_user_ids 进行全量处理。
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class ImplicitEmotionsStorageRepository:
|
class ImplicitEmotionsStorageRepository:
|
||||||
"""隐性记忆和情绪存储仓储类"""
|
"""隐性记忆和情绪存储仓储类"""
|
||||||
|
|
||||||
@@ -216,9 +216,7 @@ class ImplicitEmotionsStorageRepository:
|
|||||||
"""
|
"""
|
||||||
from sqlalchemy import String as SAString
|
from sqlalchemy import String as SAString
|
||||||
from sqlalchemy import cast
|
from sqlalchemy import cast
|
||||||
CST = timezone(timedelta(hours=8))
|
today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
now_cst = datetime.now(CST)
|
|
||||||
today_start = now_cst.replace(hour=0, minute=0, second=0, microsecond=0).astimezone(timezone.utc).replace(tzinfo=None)
|
|
||||||
tomorrow_start = today_start + timedelta(days=1)
|
tomorrow_start = today_start + timedelta(days=1)
|
||||||
offset = 0
|
offset = 0
|
||||||
while True:
|
while True:
|
||||||
|
|||||||
@@ -328,7 +328,7 @@ class MemoryConfigRepository:
|
|||||||
if not db_config:
|
if not db_config:
|
||||||
db_logger.warning(f"记忆配置不存在: config_id={update.config_id}")
|
db_logger.warning(f"记忆配置不存在: config_id={update.config_id}")
|
||||||
return None
|
return None
|
||||||
|
#TODO:部分更新没有用patch请求,是在Repository层中用先查再部分更新的方式实现的,后续可以考虑改成patch请求更符合RESTful设计原则
|
||||||
update_data = update.model_dump(exclude_unset=True)
|
update_data = update.model_dump(exclude_unset=True)
|
||||||
update_data.pop("config_id", None)
|
update_data.pop("config_id", None)
|
||||||
|
|
||||||
|
|||||||
@@ -263,16 +263,15 @@ class ModelConfigRepository:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_by_type(db: Session, model_type: ModelType, tenant_id: uuid.UUID | None = None, is_active: bool = True) -> List[ModelConfig]:
|
def get_by_type(db: Session, model_types: List[ModelType], tenant_id: uuid.UUID | None = None, is_active: bool = True) -> List[ModelConfig]:
|
||||||
"""根据类型获取模型配置"""
|
"""根据类型获取模型配置,支持多类型查询"""
|
||||||
db_logger.debug(f"根据类型查询模型配置: type={model_type}, tenant_id={tenant_id}, is_active={is_active}")
|
db_logger.debug(f"根据类型查询模型配置: types={[t.value for t in model_types]}, tenant_id={tenant_id}, is_active={is_active}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
query = db.query(ModelConfig).options(
|
query = db.query(ModelConfig).options(
|
||||||
joinedload(ModelConfig.api_keys)
|
joinedload(ModelConfig.api_keys)
|
||||||
).filter(ModelConfig.type == model_type)
|
).filter(ModelConfig.type.in_([t.value for t in model_types]))
|
||||||
|
|
||||||
# 添加租户过滤
|
|
||||||
if tenant_id:
|
if tenant_id:
|
||||||
query = query.filter(
|
query = query.filter(
|
||||||
or_(
|
or_(
|
||||||
@@ -284,12 +283,14 @@ class ModelConfigRepository:
|
|||||||
if is_active:
|
if is_active:
|
||||||
query = query.filter(ModelConfig.is_active)
|
query = query.filter(ModelConfig.is_active)
|
||||||
|
|
||||||
models = query.order_by(ModelConfig.name).all()
|
query = query.filter(ModelConfig.is_composite == False)
|
||||||
|
|
||||||
|
models = query.order_by(ModelConfig.created_at.desc()).all()
|
||||||
db_logger.debug(f"根据类型查询模型配置成功: 数量={len(models)}")
|
db_logger.debug(f"根据类型查询模型配置成功: 数量={len(models)}")
|
||||||
return models
|
return models
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db_logger.error(f"根据类型查询模型配置失败: type={model_type} - {str(e)}")
|
db_logger.error(f"根据类型查询模型配置失败: types={model_types} - {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -93,6 +93,8 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity
|
|||||||
END,
|
END,
|
||||||
e.statement_id = CASE WHEN entity.statement_id IS NOT NULL AND entity.statement_id <> '' THEN entity.statement_id ELSE e.statement_id END,
|
e.statement_id = CASE WHEN entity.statement_id IS NOT NULL AND entity.statement_id <> '' THEN entity.statement_id ELSE e.statement_id END,
|
||||||
e.aliases = CASE
|
e.aliases = CASE
|
||||||
|
// 用户实体的 aliases 由 PgSQL end_user_info 作为唯一权威源,知识抽取完全不写入
|
||||||
|
WHEN entity.name IN ['用户', '我', 'User', 'I'] THEN e.aliases
|
||||||
WHEN entity.aliases IS NOT NULL AND size(entity.aliases) > 0
|
WHEN entity.aliases IS NOT NULL AND size(entity.aliases) > 0
|
||||||
THEN CASE
|
THEN CASE
|
||||||
WHEN e.aliases IS NULL THEN entity.aliases
|
WHEN e.aliases IS NULL THEN entity.aliases
|
||||||
|
|||||||
@@ -77,11 +77,11 @@ class Neo4jConnector:
|
|||||||
"""
|
"""
|
||||||
await self.driver.close()
|
await self.driver.close()
|
||||||
|
|
||||||
async def execute_query(self, query: str, json_format=False, **kwargs: Any) -> List[Dict[str, Any]]:
|
async def execute_query(self, cypher: str, json_format=False, **kwargs: Any) -> List[Dict[str, Any]]:
|
||||||
"""执行Cypher查询
|
"""执行Cypher查询
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: Cypher查询语句
|
cypher: Cypher查询语句
|
||||||
json_format: json格式化
|
json_format: json格式化
|
||||||
**kwargs: 查询参数,将作为参数传递给Cypher查询
|
**kwargs: 查询参数,将作为参数传递给Cypher查询
|
||||||
|
|
||||||
@@ -92,7 +92,7 @@ class Neo4jConnector:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
result = await self.driver.execute_query(
|
result = await self.driver.execute_query(
|
||||||
query,
|
cypher,
|
||||||
database="neo4j",
|
database="neo4j",
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -297,6 +297,10 @@ def get_user_by_id(db: Session, user_id: uuid.UUID) -> Optional[User]:
|
|||||||
"""根据ID获取用户"""
|
"""根据ID获取用户"""
|
||||||
return UserRepository(db).get_user_by_id(user_id)
|
return UserRepository(db).get_user_by_id(user_id)
|
||||||
|
|
||||||
|
def get_user_by_id_regardless_active(db: Session, user_id: uuid.UUID) -> Optional[User]:
|
||||||
|
"""根据ID获取用户(不过滤 is_active,用于启用/禁用场景)"""
|
||||||
|
return db.query(User).filter(User.id == user_id).first()
|
||||||
|
|
||||||
def get_user_by_email(db: Session, email: str) -> Optional[User]:
|
def get_user_by_email(db: Session, email: str) -> Optional[User]:
|
||||||
"""根据邮箱获取用户"""
|
"""根据邮箱获取用户"""
|
||||||
return UserRepository(db).get_user_by_email(email)
|
return UserRepository(db).get_user_by_email(email)
|
||||||
|
|||||||
@@ -44,6 +44,8 @@ class FileInput(BaseModel):
|
|||||||
upload_file_id: Optional[uuid.UUID] = Field(None, description="已上传文件ID(local_file时必填)")
|
upload_file_id: Optional[uuid.UUID] = Field(None, description="已上传文件ID(local_file时必填)")
|
||||||
url: Optional[str] = Field(None, description="远程URL(remote_url时必填)")
|
url: Optional[str] = Field(None, description="远程URL(remote_url时必填)")
|
||||||
file_type: Optional[str] = Field(None, description="具体文件格式(如image/jpg、audio/wav、document/docx、video/mp4)")
|
file_type: Optional[str] = Field(None, description="具体文件格式(如image/jpg、audio/wav、document/docx、video/mp4)")
|
||||||
|
name: Optional[str] = Field(None, description="文件名")
|
||||||
|
size: Optional[int] = Field(None, description="文件大小(字节)")
|
||||||
|
|
||||||
_content = None
|
_content = None
|
||||||
|
|
||||||
@@ -243,6 +245,7 @@ class ModelParameters(BaseModel):
|
|||||||
stop: Optional[List[str]] = Field(default=None, description="停止序列")
|
stop: Optional[List[str]] = Field(default=None, description="停止序列")
|
||||||
deep_thinking: bool = Field(default=False, description="是否启用深度思考模式(需模型支持,如 DeepSeek-R1、QwQ 等)")
|
deep_thinking: bool = Field(default=False, description="是否启用深度思考模式(需模型支持,如 DeepSeek-R1、QwQ 等)")
|
||||||
thinking_budget_tokens: Optional[int] = Field(default=None, ge=1024, le=131072, description="深度思考 token 预算(仅部分模型支持)")
|
thinking_budget_tokens: Optional[int] = Field(default=None, ge=1024, le=131072, description="深度思考 token 预算(仅部分模型支持)")
|
||||||
|
json_output: bool = Field(default=False, description="是否强制 JSON 格式输出(需模型支持 json_output 能力)")
|
||||||
|
|
||||||
|
|
||||||
class VariableDefinition(BaseModel):
|
class VariableDefinition(BaseModel):
|
||||||
|
|||||||
@@ -4,9 +4,10 @@ This module defines Pydantic schemas for the Memory API Service endpoints,
|
|||||||
including request validation and response structures for read and write operations.
|
including request validation and response structures for read and write operations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
|
import uuid
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
|
|
||||||
|
|
||||||
class MemoryWriteRequest(BaseModel):
|
class MemoryWriteRequest(BaseModel):
|
||||||
@@ -110,6 +111,30 @@ class MemoryReadRequest(BaseModel):
|
|||||||
class MemoryWriteResponse(BaseModel):
|
class MemoryWriteResponse(BaseModel):
|
||||||
"""Response schema for memory write operation.
|
"""Response schema for memory write operation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
task_id: Celery task ID for status polling
|
||||||
|
status: Initial task status (PENDING)
|
||||||
|
end_user_id: End user ID the write was submitted for
|
||||||
|
"""
|
||||||
|
task_id: str = Field(..., description="Celery task ID for polling")
|
||||||
|
status: str = Field(..., description="Task status: PENDING")
|
||||||
|
end_user_id: str = Field(..., description="End user ID")
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatusResponse(BaseModel):
|
||||||
|
"""Response schema for task status check.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
status: Task status (PENDING, STARTED, SUCCESS, FAILURE, SKIPPED)
|
||||||
|
result: Task result data (available when status is SUCCESS or FAILURE)
|
||||||
|
"""
|
||||||
|
status: str = Field(..., description="Task status")
|
||||||
|
result: Optional[Dict[str, Any]] = Field(None, description="Task result when completed")
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryWriteSyncResponse(BaseModel):
|
||||||
|
"""Response schema for synchronous memory write.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
status: Operation status (success or failed)
|
status: Operation status (success or failed)
|
||||||
end_user_id: End user ID that was written to
|
end_user_id: End user ID that was written to
|
||||||
@@ -118,8 +143,8 @@ class MemoryWriteResponse(BaseModel):
|
|||||||
end_user_id: str = Field(..., description="End user ID")
|
end_user_id: str = Field(..., description="End user ID")
|
||||||
|
|
||||||
|
|
||||||
class MemoryReadResponse(BaseModel):
|
class MemoryReadSyncResponse(BaseModel):
|
||||||
"""Response schema for memory read operation.
|
"""Response schema for synchronous memory read.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
answer: Generated answer from memory retrieval
|
answer: Generated answer from memory retrieval
|
||||||
@@ -134,6 +159,19 @@ class MemoryReadResponse(BaseModel):
|
|||||||
end_user_id: str = Field(..., description="End user ID")
|
end_user_id: str = Field(..., description="End user ID")
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryReadResponse(BaseModel):
|
||||||
|
"""Response schema for memory read operation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
task_id: Celery task ID for status polling
|
||||||
|
status: Initial task status (PENDING)
|
||||||
|
end_user_id: End user ID the read was submitted for
|
||||||
|
"""
|
||||||
|
task_id: str = Field(..., description="Celery task ID for polling")
|
||||||
|
status: str = Field(..., description="Task status: PENDING")
|
||||||
|
end_user_id: str = Field(..., description="End user ID")
|
||||||
|
|
||||||
|
|
||||||
class CreateEndUserRequest(BaseModel):
|
class CreateEndUserRequest(BaseModel):
|
||||||
"""Request schema for creating an end user.
|
"""Request schema for creating an end user.
|
||||||
|
|
||||||
@@ -141,10 +179,12 @@ class CreateEndUserRequest(BaseModel):
|
|||||||
other_id: External user identifier (required)
|
other_id: External user identifier (required)
|
||||||
other_name: Display name for the end user
|
other_name: Display name for the end user
|
||||||
memory_config_id: Optional memory config ID. If not provided, uses workspace default.
|
memory_config_id: Optional memory config ID. If not provided, uses workspace default.
|
||||||
|
app_id: Optional app ID to bind the end user to.
|
||||||
"""
|
"""
|
||||||
other_id: str = Field(..., description="External user identifier (required)")
|
other_id: str = Field(..., description="External user identifier (required)")
|
||||||
other_name: Optional[str] = Field("", description="Display name")
|
other_name: Optional[str] = Field("", description="Display name")
|
||||||
memory_config_id: Optional[str] = Field(None, description="Memory config ID. Falls back to workspace default if not provided.")
|
memory_config_id: Optional[str] = Field(None, description="Memory config ID. Falls back to workspace default if not provided.")
|
||||||
|
app_id: Optional[str] = Field(None, description="App ID to bind the end user to")
|
||||||
|
|
||||||
@field_validator("other_id")
|
@field_validator("other_id")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -192,6 +232,7 @@ class MemoryConfigItem(BaseModel):
|
|||||||
created_at: Optional[str] = Field(None, description="Creation timestamp")
|
created_at: Optional[str] = Field(None, description="Creation timestamp")
|
||||||
updated_at: Optional[str] = Field(None, description="Last update timestamp")
|
updated_at: Optional[str] = Field(None, description="Last update timestamp")
|
||||||
|
|
||||||
|
# ========== V1 记忆配置管理接口 Schema ==========
|
||||||
|
|
||||||
class ListConfigsResponse(BaseModel):
|
class ListConfigsResponse(BaseModel):
|
||||||
"""Response schema for listing memory configs.
|
"""Response schema for listing memory configs.
|
||||||
@@ -202,3 +243,203 @@ class ListConfigsResponse(BaseModel):
|
|||||||
"""
|
"""
|
||||||
configs: List[MemoryConfigItem] = Field(default_factory=list, description="List of configs")
|
configs: List[MemoryConfigItem] = Field(default_factory=list, description="List of configs")
|
||||||
total: int = Field(0, description="Total number of configs")
|
total: int = Field(0, description="Total number of configs")
|
||||||
|
|
||||||
|
class ConfigCreateRequest(BaseModel):
|
||||||
|
"""Request schema for creating a new memory config."""
|
||||||
|
config_name: str = Field(..., description="Configuration name")
|
||||||
|
config_desc: Optional[str] = Field("", description="Configuration description")
|
||||||
|
scene_id: uuid.UUID = Field(..., description="Associated ontology scene ID (UUID, required)")
|
||||||
|
|
||||||
|
llm_id: Optional[str] = Field(None, description="LLM model configuration ID")
|
||||||
|
embedding_id: Optional[str] = Field(None, description="Embedding model configuration ID")
|
||||||
|
rerank_id: Optional[str] = Field(None, description="Reranking model configuration ID")
|
||||||
|
reflection_model_id: Optional[str] = Field(None, description="Reflection model ID")
|
||||||
|
emotion_model_id: Optional[str] = Field(None, description="Emotion analysis model ID")
|
||||||
|
|
||||||
|
@field_validator("config_name")
|
||||||
|
@classmethod
|
||||||
|
def validate_config_name(cls, v: str) -> str:
|
||||||
|
if not v or not v.strip():
|
||||||
|
raise ValueError("config_name is required and cannot be empty")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
class ConfigUpdateRequest(BaseModel):
|
||||||
|
"""Request schema for updating memory config basic info.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
config_id: Configuration UUID to update (required)
|
||||||
|
config_name: New configuration name
|
||||||
|
config_desc: New configuration description
|
||||||
|
scene_id: New associated ontology scene ID
|
||||||
|
"""
|
||||||
|
config_id: str = Field(..., description="Configuration ID to update")
|
||||||
|
config_name: Optional[str] = Field(None, description="Configuration name")
|
||||||
|
config_desc: Optional[str] = Field(None, description="Configuration description")
|
||||||
|
scene_id: Optional[uuid.UUID] = Field(None, description="Associated ontology scene ID")
|
||||||
|
|
||||||
|
@field_validator("config_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_config_id(cls, v: str) -> str:
|
||||||
|
"""Validate that config_id is not empty."""
|
||||||
|
if not v or not v.strip():
|
||||||
|
raise ValueError("config_id is required and cannot be empty")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
class ConfigUpdateExtractedRequest(BaseModel):
|
||||||
|
"""Request schema for updating memory config extracted parameters.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
config_id: Configuration UUID to update (required)
|
||||||
|
llm_id: Optional LLM model configuration ID
|
||||||
|
audio_id: Optional audio model configuration ID
|
||||||
|
vision_id: Optional vision model configuration ID
|
||||||
|
video_id: Optional video model configuration ID
|
||||||
|
embedding_id: Optional embedding model configuration ID
|
||||||
|
rerank_id: Optional reranking model configuration ID
|
||||||
|
enable_llm_dedup_blockwise: Optional toggle for LLM decision deduplication
|
||||||
|
enable_llm_disambiguation: Optional toggle for LLM decision disambiguation
|
||||||
|
deep_retrieval: Optional toggle for deep retrieval
|
||||||
|
|
||||||
|
t_type_strict: Optional float (0-1) for type strictness threshold
|
||||||
|
t_name_strict: Optional float (0-1) for name strictness threshold
|
||||||
|
t_overall: Optional float (0-1) for overall strictness threshold
|
||||||
|
state: Optional boolean for config active state
|
||||||
|
chunker_strategy: Optional string for memory chunking strategy
|
||||||
|
statement_granularity: Optional int (1-3) for statement extraction granularity
|
||||||
|
include_dialogue_context: Optional boolean for including dialogue context in retrieval
|
||||||
|
max_context: Optional int for maximum dialogue context length in characters
|
||||||
|
pruning_enabled: Optional boolean to enable intelligent semantic pruning
|
||||||
|
pruning_scene: Optional string for semantic pruning scene
|
||||||
|
pruning_threshold: Optional float (0-0.9) for semantic pruning threshold
|
||||||
|
enable_self_reflexion: Optional boolean to enable self-reflexion
|
||||||
|
iteration_period: Optional string for reflexion iteration period in hours (1, 3, 6, 12, 24)
|
||||||
|
reflexion_range: Optional string for reflexion range (partial or all)
|
||||||
|
baseline: Optional string for baseline (TIME/FACT/TIME-FACT)
|
||||||
|
|
||||||
|
"""
|
||||||
|
config_id: str = Field(..., description="Configuration ID (UUID)")
|
||||||
|
llm_id: Optional[str] = Field(None, description="LLM model configuration ID")
|
||||||
|
audio_id: Optional[str] = Field(None, description="Audio model ID")
|
||||||
|
vision_id: Optional[str] = Field(None, description="Vision model ID")
|
||||||
|
video_id: Optional[str] = Field(None, description="Video model ID")
|
||||||
|
embedding_id: Optional[str] = Field(None, description="Embedding model configuration ID")
|
||||||
|
rerank_id: Optional[str] = Field(None, description="Reranking model configuration ID")
|
||||||
|
enable_llm_dedup_blockwise: Optional[bool] = Field(None, description="Enable LLM decision deduplication")
|
||||||
|
enable_llm_disambiguation: Optional[bool] = Field(None, description="Enable LLM decision disambiguation")
|
||||||
|
deep_retrieval: Optional[bool] = Field(None, description="Deep retrieval toggle")
|
||||||
|
|
||||||
|
t_type_strict: Optional[float] = Field(None, ge=0.0, le=1.0, description="type strictness threshold")
|
||||||
|
t_name_strict: Optional[float] = Field(None, ge=0.0, le=1.0, description="name strictness threshold")
|
||||||
|
t_overall: Optional[float] = Field(None, ge=0.0, le=1.0, description="overall strictness threshold")
|
||||||
|
state: Optional[bool] = Field(None, description="config active state")
|
||||||
|
# 句子提取
|
||||||
|
chunker_strategy: Optional[str] = Field(None, description="memory chunking strategy")
|
||||||
|
statement_granularity: Optional[int] = Field(None, ge=1, le=3, description="statement extraction granularity")
|
||||||
|
include_dialogue_context: Optional[bool] = Field(None, description="whether to include dialogue context in retrieval")
|
||||||
|
max_context: Optional[int] = Field(None, gt=100, description="maximum dialogue context length in characters")
|
||||||
|
# 剪枝配置:与 runtime.json 中 pruning 段对应
|
||||||
|
pruning_enabled: Optional[bool] = Field(None, description="whether to enable intelligent semantic pruning")
|
||||||
|
pruning_scene: Optional[str] = Field(None, description="semantic pruning scene")
|
||||||
|
pruning_threshold: Optional[float] = Field(None, ge=0.0, le=0.9, description="semantic pruning threshold (0-0.9)")
|
||||||
|
enable_self_reflexion: Optional[bool] = Field(None, description="whether to enable self-reflexion")
|
||||||
|
iteration_period: Optional[Literal["1", "3", "6", "12", "24"]] = Field(None, description="reflexion iteration period in hours (1, 3, 6, 12, 24)")
|
||||||
|
reflexion_range: Optional[Literal["partial", "all"]] = Field(None, description="reflexion range: partial/all")
|
||||||
|
baseline: Optional[Literal["TIME", "FACT", "TIME-FACT"]] = Field(None, description="baseline: TIME/FACT/TIME-FACT")
|
||||||
|
|
||||||
|
@field_validator("config_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_config_id(cls, v: str) -> str:
|
||||||
|
if not v or not v.strip():
|
||||||
|
raise ValueError("config_id is required and cannot be empty")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
class ConfigUpdateForgettingRequest(BaseModel):
|
||||||
|
"""Request schema for updating memory config forgetting parameters.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
config_id: Configuration UUID to update (required)
|
||||||
|
decay_constant: Decay constant for forgetting
|
||||||
|
lambda_time: Time decay parameter
|
||||||
|
lambda_mem: Memory decay parameter
|
||||||
|
offset: Offset for forgetting curve
|
||||||
|
max_history_length: Maximum history length to consider for forgetting
|
||||||
|
forgetting_threshold: Threshold for forgetting
|
||||||
|
min_days_since_access: Minimum days since last access to trigger forgetting
|
||||||
|
enable_llm_summary: Whether to use LLM-generated summaries for forgetting
|
||||||
|
max_merge_batch_size: Maximum batch size for merging nodes during forgetting
|
||||||
|
forgetting_interval_hours: Interval in hours for periodic forgetting
|
||||||
|
|
||||||
|
"""
|
||||||
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
|
config_id: str = Field(..., description="Configuration ID (UUID)")
|
||||||
|
decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="Decay constant for forgetting")
|
||||||
|
lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="Time decay parameter")
|
||||||
|
lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="Memory decay parameter")
|
||||||
|
offset: Optional[float] = Field(None, ge=0.0, le=1.0, description="Offset for forgetting curve")
|
||||||
|
max_history_length: Optional[int] = Field(None, ge=10, le=1000, description="Maximum history length to consider for forgetting")
|
||||||
|
forgetting_threshold: Optional[float] = Field(None, ge=0.0, le=1.0, description="Forgetting threshold")
|
||||||
|
min_days_since_access: Optional[int] = Field(None, ge=1, le=365, description="Minimum days since last access to trigger forgetting")
|
||||||
|
enable_llm_summary: Optional[bool] = Field(None, description="Whether to use LLM-generated summaries for forgetting")
|
||||||
|
max_merge_batch_size: Optional[int] = Field(None, ge=1, le=1000, description="Maximum batch size for merging nodes during forgetting")
|
||||||
|
forgetting_interval_hours: Optional[int] = Field(None, ge=1, le=168, description="Interval in hours for periodic forgetting")
|
||||||
|
|
||||||
|
@field_validator("config_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_config_id(cls, v: str) -> str:
|
||||||
|
if not v or not v.strip():
|
||||||
|
raise ValueError("config_id is required and cannot be empty")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
class EmotionConfigUpdateRequest(BaseModel):
|
||||||
|
"""Request schema for updating memory config emotion parameters.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
config_id: Configuration UUID to update (required)
|
||||||
|
emotion_enabled: Whether to enable emotion extraction
|
||||||
|
emotion_model_id: Emotion analysis model ID
|
||||||
|
emotion_extract_keywords: Whether to extract emotion keywords
|
||||||
|
emotion_min_intensity: Minimum emotion intensity threshold (0.0-1.0)
|
||||||
|
emotion_enable_subject: Whether to enable subject classification for emotions
|
||||||
|
"""
|
||||||
|
config_id: str = Field(..., description="Configuration ID (UUID)")
|
||||||
|
emotion_enabled: bool = Field(..., description="Whether to enable emotion extraction")
|
||||||
|
emotion_model_id: Optional[str] = Field(None, description="Emotion analysis model ID")
|
||||||
|
emotion_extract_keywords: bool = Field(..., description="Whether to extract emotion keywords")
|
||||||
|
emotion_min_intensity: float = Field(..., ge=0.0, le=1.0, description="Minimum emotion intensity threshold")
|
||||||
|
emotion_enable_subject: bool = Field(..., description="Whether to enable subject classification for emotions")
|
||||||
|
|
||||||
|
@field_validator("config_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_config_id(cls, v: str) -> str:
|
||||||
|
if not v or not v.strip():
|
||||||
|
raise ValueError("config_id is required and cannot be empty")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
class ReflectionConfigUpdateRequest(BaseModel):
|
||||||
|
"""Request schema for updating memory config reflection parameters.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
config_id: Configuration UUID to update (required)
|
||||||
|
reflection_enabled: Whether to enable self-reflection
|
||||||
|
reflection_period_in_hours: Reflection iteration period in hours
|
||||||
|
reflexion_range: Reflection range (partial or all)
|
||||||
|
baseline: Baseline for reflection (TIME/FACT/TIME-FACT)
|
||||||
|
reflection_model_id: Reflection model ID
|
||||||
|
memory_verify: Whether to enable memory verification
|
||||||
|
quality_assessment: Whether to enable quality assessment
|
||||||
|
"""
|
||||||
|
config_id: str = Field(..., description="Configuration ID (UUID)")
|
||||||
|
reflection_enabled: bool = Field(..., description="Whether to enable self-reflection")
|
||||||
|
reflection_period_in_hours: str = Field(..., description="Reflection iteration period in hours")
|
||||||
|
reflexion_range: Literal["partial", "all"] = Field(..., description="Reflection range: partial/all")
|
||||||
|
baseline: Literal["TIME", "FACT", "TIME-FACT"] = Field(..., description="Baseline: TIME/FACT/TIME-FACT")
|
||||||
|
reflection_model_id: str = Field(..., description="Reflection model ID")
|
||||||
|
memory_verify: bool = Field(..., description="Whether to enable memory verification")
|
||||||
|
quality_assessment: bool = Field(..., description="Whether to enable quality assessment")
|
||||||
|
|
||||||
|
@field_validator("config_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_config_id(cls, v: str) -> str:
|
||||||
|
if not v or not v.strip():
|
||||||
|
raise ValueError("config_id is required and cannot be empty")
|
||||||
|
return v.strip()
|
||||||
|
|||||||
@@ -291,7 +291,7 @@ class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数
|
|||||||
pruning_threshold: Optional[float] = Field(
|
pruning_threshold: Optional[float] = Field(
|
||||||
None, ge=0.0, le=0.9, description="智能语义剪枝阈值(0-0.9)"
|
None, ge=0.0, le=0.9, description="智能语义剪枝阈值(0-0.9)"
|
||||||
)
|
)
|
||||||
|
#TODO:萃取引擎的更新的更新会带有反思引擎的参数,需判断业务是否需要,不需要可以重构
|
||||||
# 反思配置
|
# 反思配置
|
||||||
enable_self_reflexion: Optional[bool] = Field(None, description="是否启用自我反思")
|
enable_self_reflexion: Optional[bool] = Field(None, description="是否启用自我反思")
|
||||||
iteration_period: Optional[Literal["1", "3", "6", "12", "24"]] = Field(
|
iteration_period: Optional[Literal["1", "3", "6", "12", "24"]] = Field(
|
||||||
|
|||||||
@@ -248,6 +248,35 @@ class RateLimiterService:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.redis = aio_redis
|
self.redis = aio_redis
|
||||||
|
|
||||||
|
async def check_tenant_rate_limit(self, tenant_id: uuid.UUID, limit: int) -> Tuple[bool, dict]:
|
||||||
|
"""
|
||||||
|
按 tenant_id 做 1 秒滑动窗口限速,限制值来自套餐配额 api_ops_rate_limit
|
||||||
|
"""
|
||||||
|
now = time.time()
|
||||||
|
window_start = now - 1 # 1 秒窗口
|
||||||
|
key = f"rate_limit:tenant_qps:{tenant_id}"
|
||||||
|
|
||||||
|
async with self.redis.pipeline() as pipe:
|
||||||
|
# 清理 1 秒前的旧记录
|
||||||
|
pipe.zremrangebyscore(key, 0, window_start)
|
||||||
|
# 加入当前请求(score=时间戳,member=时间戳+随机数保证唯一)
|
||||||
|
pipe.zadd(key, {f"{now}:{uuid.uuid4().hex}": now})
|
||||||
|
# 统计窗口内请求数
|
||||||
|
pipe.zcard(key)
|
||||||
|
# 设置 key 过期(2 秒后自动清理)
|
||||||
|
pipe.expire(key, 2)
|
||||||
|
results = await pipe.execute()
|
||||||
|
|
||||||
|
current = results[2]
|
||||||
|
remaining = max(0, limit - current)
|
||||||
|
reset_time = int(now) + 1
|
||||||
|
|
||||||
|
return current <= limit, {
|
||||||
|
"limit": limit,
|
||||||
|
"remaining": remaining,
|
||||||
|
"reset": reset_time,
|
||||||
|
}
|
||||||
|
|
||||||
async def check_qps(self, api_key_id: uuid.UUID, limit: int) -> Tuple[bool, dict]:
|
async def check_qps(self, api_key_id: uuid.UUID, limit: int) -> Tuple[bool, dict]:
|
||||||
"""
|
"""
|
||||||
检查QPS限制
|
检查QPS限制
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from app.services.model_service import ModelApiKeyService
|
|||||||
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
||||||
from app.services.multimodal_service import MultimodalService
|
from app.services.multimodal_service import MultimodalService
|
||||||
from app.services.workflow_service import WorkflowService
|
from app.services.workflow_service import WorkflowService
|
||||||
|
from app.models.file_metadata_model import FileMetadata
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
@@ -119,6 +120,7 @@ class AppChatService:
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
deep_thinking=model_parameters.get("deep_thinking", False),
|
deep_thinking=model_parameters.get("deep_thinking", False),
|
||||||
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
|
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
|
||||||
|
json_output=model_parameters.get("json_output", False),
|
||||||
capability=api_key_obj.capability or [],
|
capability=api_key_obj.capability or [],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -218,11 +220,29 @@ class AppChatService:
|
|||||||
"reasoning_content": result.get("reasoning_content")
|
"reasoning_content": result.get("reasoning_content")
|
||||||
}
|
}
|
||||||
if files:
|
if files:
|
||||||
|
local_ids = [f.upload_file_id for f in files
|
||||||
|
if f.transfer_method.value == "local_file" and f.upload_file_id
|
||||||
|
and (not f.name or not f.size)]
|
||||||
|
meta_map = {}
|
||||||
|
if local_ids:
|
||||||
|
rows = self.db.query(FileMetadata).filter(
|
||||||
|
FileMetadata.id.in_(local_ids),
|
||||||
|
FileMetadata.status == "completed"
|
||||||
|
).all()
|
||||||
|
meta_map = {str(r.id): r for r in rows}
|
||||||
for f in files:
|
for f in files:
|
||||||
# url = await MultimodalService(self.db).get_file_url(f)
|
name, size = f.name, f.size
|
||||||
|
if f.transfer_method.value == "local_file" and f.upload_file_id and (not name or not size):
|
||||||
|
meta = meta_map.get(str(f.upload_file_id))
|
||||||
|
if meta:
|
||||||
|
name = name or meta.file_name
|
||||||
|
size = size or meta.file_size
|
||||||
human_meta["files"].append({
|
human_meta["files"].append({
|
||||||
"type": f.type,
|
"type": f.type,
|
||||||
"url": f.url
|
"url": f.url,
|
||||||
|
"name": name,
|
||||||
|
"size": size,
|
||||||
|
"file_type": f.file_type,
|
||||||
})
|
})
|
||||||
|
|
||||||
if processed_files:
|
if processed_files:
|
||||||
@@ -373,6 +393,7 @@ class AppChatService:
|
|||||||
streaming=True,
|
streaming=True,
|
||||||
deep_thinking=model_parameters.get("deep_thinking", False),
|
deep_thinking=model_parameters.get("deep_thinking", False),
|
||||||
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
|
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
|
||||||
|
json_output=model_parameters.get("json_output", False),
|
||||||
capability=api_key_obj.capability or [],
|
capability=api_key_obj.capability or [],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -509,10 +530,29 @@ class AppChatService:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if files:
|
if files:
|
||||||
|
local_ids = [f.upload_file_id for f in files
|
||||||
|
if f.transfer_method.value == "local_file" and f.upload_file_id
|
||||||
|
and (not f.name or not f.size)]
|
||||||
|
meta_map = {}
|
||||||
|
if local_ids:
|
||||||
|
rows = self.db.query(FileMetadata).filter(
|
||||||
|
FileMetadata.id.in_(local_ids),
|
||||||
|
FileMetadata.status == "completed"
|
||||||
|
).all()
|
||||||
|
meta_map = {str(r.id): r for r in rows}
|
||||||
for f in files:
|
for f in files:
|
||||||
|
name, size = f.name, f.size
|
||||||
|
if f.transfer_method.value == "local_file" and f.upload_file_id and (not name or not size):
|
||||||
|
meta = meta_map.get(str(f.upload_file_id))
|
||||||
|
if meta:
|
||||||
|
name = name or meta.file_name
|
||||||
|
size = size or meta.file_size
|
||||||
human_meta["files"].append({
|
human_meta["files"].append({
|
||||||
"type": f.type,
|
"type": f.type,
|
||||||
"url": f.url
|
"url": f.url,
|
||||||
|
"name": name,
|
||||||
|
"size": size,
|
||||||
|
"file_type": f.file_type,
|
||||||
})
|
})
|
||||||
if processed_files:
|
if processed_files:
|
||||||
human_meta["history_files"] = {
|
human_meta["history_files"] = {
|
||||||
|
|||||||
@@ -14,12 +14,14 @@ from app.models.app_model import App, AppType
|
|||||||
from app.models.appshare_model import AppShare
|
from app.models.appshare_model import AppShare
|
||||||
from app.models.app_release_model import AppRelease
|
from app.models.app_release_model import AppRelease
|
||||||
from app.models.knowledge_model import Knowledge
|
from app.models.knowledge_model import Knowledge
|
||||||
|
from app.models.knowledgeshare_model import KnowledgeShare
|
||||||
from app.models.models_model import ModelConfig
|
from app.models.models_model import ModelConfig
|
||||||
from app.models.tool_model import ToolConfig as ToolConfigModel
|
from app.models.tool_model import ToolConfig as ToolConfigModel
|
||||||
from app.models.skill_model import Skill
|
from app.models.skill_model import Skill
|
||||||
from app.models.workflow_model import WorkflowConfig
|
from app.models.workflow_model import WorkflowConfig
|
||||||
from app.services.workflow_service import WorkflowService
|
from app.services.workflow_service import WorkflowService
|
||||||
from app.core.workflow.adapters.memory_bear.memory_bear_adapter import MemoryBearAdapter
|
from app.core.workflow.adapters.memory_bear.memory_bear_adapter import MemoryBearAdapter
|
||||||
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
|
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
|
||||||
|
|
||||||
|
|
||||||
@@ -73,15 +75,14 @@ class AppDslService:
|
|||||||
AppType.MULTI_AGENT: "multi_agent_config",
|
AppType.MULTI_AGENT: "multi_agent_config",
|
||||||
AppType.WORKFLOW: "workflow"
|
AppType.WORKFLOW: "workflow"
|
||||||
}.get(app.type, "config")
|
}.get(app.type, "config")
|
||||||
config_data = self._enrich_release_config(app.type, release.config or {})
|
config_data = self._enrich_release_config(app.type, release.config or {}, release.default_model_config_id)
|
||||||
dsl = {**meta, "app": app_meta, config_key: config_data}
|
dsl = {**meta, "app": app_meta, config_key: config_data}
|
||||||
return yaml.dump(dsl, default_flow_style=False, allow_unicode=True), f"{release.name}_v{release.version_name}.yaml"
|
return yaml.dump(dsl, default_flow_style=False, allow_unicode=True), f"{release.name}_v{release.version_name}.yaml"
|
||||||
|
|
||||||
def _enrich_release_config(self, app_type: str, cfg: dict) -> dict:
|
def _enrich_release_config(self, app_type: str, cfg: dict, default_model_config_id=None) -> dict:
|
||||||
if app_type == AppType.AGENT:
|
if app_type == AppType.AGENT:
|
||||||
enriched = {**cfg}
|
enriched = {**cfg}
|
||||||
if "default_model_config_id" in cfg:
|
enriched["default_model_config_ref"] = self._model_ref(default_model_config_id)
|
||||||
enriched["default_model_config_ref"] = self._model_ref(cfg["default_model_config_id"])
|
|
||||||
if "knowledge_retrieval" in cfg:
|
if "knowledge_retrieval" in cfg:
|
||||||
enriched["knowledge_retrieval"] = self._enrich_knowledge_retrieval(cfg["knowledge_retrieval"])
|
enriched["knowledge_retrieval"] = self._enrich_knowledge_retrieval(cfg["knowledge_retrieval"])
|
||||||
if "tools" in cfg:
|
if "tools" in cfg:
|
||||||
@@ -91,8 +92,7 @@ class AppDslService:
|
|||||||
return enriched
|
return enriched
|
||||||
if app_type == AppType.MULTI_AGENT:
|
if app_type == AppType.MULTI_AGENT:
|
||||||
enriched = {**cfg}
|
enriched = {**cfg}
|
||||||
if "default_model_config_id" in cfg:
|
enriched["default_model_config_ref"] = self._model_ref(default_model_config_id)
|
||||||
enriched["default_model_config_ref"] = self._model_ref(cfg["default_model_config_id"])
|
|
||||||
if "master_agent_id" in cfg:
|
if "master_agent_id" in cfg:
|
||||||
enriched["master_agent_ref"] = self._release_ref(cfg["master_agent_id"])
|
enriched["master_agent_ref"] = self._release_ref(cfg["master_agent_id"])
|
||||||
if "sub_agents" in cfg:
|
if "sub_agents" in cfg:
|
||||||
@@ -229,8 +229,11 @@ class AppDslService:
|
|||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
tenant_id: uuid.UUID,
|
tenant_id: uuid.UUID,
|
||||||
user_id: uuid.UUID,
|
user_id: uuid.UUID,
|
||||||
|
app_id: Optional[uuid.UUID] = None,
|
||||||
) -> tuple[App, list[str]]:
|
) -> tuple[App, list[str]]:
|
||||||
"""解析 DSL,创建应用及配置,返回 (new_app, warnings)"""
|
"""解析 DSL,创建或覆盖应用配置,返回 (app, warnings)。
|
||||||
|
app_id 不为空时:校验类型一致后覆盖配置;为空时创建新应用。
|
||||||
|
"""
|
||||||
app_meta = dsl.get("app", {})
|
app_meta = dsl.get("app", {})
|
||||||
app_type = app_meta.get("type")
|
app_type = app_meta.get("type")
|
||||||
if app_type not in (AppType.AGENT, AppType.MULTI_AGENT, AppType.WORKFLOW):
|
if app_type not in (AppType.AGENT, AppType.MULTI_AGENT, AppType.WORKFLOW):
|
||||||
@@ -239,6 +242,9 @@ class AppDslService:
|
|||||||
warnings: list[str] = []
|
warnings: list[str] = []
|
||||||
now = datetime.datetime.now()
|
now = datetime.datetime.now()
|
||||||
|
|
||||||
|
if app_id is not None:
|
||||||
|
return self._overwrite_dsl(dsl, app_id, app_type, workspace_id, tenant_id, warnings, now)
|
||||||
|
|
||||||
new_app = App(
|
new_app = App(
|
||||||
id=uuid.uuid4(),
|
id=uuid.uuid4(),
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
@@ -258,11 +264,57 @@ class AppDslService:
|
|||||||
self.db.add(new_app)
|
self.db.add(new_app)
|
||||||
self.db.flush()
|
self.db.flush()
|
||||||
|
|
||||||
|
self._write_config(new_app.id, app_type, dsl, workspace_id, tenant_id, warnings, now, create=True)
|
||||||
|
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(new_app)
|
||||||
|
return new_app, warnings
|
||||||
|
|
||||||
|
def _overwrite_dsl(
|
||||||
|
self,
|
||||||
|
dsl: dict,
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
app_type: str,
|
||||||
|
workspace_id: uuid.UUID,
|
||||||
|
tenant_id: uuid.UUID,
|
||||||
|
warnings: list,
|
||||||
|
now: datetime.datetime,
|
||||||
|
) -> tuple[App, list[str]]:
|
||||||
|
"""覆盖已有应用的配置,类型不一致时抛出异常"""
|
||||||
|
app = self.db.query(App).filter(
|
||||||
|
App.id == app_id,
|
||||||
|
App.workspace_id == workspace_id,
|
||||||
|
App.is_active.is_(True)
|
||||||
|
).first()
|
||||||
|
if not app:
|
||||||
|
raise ResourceNotFoundException("应用", str(app_id))
|
||||||
|
if app.type != app_type:
|
||||||
|
raise BusinessException(
|
||||||
|
f"YAML 类型 '{app_type}' 与应用类型 '{app.type}' 不一致,无法导入",
|
||||||
|
BizCode.BAD_REQUEST
|
||||||
|
)
|
||||||
|
|
||||||
|
self._write_config(app_id, app_type, dsl, workspace_id, tenant_id, warnings, now, create=False)
|
||||||
|
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(app)
|
||||||
|
return app, warnings
|
||||||
|
|
||||||
|
def _write_config(
|
||||||
|
self,
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
app_type: str,
|
||||||
|
dsl: dict,
|
||||||
|
workspace_id: uuid.UUID,
|
||||||
|
tenant_id: uuid.UUID,
|
||||||
|
warnings: list,
|
||||||
|
now: datetime.datetime,
|
||||||
|
create: bool,
|
||||||
|
) -> None:
|
||||||
|
"""写入(新建或覆盖)应用配置"""
|
||||||
if app_type == AppType.AGENT:
|
if app_type == AppType.AGENT:
|
||||||
cfg = dsl.get("agent_config") or {}
|
cfg = dsl.get("agent_config") or {}
|
||||||
self.db.add(AgentConfig(
|
fields = dict(
|
||||||
id=uuid.uuid4(),
|
|
||||||
app_id=new_app.id,
|
|
||||||
system_prompt=cfg.get("system_prompt"),
|
system_prompt=cfg.get("system_prompt"),
|
||||||
model_parameters=cfg.get("model_parameters"),
|
model_parameters=cfg.get("model_parameters"),
|
||||||
default_model_config_id=self._resolve_model(cfg.get("default_model_config_ref"), tenant_id, warnings),
|
default_model_config_id=self._resolve_model(cfg.get("default_model_config_ref"), tenant_id, warnings),
|
||||||
@@ -272,16 +324,21 @@ class AppDslService:
|
|||||||
tools=self._resolve_tools(cfg.get("tools", []), tenant_id, warnings),
|
tools=self._resolve_tools(cfg.get("tools", []), tenant_id, warnings),
|
||||||
skills=self._resolve_skills(cfg.get("skills", {}), tenant_id, warnings),
|
skills=self._resolve_skills(cfg.get("skills", {}), tenant_id, warnings),
|
||||||
features=cfg.get("features", {}),
|
features=cfg.get("features", {}),
|
||||||
is_active=True,
|
|
||||||
created_at=now,
|
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
))
|
)
|
||||||
|
if create:
|
||||||
|
self.db.add(AgentConfig(id=uuid.uuid4(), app_id=app_id, is_active=True, created_at=now, **fields))
|
||||||
|
else:
|
||||||
|
existing = self.db.query(AgentConfig).filter(AgentConfig.app_id == app_id).first()
|
||||||
|
if existing:
|
||||||
|
for k, v in fields.items():
|
||||||
|
setattr(existing, k, v)
|
||||||
|
else:
|
||||||
|
self.db.add(AgentConfig(id=uuid.uuid4(), app_id=app_id, is_active=True, created_at=now, **fields))
|
||||||
|
|
||||||
elif app_type == AppType.MULTI_AGENT:
|
elif app_type == AppType.MULTI_AGENT:
|
||||||
cfg = dsl.get("multi_agent_config") or {}
|
cfg = dsl.get("multi_agent_config") or {}
|
||||||
self.db.add(MultiAgentConfig(
|
fields = dict(
|
||||||
id=uuid.uuid4(),
|
|
||||||
app_id=new_app.id,
|
|
||||||
orchestration_mode=cfg.get("orchestration_mode", "collaboration"),
|
orchestration_mode=cfg.get("orchestration_mode", "collaboration"),
|
||||||
master_agent_name=cfg.get("master_agent_name"),
|
master_agent_name=cfg.get("master_agent_name"),
|
||||||
model_parameters=cfg.get("model_parameters"),
|
model_parameters=cfg.get("model_parameters"),
|
||||||
@@ -291,13 +348,24 @@ class AppDslService:
|
|||||||
routing_rules=self._resolve_routing_rules(cfg.get("routing_rules"), warnings),
|
routing_rules=self._resolve_routing_rules(cfg.get("routing_rules"), warnings),
|
||||||
execution_config=cfg.get("execution_config", {}),
|
execution_config=cfg.get("execution_config", {}),
|
||||||
aggregation_strategy=cfg.get("aggregation_strategy", "merge"),
|
aggregation_strategy=cfg.get("aggregation_strategy", "merge"),
|
||||||
is_active=True,
|
|
||||||
created_at=now,
|
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
))
|
)
|
||||||
|
if create:
|
||||||
|
self.db.add(MultiAgentConfig(id=uuid.uuid4(), app_id=app_id, is_active=True, created_at=now, **fields))
|
||||||
|
else:
|
||||||
|
existing = self.db.query(MultiAgentConfig).filter(MultiAgentConfig.app_id == app_id).first()
|
||||||
|
if existing:
|
||||||
|
for k, v in fields.items():
|
||||||
|
setattr(existing, k, v)
|
||||||
|
else:
|
||||||
|
self.db.add(MultiAgentConfig(id=uuid.uuid4(), app_id=app_id, is_active=True, created_at=now, **fields))
|
||||||
|
|
||||||
elif app_type == AppType.WORKFLOW:
|
elif app_type == AppType.WORKFLOW:
|
||||||
adapter = MemoryBearAdapter(dsl)
|
raw_wf = dsl.get("workflow") or {}
|
||||||
|
raw_nodes = raw_wf.get("nodes") or []
|
||||||
|
resolved_nodes = self._resolve_workflow_nodes(raw_nodes, tenant_id, workspace_id, warnings)
|
||||||
|
resolved_dsl = {**dsl, "workflow": {**raw_wf, "nodes": resolved_nodes}}
|
||||||
|
adapter = MemoryBearAdapter(resolved_dsl)
|
||||||
if not adapter.validate_config():
|
if not adapter.validate_config():
|
||||||
raise BusinessException("工作流配置格式无效", BizCode.BAD_REQUEST)
|
raise BusinessException("工作流配置格式无效", BizCode.BAD_REQUEST)
|
||||||
result = adapter.parse_workflow()
|
result = adapter.parse_workflow()
|
||||||
@@ -305,21 +373,39 @@ class AppDslService:
|
|||||||
warnings.append(f"[节点错误] {e.node_name or e.node_id}: {e.detail}")
|
warnings.append(f"[节点错误] {e.node_name or e.node_id}: {e.detail}")
|
||||||
for w in result.warnings:
|
for w in result.warnings:
|
||||||
warnings.append(f"[节点警告] {w.node_name or w.node_id}: {w.detail}")
|
warnings.append(f"[节点警告] {w.node_name or w.node_id}: {w.detail}")
|
||||||
wf = dsl.get("workflow") or {}
|
wf_service = WorkflowService(self.db)
|
||||||
WorkflowService(self.db).create_workflow_config(
|
if create:
|
||||||
app_id=new_app.id,
|
wf_service.create_workflow_config(
|
||||||
nodes=[n.model_dump() for n in result.nodes],
|
app_id=app_id,
|
||||||
edges=[e.model_dump() for e in result.edges],
|
nodes=[n.model_dump() for n in result.nodes],
|
||||||
variables=[v.model_dump() for v in result.variables],
|
edges=[e.model_dump() for e in result.edges],
|
||||||
execution_config=wf.get("execution_config", {}),
|
variables=[v.model_dump() for v in result.variables],
|
||||||
features=wf.get("features", {}),
|
execution_config=raw_wf.get("execution_config", {}),
|
||||||
triggers=wf.get("triggers", []),
|
features=raw_wf.get("features", {}),
|
||||||
validate=False,
|
triggers=raw_wf.get("triggers", []),
|
||||||
)
|
validate=False,
|
||||||
|
)
|
||||||
self.db.commit()
|
else:
|
||||||
self.db.refresh(new_app)
|
existing = self.db.query(WorkflowConfig).filter(WorkflowConfig.app_id == app_id).first()
|
||||||
return new_app, warnings
|
if existing:
|
||||||
|
existing.nodes = [n.model_dump() for n in result.nodes]
|
||||||
|
existing.edges = [e.model_dump() for e in result.edges]
|
||||||
|
existing.variables = [v.model_dump() for v in result.variables]
|
||||||
|
existing.execution_config = raw_wf.get("execution_config", {})
|
||||||
|
existing.features = raw_wf.get("features", {})
|
||||||
|
existing.triggers = raw_wf.get("triggers", [])
|
||||||
|
existing.updated_at = now
|
||||||
|
else:
|
||||||
|
wf_service.create_workflow_config(
|
||||||
|
app_id=app_id,
|
||||||
|
nodes=[n.model_dump() for n in result.nodes],
|
||||||
|
edges=[e.model_dump() for e in result.edges],
|
||||||
|
variables=[v.model_dump() for v in result.variables],
|
||||||
|
execution_config=raw_wf.get("execution_config", {}),
|
||||||
|
features=raw_wf.get("features", {}),
|
||||||
|
triggers=raw_wf.get("triggers", []),
|
||||||
|
validate=False,
|
||||||
|
)
|
||||||
|
|
||||||
def _unique_app_name(self, name: str, workspace_id: uuid.UUID, app_type: AppType) -> str:
|
def _unique_app_name(self, name: str, workspace_id: uuid.UUID, app_type: AppType) -> str:
|
||||||
"""生成唯一应用名称,同时检查本空间自有应用和共享到本空间的应用"""
|
"""生成唯一应用名称,同时检查本空间自有应用和共享到本空间的应用"""
|
||||||
@@ -365,27 +451,63 @@ class AppDslService:
|
|||||||
def _resolve_kb(self, ref: Optional[dict], workspace_id: uuid.UUID, warnings: list) -> Optional[str]:
|
def _resolve_kb(self, ref: Optional[dict], workspace_id: uuid.UUID, warnings: list) -> Optional[str]:
|
||||||
if not ref:
|
if not ref:
|
||||||
return None
|
return None
|
||||||
kb = self.db.query(Knowledge).filter(
|
kb_id = ref.get("id")
|
||||||
Knowledge.workspace_id == workspace_id,
|
if kb_id:
|
||||||
Knowledge.name == ref.get("name")
|
try:
|
||||||
).first()
|
kb_uuid = uuid.UUID(str(kb_id))
|
||||||
if not kb:
|
kb_share = self.db.query(KnowledgeShare).filter(
|
||||||
warnings.append(f"知识库 '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
|
KnowledgeShare.target_workspace_id == workspace_id,
|
||||||
return str(kb.id) if kb else None
|
KnowledgeShare.source_kb_id == kb_uuid
|
||||||
|
).first()
|
||||||
|
if kb_share:
|
||||||
|
kb = self.db.query(Knowledge).filter(
|
||||||
|
Knowledge.id == kb_share.target_kb_id
|
||||||
|
).first()
|
||||||
|
if kb and kb.status == 1:
|
||||||
|
return str(kb_share.target_kb_id)
|
||||||
|
kb = self.db.query(Knowledge).filter(
|
||||||
|
Knowledge.workspace_id == workspace_id,
|
||||||
|
Knowledge.id == kb_uuid,
|
||||||
|
Knowledge.status == 1
|
||||||
|
).first()
|
||||||
|
if kb:
|
||||||
|
return str(kb.id)
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
pass
|
||||||
|
warnings.append(f"知识库 '{kb_id}' 未匹配,已置空,请导入后手动配置")
|
||||||
|
return None
|
||||||
|
|
||||||
def _resolve_tool(self, ref: Optional[dict], tenant_id: uuid.UUID, warnings: list) -> Optional[str]:
|
def _resolve_tool(self, ref: Optional[dict], tenant_id: uuid.UUID, warnings: list) -> Optional[str]:
|
||||||
if not ref:
|
if not ref:
|
||||||
return None
|
return None
|
||||||
q = self.db.query(ToolConfigModel).filter(
|
tool_id = ref.get("id")
|
||||||
ToolConfigModel.tenant_id == tenant_id,
|
tool_name = ref.get("name")
|
||||||
ToolConfigModel.name == ref.get("name")
|
if tool_id:
|
||||||
)
|
try:
|
||||||
if ref.get("tool_type"):
|
tool_uuid = uuid.UUID(str(tool_id))
|
||||||
q = q.filter(ToolConfigModel.tool_type == ref["tool_type"])
|
t = self.db.query(ToolConfigModel).filter(
|
||||||
t = q.first()
|
ToolConfigModel.id == tool_uuid,
|
||||||
if not t:
|
ToolConfigModel.tenant_id == tenant_id,
|
||||||
warnings.append(f"工具 '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
|
ToolConfigModel.is_active.is_(True)
|
||||||
return str(t.id) if t else None
|
).first()
|
||||||
|
if t:
|
||||||
|
return str(t.id)
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
pass
|
||||||
|
if tool_name:
|
||||||
|
q = self.db.query(ToolConfigModel).filter(
|
||||||
|
ToolConfigModel.tenant_id == tenant_id,
|
||||||
|
ToolConfigModel.name == tool_name
|
||||||
|
)
|
||||||
|
if ref.get("tool_type"):
|
||||||
|
q = q.filter(ToolConfigModel.tool_type == ref["tool_type"])
|
||||||
|
t = q.first()
|
||||||
|
if t:
|
||||||
|
return str(t.id)
|
||||||
|
warnings.append(f"工具 '{tool_name}' 未匹配,已置空,请导入后手动配置")
|
||||||
|
else:
|
||||||
|
warnings.append(f"工具 '{tool_id}' 未匹配,已置空,请导入后手动配置")
|
||||||
|
return None
|
||||||
|
|
||||||
def _resolve_release(self, ref: Optional[dict], warnings: list) -> Optional[uuid.UUID]:
|
def _resolve_release(self, ref: Optional[dict], warnings: list) -> Optional[uuid.UUID]:
|
||||||
if not ref:
|
if not ref:
|
||||||
@@ -427,6 +549,61 @@ class AppDslService:
|
|||||||
result.append(entry)
|
result.append(entry)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _resolve_workflow_nodes(self, nodes: list, tenant_id: uuid.UUID, workspace_id: uuid.UUID, warnings: list) -> list:
|
||||||
|
"""解析工作流节点中的工具ID和知识库ID,匹配不到则清空配置"""
|
||||||
|
resolved_nodes = []
|
||||||
|
for node in nodes:
|
||||||
|
node_type = node.get("type")
|
||||||
|
config = dict(node.get("config") or {})
|
||||||
|
node_label = node.get("name") or node.get("id")
|
||||||
|
if node_type == NodeType.TOOL.value:
|
||||||
|
tool_id = config.get("tool_id")
|
||||||
|
if not tool_id:
|
||||||
|
# tool_id 本身就是空,直接置空不重复 warning
|
||||||
|
config["tool_id"] = None
|
||||||
|
config["tool_parameters"] = {}
|
||||||
|
else:
|
||||||
|
tool_ref = {}
|
||||||
|
if isinstance(tool_id, str) and len(tool_id) >= 36:
|
||||||
|
try:
|
||||||
|
uuid.UUID(tool_id)
|
||||||
|
tool_ref["id"] = tool_id
|
||||||
|
except ValueError:
|
||||||
|
tool_ref["name"] = tool_id
|
||||||
|
else:
|
||||||
|
tool_ref["name"] = tool_id
|
||||||
|
resolved_tool_id = self._resolve_tool(tool_ref, tenant_id, [])
|
||||||
|
if resolved_tool_id:
|
||||||
|
config["tool_id"] = resolved_tool_id
|
||||||
|
else:
|
||||||
|
warnings.append(f"[{node_label}] 工具 '{tool_id}' 未匹配,已置空,请导入后手动配置")
|
||||||
|
config["tool_id"] = None
|
||||||
|
config["tool_parameters"] = {}
|
||||||
|
elif node_type == NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||||
|
knowledge_bases = config.get("knowledge_bases") or []
|
||||||
|
resolved_kbs = []
|
||||||
|
for kb in knowledge_bases:
|
||||||
|
kb_id = kb.get("kb_id")
|
||||||
|
if not kb_id:
|
||||||
|
continue
|
||||||
|
kb_ref = {}
|
||||||
|
if isinstance(kb_id, str) and len(kb_id) >= 36:
|
||||||
|
try:
|
||||||
|
uuid.UUID(kb_id)
|
||||||
|
kb_ref["id"] = kb_id
|
||||||
|
except ValueError:
|
||||||
|
kb_ref["name"] = kb_id
|
||||||
|
else:
|
||||||
|
kb_ref["name"] = kb_id
|
||||||
|
resolved_id = self._resolve_kb(kb_ref, workspace_id, [])
|
||||||
|
if resolved_id:
|
||||||
|
resolved_kbs.append({**kb, "kb_id": resolved_id})
|
||||||
|
else:
|
||||||
|
warnings.append(f"[{node_label}] 知识库 '{kb_id}' 未匹配,已移除,请导入后手动配置")
|
||||||
|
config["knowledge_bases"] = resolved_kbs
|
||||||
|
resolved_nodes.append({**node, "config": config})
|
||||||
|
return resolved_nodes
|
||||||
|
|
||||||
def _resolve_knowledge_retrieval(self, kr: Optional[dict], workspace_id: uuid.UUID, warnings: list) -> Optional[dict]:
|
def _resolve_knowledge_retrieval(self, kr: Optional[dict], workspace_id: uuid.UUID, warnings: list) -> Optional[dict]:
|
||||||
if not kr:
|
if not kr:
|
||||||
return kr
|
return kr
|
||||||
|
|||||||
@@ -1452,6 +1452,32 @@ class AppService:
|
|||||||
logger.debug("配置不存在,返回默认模板", extra={"app_id": str(app_id)})
|
logger.debug("配置不存在,返回默认模板", extra={"app_id": str(app_id)})
|
||||||
return self._create_default_agent_config(app_id)
|
return self._create_default_agent_config(app_id)
|
||||||
|
|
||||||
|
def get_default_model_parameters(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
) -> "ModelParameters":
|
||||||
|
"""获取 Agent 默认模型参数(不修改数据库)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_id: 应用ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelParameters: 默认模型参数
|
||||||
|
"""
|
||||||
|
logger.info("获取 Agent 默认模型参数", extra={"app_id": str(app_id)})
|
||||||
|
|
||||||
|
app = self._get_app_or_404(app_id)
|
||||||
|
|
||||||
|
if app.type != "agent":
|
||||||
|
raise BusinessException("只有 Agent 类型应用支持 Agent 配置", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||||
|
|
||||||
|
from app.schemas.app_schema import ModelParameters
|
||||||
|
default_model_parameters = ModelParameters()
|
||||||
|
|
||||||
|
logger.info("获取 Agent 默认模型参数成功", extra={"app_id": str(app_id)})
|
||||||
|
return default_model_parameters
|
||||||
|
|
||||||
def _create_default_agent_config(self, app_id: uuid.UUID) -> AgentConfig:
|
def _create_default_agent_config(self, app_id: uuid.UUID) -> AgentConfig:
|
||||||
"""创建默认的 Agent 配置模板(不保存到数据库)
|
"""创建默认的 Agent 配置模板(不保存到数据库)
|
||||||
|
|
||||||
|
|||||||
@@ -544,7 +544,7 @@ class ConversationService:
|
|||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
is_omni=is_omni,
|
is_omni=is_omni,
|
||||||
support_thinking="thinking" in (capability or []),
|
capability=capability,
|
||||||
),
|
),
|
||||||
type=ModelType(model_type)
|
type=ModelType(model_type)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -597,6 +597,7 @@ class AgentRunService:
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
deep_thinking=effective_params.get("deep_thinking", False),
|
deep_thinking=effective_params.get("deep_thinking", False),
|
||||||
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
|
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
|
||||||
|
json_output=effective_params.get("json_output", False),
|
||||||
capability=api_key_config.get("capability", []),
|
capability=api_key_config.get("capability", []),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -853,6 +854,7 @@ class AgentRunService:
|
|||||||
streaming=True,
|
streaming=True,
|
||||||
deep_thinking=effective_params.get("deep_thinking", False),
|
deep_thinking=effective_params.get("deep_thinking", False),
|
||||||
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
|
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
|
||||||
|
json_output=effective_params.get("json_output", False),
|
||||||
capability=api_key_config.get("capability", []),
|
capability=api_key_config.get("capability", []),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1299,10 +1301,30 @@ class AgentRunService:
|
|||||||
"history_files": {}
|
"history_files": {}
|
||||||
}
|
}
|
||||||
if files:
|
if files:
|
||||||
|
from app.models.file_metadata_model import FileMetadata
|
||||||
|
local_ids = [f.upload_file_id for f in files
|
||||||
|
if f.transfer_method.value == "local_file" and f.upload_file_id
|
||||||
|
and (not f.name or not f.size)]
|
||||||
|
meta_map = {}
|
||||||
|
if local_ids:
|
||||||
|
rows = self.db.query(FileMetadata).filter(
|
||||||
|
FileMetadata.id.in_(local_ids),
|
||||||
|
FileMetadata.status == "completed"
|
||||||
|
).all()
|
||||||
|
meta_map = {str(r.id): r for r in rows}
|
||||||
for f in files:
|
for f in files:
|
||||||
|
name, size = f.name, f.size
|
||||||
|
if f.transfer_method.value == "local_file" and f.upload_file_id and (not name or not size):
|
||||||
|
meta = meta_map.get(str(f.upload_file_id))
|
||||||
|
if meta:
|
||||||
|
name = name or meta.file_name
|
||||||
|
size = size or meta.file_size
|
||||||
human_meta["files"].append({
|
human_meta["files"].append({
|
||||||
"type": f.type,
|
"type": f.type,
|
||||||
"url": f.url
|
"url": f.url,
|
||||||
|
"file_type": f.file_type,
|
||||||
|
"name": name,
|
||||||
|
"size": size
|
||||||
})
|
})
|
||||||
|
|
||||||
# 保存 history_files,包含 provider 和 is_omni 信息
|
# 保存 history_files,包含 provider 和 is_omni 信息
|
||||||
|
|||||||
@@ -679,9 +679,9 @@ class EmotionAnalyticsService:
|
|||||||
|
|
||||||
# 查询用户的实体和标签
|
# 查询用户的实体和标签
|
||||||
query = """
|
query = """
|
||||||
MATCH (e:Entity)
|
MATCH (e:ExtractedEntity)
|
||||||
WHERE e.end_user_id = $end_user_id
|
WHERE e.end_user_id = $end_user_id
|
||||||
RETURN e.name as name, e.type as type
|
RETURN e.name as name, e.entity_type as type
|
||||||
ORDER BY e.created_at DESC
|
ORDER BY e.created_at DESC
|
||||||
LIMIT 20
|
LIMIT 20
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from app.schemas.implicit_memory_schema import (
|
|||||||
UserMemorySummary,
|
UserMemorySummary,
|
||||||
)
|
)
|
||||||
from app.schemas.memory_config_schema import MemoryConfig
|
from app.schemas.memory_config_schema import MemoryConfig
|
||||||
|
from app.services.memory_base_service import MIN_MEMORY_SUMMARY_COUNT
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -379,12 +380,59 @@ class ImplicitMemoryService:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def _build_empty_profile(self) -> dict:
|
||||||
|
"""构建 MemorySummary 不足时返回的固定空白画像数据"""
|
||||||
|
now_ms = int(datetime.utcnow().timestamp() * 1000)
|
||||||
|
insufficient = "Insufficient data for analysis"
|
||||||
|
|
||||||
|
def _empty_dimension(name: str) -> dict:
|
||||||
|
return {
|
||||||
|
"evidence": [insufficient],
|
||||||
|
"reasoning": f"No clear evidence found for {name} dimension",
|
||||||
|
"percentage": 0.0,
|
||||||
|
"dimension_name": name,
|
||||||
|
"confidence_level": 20,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _empty_category(name: str) -> dict:
|
||||||
|
return {
|
||||||
|
"evidence": [insufficient],
|
||||||
|
"percentage": 25.0,
|
||||||
|
"category_name": name,
|
||||||
|
"trending_direction": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"habits": [],
|
||||||
|
"portrait": {
|
||||||
|
"aesthetic": _empty_dimension("aesthetic"),
|
||||||
|
"creativity": _empty_dimension("creativity"),
|
||||||
|
"literature": _empty_dimension("literature"),
|
||||||
|
"technology": _empty_dimension("technology"),
|
||||||
|
"historical_trends": None,
|
||||||
|
"analysis_timestamp": now_ms,
|
||||||
|
"total_summaries_analyzed": 0,
|
||||||
|
},
|
||||||
|
"preferences": [],
|
||||||
|
"interest_areas": {
|
||||||
|
"art": _empty_category("art"),
|
||||||
|
"tech": _empty_category("tech"),
|
||||||
|
"music": _empty_category("music"),
|
||||||
|
"lifestyle": _empty_category("lifestyle"),
|
||||||
|
"analysis_timestamp": now_ms,
|
||||||
|
"total_summaries_analyzed": 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
async def generate_complete_profile(
|
async def generate_complete_profile(
|
||||||
self,
|
self,
|
||||||
user_id: str
|
user_id: str
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""生成完整的用户画像(包含所有4个模块)
|
"""生成完整的用户画像(包含所有4个模块)
|
||||||
|
|
||||||
|
需要该用户的 MemorySummary 节点数量 >= 5 才会真正调用 LLM 生成画像,
|
||||||
|
否则返回固定的空白画像数据。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: 用户ID
|
user_id: 用户ID
|
||||||
|
|
||||||
@@ -394,6 +442,16 @@ class ImplicitMemoryService:
|
|||||||
logger.info(f"生成完整用户画像: user={user_id}")
|
logger.info(f"生成完整用户画像: user={user_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 前置检查:查询该用户有效的 MemorySummary 节点数量(排除孤立节点)
|
||||||
|
from app.services.memory_base_service import MemoryBaseService
|
||||||
|
base_service = MemoryBaseService()
|
||||||
|
memory_summary_count = await base_service.get_valid_memory_summary_count(user_id)
|
||||||
|
logger.info(f"用户 MemorySummary 节点数量: {memory_summary_count} (user={user_id})")
|
||||||
|
|
||||||
|
if memory_summary_count < MIN_MEMORY_SUMMARY_COUNT:
|
||||||
|
logger.info(f"MemorySummary 数量不足 {MIN_MEMORY_SUMMARY_COUNT}(当前 {memory_summary_count}),返回空白画像: user={user_id}")
|
||||||
|
return self._build_empty_profile()
|
||||||
|
|
||||||
# 并行调用4个分析方法
|
# 并行调用4个分析方法
|
||||||
preferences, portrait, interest_areas, habits = await asyncio.gather(
|
preferences, portrait, interest_areas, habits = await asyncio.gather(
|
||||||
self.get_preference_tags(user_id=user_id),
|
self.get_preference_tags(user_id=user_id),
|
||||||
|
|||||||
@@ -2,11 +2,14 @@ import uuid
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
from app.models.knowledge_model import Knowledge
|
from app.models.knowledge_model import Knowledge
|
||||||
|
from app.models.workspace_model import Workspace
|
||||||
|
from app.models.models_model import ModelConfig
|
||||||
from app.schemas.knowledge_schema import KnowledgeCreate, KnowledgeUpdate
|
from app.schemas.knowledge_schema import KnowledgeCreate, KnowledgeUpdate
|
||||||
from app.repositories import knowledge_repository
|
from app.repositories import knowledge_repository
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.repositories.model_repository import ModelConfigRepository
|
||||||
|
from app.models.models_model import ModelType
|
||||||
|
|
||||||
# Obtain a dedicated logger for business logic
|
|
||||||
business_logger = get_business_logger()
|
business_logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
@@ -67,6 +70,50 @@ def create_knowledge(
|
|||||||
knowledge.workspace_id = current_user.current_workspace_id
|
knowledge.workspace_id = current_user.current_workspace_id
|
||||||
if knowledge.parent_id is None:
|
if knowledge.parent_id is None:
|
||||||
knowledge.parent_id = knowledge.workspace_id
|
knowledge.parent_id = knowledge.workspace_id
|
||||||
|
|
||||||
|
workspace = db.query(Workspace).filter(Workspace.id == knowledge.workspace_id).first()
|
||||||
|
if not workspace:
|
||||||
|
raise Exception(f"Workspace {knowledge.workspace_id} not found")
|
||||||
|
|
||||||
|
tenant_id = workspace.tenant_id
|
||||||
|
|
||||||
|
if not knowledge.embedding_id:
|
||||||
|
embedding_models = ModelConfigRepository.get_by_type(
|
||||||
|
db=db, model_types=[ModelType.EMBEDDING], tenant_id=tenant_id, is_active=True
|
||||||
|
)
|
||||||
|
if embedding_models:
|
||||||
|
knowledge.embedding_id = embedding_models[0].id
|
||||||
|
business_logger.debug(f"Auto-bind embedding model: {embedding_models[0].id}")
|
||||||
|
|
||||||
|
if not knowledge.reranker_id:
|
||||||
|
rerank_models = ModelConfigRepository.get_by_type(
|
||||||
|
db=db, model_types=[ModelType.RERANK], tenant_id=tenant_id, is_active=True
|
||||||
|
)
|
||||||
|
if rerank_models:
|
||||||
|
knowledge.reranker_id = rerank_models[0].id
|
||||||
|
business_logger.debug(f"Auto-bind rerank model: {rerank_models[0].id}")
|
||||||
|
|
||||||
|
if not knowledge.llm_id:
|
||||||
|
llm_models = ModelConfigRepository.get_by_type(
|
||||||
|
db=db, model_types=[ModelType.LLM, ModelType.CHAT], tenant_id=tenant_id, is_active=True
|
||||||
|
)
|
||||||
|
if llm_models:
|
||||||
|
knowledge.llm_id = llm_models[0].id
|
||||||
|
business_logger.debug(f"Auto-bind llm model: {llm_models[0].id}")
|
||||||
|
|
||||||
|
if not knowledge.image2text_id:
|
||||||
|
image2text_models = db.query(ModelConfig).filter(
|
||||||
|
ModelConfig.tenant_id == tenant_id,
|
||||||
|
ModelConfig.type.in_([ModelType.CHAT.value]),
|
||||||
|
ModelConfig.capability.contains(["vision"]),
|
||||||
|
ModelConfig.is_active == True,
|
||||||
|
ModelConfig.is_composite == False
|
||||||
|
).order_by(ModelConfig.created_at.desc()).all()
|
||||||
|
if not image2text_models:
|
||||||
|
raise Exception("租户下没有可用的视觉模型,创建知识库失败")
|
||||||
|
knowledge.image2text_id = image2text_models[0].id
|
||||||
|
business_logger.debug(f"Auto-bind image2text model: {image2text_models[0].id}")
|
||||||
|
|
||||||
business_logger.debug(f"Start creating the knowledge base: {knowledge.name}")
|
business_logger.debug(f"Start creating the knowledge base: {knowledge.name}")
|
||||||
db_knowledge = knowledge_repository.create_knowledge(
|
db_knowledge = knowledge_repository.create_knowledge(
|
||||||
db=db, knowledge=knowledge
|
db=db, knowledge=knowledge
|
||||||
|
|||||||
@@ -415,9 +415,11 @@ class LLMRouter:
|
|||||||
api_key=api_key_config.api_key,
|
api_key=api_key_config.api_key,
|
||||||
base_url=api_key_config.api_base,
|
base_url=api_key_config.api_base,
|
||||||
is_omni=api_key_config.is_omni,
|
is_omni=api_key_config.is_omni,
|
||||||
support_thinking="thinking" in (api_key_config.capability or []),
|
capability=api_key_config.capability,
|
||||||
temperature=0.3,
|
extra_params={
|
||||||
max_tokens=500
|
"temperature": 0.3,
|
||||||
|
"max_tokens": 500
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"创建 LLM 实例 - Provider: {api_key_config.provider}, Model: {api_key_config.model_name}")
|
logger.debug(f"创建 LLM 实例 - Provider: {api_key_config.provider}, Model: {api_key_config.model_name}")
|
||||||
|
|||||||
@@ -393,7 +393,7 @@ class MasterAgentRouter:
|
|||||||
api_key=api_key_config.api_key,
|
api_key=api_key_config.api_key,
|
||||||
base_url=api_key_config.api_base,
|
base_url=api_key_config.api_base,
|
||||||
is_omni=api_key_config.is_omni,
|
is_omni=api_key_config.is_omni,
|
||||||
support_thinking="thinking" in (api_key_config.capability or []),
|
capability=api_key_config.capability,
|
||||||
extra_params = extra_params
|
extra_params = extra_params
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ This service validates inputs and delegates to MemoryAgentService for core memor
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
||||||
from app.core.logging_config import get_logger
|
from app.core.logging_config import get_logger
|
||||||
@@ -15,7 +17,6 @@ from app.models.app_model import App
|
|||||||
from app.models.end_user_model import EndUser
|
from app.models.end_user_model import EndUser
|
||||||
from app.schemas.memory_config_schema import ConfigurationError
|
from app.schemas.memory_config_schema import ConfigurationError
|
||||||
from app.services.memory_agent_service import MemoryAgentService
|
from app.services.memory_agent_service import MemoryAgentService
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -124,7 +125,7 @@ class MemoryAPIService:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to update memory_config_id for end_user {end_user_id}: {e}")
|
logger.warning(f"Failed to update memory_config_id for end_user {end_user_id}: {e}")
|
||||||
|
|
||||||
async def write_memory(
|
def write_memory(
|
||||||
self,
|
self,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
@@ -133,14 +134,131 @@ class MemoryAPIService:
|
|||||||
storage_type: str = "neo4j",
|
storage_type: str = "neo4j",
|
||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Write memory with validation.
|
"""Submit a memory write task via Celery.
|
||||||
|
|
||||||
Validates end_user exists and belongs to workspace, updates the end user's
|
Validates end_user exists and belongs to workspace, updates the end user's
|
||||||
memory_config_id, then delegates to MemoryAgentService.write_memory.
|
memory_config_id, then dispatches write_message_task to Celery for async
|
||||||
|
processing with per-user fair locking.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
workspace_id: Workspace ID for resource validation
|
workspace_id: Workspace ID for resource validation
|
||||||
end_user_id: End user identifier (used as end_user_id)
|
end_user_id: End user identifier
|
||||||
|
message: Message content to store
|
||||||
|
config_id: Memory configuration ID (required)
|
||||||
|
storage_type: Storage backend (neo4j or rag)
|
||||||
|
user_rag_memory_id: Optional RAG memory ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with task_id, status, and end_user_id
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ResourceNotFoundException: If end_user not found
|
||||||
|
BusinessException: If validation fails
|
||||||
|
"""
|
||||||
|
logger.info(f"Submitting memory write for end_user: {end_user_id}, workspace: {workspace_id}")
|
||||||
|
|
||||||
|
# Validate end_user exists and belongs to workspace
|
||||||
|
self.validate_end_user(end_user_id, workspace_id)
|
||||||
|
|
||||||
|
# Update end user's memory_config_id
|
||||||
|
self._update_end_user_config(end_user_id, config_id)
|
||||||
|
|
||||||
|
# Convert to message list format expected by write_message_task
|
||||||
|
messages = message if isinstance(message, list) else [{"role": "user", "content": message}]
|
||||||
|
|
||||||
|
from app.tasks import write_message_task
|
||||||
|
task = write_message_task.delay(
|
||||||
|
end_user_id,
|
||||||
|
messages,
|
||||||
|
config_id,
|
||||||
|
storage_type,
|
||||||
|
user_rag_memory_id or "",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Memory write task submitted: task_id={task.id}, end_user_id={end_user_id}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"task_id": task.id,
|
||||||
|
"status": "PENDING",
|
||||||
|
"end_user_id": end_user_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
def read_memory(
|
||||||
|
self,
|
||||||
|
workspace_id: uuid.UUID,
|
||||||
|
end_user_id: str,
|
||||||
|
message: str,
|
||||||
|
search_switch: str = "0",
|
||||||
|
config_id: str = "",
|
||||||
|
storage_type: str = "neo4j",
|
||||||
|
user_rag_memory_id: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Submit a memory read task via Celery.
|
||||||
|
|
||||||
|
Validates end_user exists and belongs to workspace, updates the end user's
|
||||||
|
memory_config_id, then dispatches read_message_task to Celery for async processing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: Workspace ID for resource validation
|
||||||
|
end_user_id: End user identifier
|
||||||
|
message: Query message
|
||||||
|
search_switch: Search mode (0=deep search with verification, 1=deep search, 2=fast search)
|
||||||
|
config_id: Memory configuration ID (required)
|
||||||
|
storage_type: Storage backend (neo4j or rag)
|
||||||
|
user_rag_memory_id: Optional RAG memory ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with task_id, status, and end_user_id
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ResourceNotFoundException: If end_user not found
|
||||||
|
BusinessException: If validation fails
|
||||||
|
"""
|
||||||
|
logger.info(f"Submitting memory read for end_user: {end_user_id}, workspace: {workspace_id}")
|
||||||
|
|
||||||
|
# Validate end_user exists and belongs to workspace
|
||||||
|
self.validate_end_user(end_user_id, workspace_id)
|
||||||
|
|
||||||
|
# Update end user's memory_config_id
|
||||||
|
self._update_end_user_config(end_user_id, config_id)
|
||||||
|
|
||||||
|
from app.tasks import read_message_task
|
||||||
|
task = read_message_task.delay(
|
||||||
|
end_user_id,
|
||||||
|
message,
|
||||||
|
[], # history
|
||||||
|
search_switch,
|
||||||
|
config_id,
|
||||||
|
storage_type,
|
||||||
|
user_rag_memory_id or "",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Memory read task submitted: task_id={task.id}, end_user_id={end_user_id}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"task_id": task.id,
|
||||||
|
"status": "PENDING",
|
||||||
|
"end_user_id": end_user_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def write_memory_sync(
|
||||||
|
self,
|
||||||
|
workspace_id: uuid.UUID,
|
||||||
|
end_user_id: str,
|
||||||
|
message: str,
|
||||||
|
config_id: str,
|
||||||
|
storage_type: str = "neo4j",
|
||||||
|
user_rag_memory_id: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Write memory synchronously (inline, no Celery).
|
||||||
|
|
||||||
|
Validates end_user, then calls MemoryAgentService.write_memory directly.
|
||||||
|
Blocks until the write completes. Use for cases where the caller needs
|
||||||
|
immediate confirmation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: Workspace ID for resource validation
|
||||||
|
end_user_id: End user identifier
|
||||||
message: Message content to store
|
message: Message content to store
|
||||||
config_id: Memory configuration ID (required)
|
config_id: Memory configuration ID (required)
|
||||||
storage_type: Storage backend (neo4j or rag)
|
storage_type: Storage backend (neo4j or rag)
|
||||||
@@ -151,19 +269,14 @@ class MemoryAPIService:
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ResourceNotFoundException: If end_user not found
|
ResourceNotFoundException: If end_user not found
|
||||||
BusinessException: If end_user not in authorized workspace or write fails
|
BusinessException: If write fails
|
||||||
"""
|
"""
|
||||||
logger.info(f"Writing memory for end_user: {end_user_id}, workspace: {workspace_id}")
|
logger.info(f"Writing memory (sync) for end_user: {end_user_id}, workspace: {workspace_id}")
|
||||||
|
|
||||||
# Validate end_user exists and belongs to workspace
|
|
||||||
self.validate_end_user(end_user_id, workspace_id)
|
self.validate_end_user(end_user_id, workspace_id)
|
||||||
|
|
||||||
# Update end user's memory_config_id
|
|
||||||
self._update_end_user_config(end_user_id, config_id)
|
self._update_end_user_config(end_user_id, config_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Delegate to MemoryAgentService
|
|
||||||
# Convert string message to list[dict] format expected by MemoryAgentService
|
|
||||||
messages = message if isinstance(message, list) else [{"role": "user", "content": message}]
|
messages = message if isinstance(message, list) else [{"role": "user", "content": message}]
|
||||||
result = await MemoryAgentService().write_memory(
|
result = await MemoryAgentService().write_memory(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
@@ -174,11 +287,8 @@ class MemoryAPIService:
|
|||||||
user_rag_memory_id=user_rag_memory_id or "",
|
user_rag_memory_id=user_rag_memory_id or "",
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Memory write successful for end_user: {end_user_id}")
|
logger.info(f"Memory write (sync) successful for end_user: {end_user_id}")
|
||||||
|
|
||||||
# result may be a string "success" or a dict with a "status" key
|
|
||||||
# Preserve the full dict so callers don't silently lose extra fields
|
|
||||||
# (e.g. error codes, metadata) returned by MemoryAgentService.
|
|
||||||
if isinstance(result, dict):
|
if isinstance(result, dict):
|
||||||
return {
|
return {
|
||||||
**result,
|
**result,
|
||||||
@@ -192,20 +302,17 @@ class MemoryAPIService:
|
|||||||
|
|
||||||
except ConfigurationError as e:
|
except ConfigurationError as e:
|
||||||
logger.error(f"Memory configuration error for end_user {end_user_id}: {e}")
|
logger.error(f"Memory configuration error for end_user {end_user_id}: {e}")
|
||||||
raise BusinessException(
|
raise BusinessException(message=str(e), code=BizCode.MEMORY_CONFIG_NOT_FOUND)
|
||||||
message=str(e),
|
|
||||||
code=BizCode.MEMORY_CONFIG_NOT_FOUND
|
|
||||||
)
|
|
||||||
except BusinessException:
|
except BusinessException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Memory write failed for end_user {end_user_id}: {e}")
|
logger.error(f"Memory write (sync) failed for end_user {end_user_id}: {e}")
|
||||||
raise BusinessException(
|
raise BusinessException(
|
||||||
message=f"Memory write failed: {str(e)}",
|
message=f"Memory write failed: {str(e)}",
|
||||||
code=BizCode.MEMORY_WRITE_FAILED
|
code=BizCode.MEMORY_WRITE_FAILED
|
||||||
)
|
)
|
||||||
|
|
||||||
async def read_memory(
|
async def read_memory_sync(
|
||||||
self,
|
self,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
@@ -215,14 +322,15 @@ class MemoryAPIService:
|
|||||||
storage_type: str = "neo4j",
|
storage_type: str = "neo4j",
|
||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Read memory with validation.
|
"""Read memory synchronously (inline, no Celery).
|
||||||
|
|
||||||
Validates end_user exists and belongs to workspace, updates the end user's
|
Validates end_user, then calls MemoryAgentService.read_memory directly.
|
||||||
memory_config_id, then delegates to MemoryAgentService.read_memory.
|
Blocks until the read completes. Use for cases where the caller needs
|
||||||
|
the answer immediately.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
workspace_id: Workspace ID for resource validation
|
workspace_id: Workspace ID for resource validation
|
||||||
end_user_id: End user identifier (used as end_user_id)
|
end_user_id: End user identifier
|
||||||
message: Query message
|
message: Query message
|
||||||
search_switch: Search mode (0=deep search with verification, 1=deep search, 2=fast search)
|
search_switch: Search mode (0=deep search with verification, 1=deep search, 2=fast search)
|
||||||
config_id: Memory configuration ID (required)
|
config_id: Memory configuration ID (required)
|
||||||
@@ -234,18 +342,14 @@ class MemoryAPIService:
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ResourceNotFoundException: If end_user not found
|
ResourceNotFoundException: If end_user not found
|
||||||
BusinessException: If end_user not in authorized workspace or read fails
|
BusinessException: If read fails
|
||||||
"""
|
"""
|
||||||
logger.info(f"Reading memory for end_user: {end_user_id}, workspace: {workspace_id}")
|
logger.info(f"Reading memory (sync) for end_user: {end_user_id}, workspace: {workspace_id}")
|
||||||
|
|
||||||
# Validate end_user exists and belongs to workspace
|
|
||||||
self.validate_end_user(end_user_id, workspace_id)
|
self.validate_end_user(end_user_id, workspace_id)
|
||||||
|
|
||||||
# Update end user's memory_config_id
|
|
||||||
self._update_end_user_config(end_user_id, config_id)
|
self._update_end_user_config(end_user_id, config_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Delegate to MemoryAgentService
|
|
||||||
result = await MemoryAgentService().read_memory(
|
result = await MemoryAgentService().read_memory(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
message=message,
|
message=message,
|
||||||
@@ -257,7 +361,7 @@ class MemoryAPIService:
|
|||||||
user_rag_memory_id=user_rag_memory_id or ""
|
user_rag_memory_id=user_rag_memory_id or ""
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Memory read successful for end_user: {end_user_id}")
|
logger.info(f"Memory read (sync) successful for end_user: {end_user_id}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"answer": result.get("answer", ""),
|
"answer": result.get("answer", ""),
|
||||||
@@ -267,14 +371,11 @@ class MemoryAPIService:
|
|||||||
|
|
||||||
except ConfigurationError as e:
|
except ConfigurationError as e:
|
||||||
logger.error(f"Memory configuration error for end_user {end_user_id}: {e}")
|
logger.error(f"Memory configuration error for end_user {end_user_id}: {e}")
|
||||||
raise BusinessException(
|
raise BusinessException(message=str(e), code=BizCode.MEMORY_CONFIG_NOT_FOUND)
|
||||||
message=str(e),
|
|
||||||
code=BizCode.MEMORY_CONFIG_NOT_FOUND
|
|
||||||
)
|
|
||||||
except BusinessException:
|
except BusinessException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Memory read failed for end_user {end_user_id}: {e}")
|
logger.error(f"Memory read (sync) failed for end_user {end_user_id}: {e}")
|
||||||
raise BusinessException(
|
raise BusinessException(
|
||||||
message=f"Memory read failed: {str(e)}",
|
message=f"Memory read failed: {str(e)}",
|
||||||
code=BizCode.MEMORY_READ_FAILED
|
code=BizCode.MEMORY_READ_FAILED
|
||||||
|
|||||||
@@ -265,12 +265,50 @@ async def Translation_English(modid, text, fields=None):
|
|||||||
# 其他类型(数字、布尔值、None等):原样返回
|
# 其他类型(数字、布尔值、None等):原样返回
|
||||||
else:
|
else:
|
||||||
return text
|
return text
|
||||||
|
# 隐性记忆画像生成所需的最低 MemorySummary 节点数量
|
||||||
|
MIN_MEMORY_SUMMARY_COUNT = 5
|
||||||
|
|
||||||
|
|
||||||
class MemoryBaseService:
|
class MemoryBaseService:
|
||||||
"""记忆服务基类,提供共享的辅助方法"""
|
"""记忆服务基类,提供共享的辅助方法"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.neo4j_connector = Neo4jConnector()
|
self.neo4j_connector = Neo4jConnector()
|
||||||
|
|
||||||
|
async def get_valid_memory_summary_count(
|
||||||
|
self,
|
||||||
|
end_user_id: str
|
||||||
|
) -> int:
|
||||||
|
"""获取用户有效的 MemorySummary 节点数量(排除孤立节点)。
|
||||||
|
|
||||||
|
只统计存在 DERIVED_FROM_STATEMENT 关系的 MemorySummary 节点。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 终端用户ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
有效 MemorySummary 节点数量
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
query = """
|
||||||
|
MATCH (n:MemorySummary)-[:DERIVED_FROM_STATEMENT]->(:Statement)
|
||||||
|
WHERE n.end_user_id = $end_user_id
|
||||||
|
RETURN count(DISTINCT n) as count
|
||||||
|
"""
|
||||||
|
result = await self.neo4j_connector.execute_query(
|
||||||
|
query, end_user_id=end_user_id
|
||||||
|
)
|
||||||
|
count = result[0]["count"] if result and len(result) > 0 else 0
|
||||||
|
logger.debug(
|
||||||
|
f"有效 MemorySummary 节点数量: {count} (end_user_id={end_user_id})"
|
||||||
|
)
|
||||||
|
return count
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"获取有效 MemorySummary 数量失败: {str(e)}", exc_info=True
|
||||||
|
)
|
||||||
|
return 0
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_timestamp(timestamp_value) -> Optional[int]:
|
def parse_timestamp(timestamp_value) -> Optional[int]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -233,7 +233,7 @@ class MemoryPerceptualService:
|
|||||||
api_key=model_config.api_key,
|
api_key=model_config.api_key,
|
||||||
base_url=model_config.api_base,
|
base_url=model_config.api_base,
|
||||||
is_omni=model_config.is_omni,
|
is_omni=model_config.is_omni,
|
||||||
support_thinking="thinking" in (model_config.capability or []),
|
capability=model_config.capability,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return llm, model_config
|
return llm, model_config
|
||||||
|
|||||||
@@ -47,7 +47,8 @@ class ModelParameterMerger:
|
|||||||
"n": 1,
|
"n": 1,
|
||||||
"stop": None,
|
"stop": None,
|
||||||
"deep_thinking": False,
|
"deep_thinking": False,
|
||||||
"thinking_budget_tokens": None
|
"thinking_budget_tokens": None,
|
||||||
|
"json_output": False
|
||||||
}
|
}
|
||||||
|
|
||||||
# 合并参数:默认值 -> 模型配置 -> Agent 配置
|
# 合并参数:默认值 -> 模型配置 -> Agent 配置
|
||||||
|
|||||||
@@ -125,9 +125,11 @@ class ModelConfigService:
|
|||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
is_omni=is_omni,
|
is_omni=is_omni,
|
||||||
support_thinking="thinking" in (capability or []),
|
capability=capability,
|
||||||
temperature=0.7,
|
extra_params={
|
||||||
max_tokens=100
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 100
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 根据模型类型选择不同的验证方式
|
# 根据模型类型选择不同的验证方式
|
||||||
@@ -729,10 +731,21 @@ class ModelApiKeyService:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def delete_api_key(db: Session, api_key_id: uuid.UUID) -> bool:
|
def delete_api_key(db: Session, api_key_id: uuid.UUID) -> bool:
|
||||||
"""删除API Key"""
|
"""删除API Key"""
|
||||||
if not ModelApiKeyRepository.get_by_id(db, api_key_id):
|
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||||
|
if not api_key:
|
||||||
raise BusinessException("API Key不存在", BizCode.NOT_FOUND)
|
raise BusinessException("API Key不存在", BizCode.NOT_FOUND)
|
||||||
|
|
||||||
|
model_config_ids = [mc.id for mc in api_key.model_configs]
|
||||||
|
|
||||||
success = ModelApiKeyRepository.delete(db, api_key_id)
|
success = ModelApiKeyRepository.delete(db, api_key_id)
|
||||||
|
|
||||||
|
for model_config_id in model_config_ids:
|
||||||
|
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
||||||
|
if model_config:
|
||||||
|
has_active_key = any(key.is_active for key in model_config.api_keys)
|
||||||
|
if not has_active_key and model_config.is_active:
|
||||||
|
model_config.is_active = False
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
return success
|
return success
|
||||||
|
|
||||||
|
|||||||
@@ -2616,9 +2616,11 @@ class MultiAgentOrchestrator:
|
|||||||
api_key=api_key_config.api_key,
|
api_key=api_key_config.api_key,
|
||||||
base_url=api_key_config.api_base,
|
base_url=api_key_config.api_base,
|
||||||
is_omni=api_key_config.is_omni,
|
is_omni=api_key_config.is_omni,
|
||||||
support_thinking="thinking" in (api_key_config.capability or []),
|
capability=api_key_config.capability,
|
||||||
temperature=0.7, # 整合任务使用中等温度
|
extra_params={
|
||||||
max_tokens=2000
|
"temperature": 0.7, # 整合任务使用中等温度
|
||||||
|
"max_tokens": 2000
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建 LLM 实例
|
# 创建 LLM 实例
|
||||||
@@ -2795,10 +2797,12 @@ class MultiAgentOrchestrator:
|
|||||||
api_key=api_key_config.api_key,
|
api_key=api_key_config.api_key,
|
||||||
base_url=api_key_config.api_base,
|
base_url=api_key_config.api_base,
|
||||||
is_omni=api_key_config.is_omni,
|
is_omni=api_key_config.is_omni,
|
||||||
support_thinking="thinking" in (api_key_config.capability or []),
|
capability=api_key_config.capability,
|
||||||
temperature=0.7,
|
extra_params={
|
||||||
max_tokens=2000,
|
"temperature": 0.7,
|
||||||
extra_params={"streaming": True} # 启用流式输出
|
"max_tokens": 2000,
|
||||||
|
"streaming": True # 启用流式输出
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建 LLM 实例
|
# 创建 LLM 实例
|
||||||
|
|||||||
@@ -186,7 +186,7 @@ class PromptOptimizerService:
|
|||||||
api_key=api_config.api_key,
|
api_key=api_config.api_key,
|
||||||
base_url=api_config.api_base,
|
base_url=api_config.api_base,
|
||||||
is_omni=api_config.is_omni,
|
is_omni=api_config.is_omni,
|
||||||
support_thinking="thinking" in (api_config.capability or []),
|
capability=api_config.capability,
|
||||||
), type=ModelType(model_config.type))
|
), type=ModelType(model_config.type))
|
||||||
try:
|
try:
|
||||||
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
||||||
@@ -227,10 +227,20 @@ class PromptOptimizerService:
|
|||||||
content = getattr(chunk, "content", chunk)
|
content = getattr(chunk, "content", chunk)
|
||||||
if not content:
|
if not content:
|
||||||
continue
|
continue
|
||||||
buffer += content
|
if isinstance(content, str):
|
||||||
|
buffer += content
|
||||||
|
elif isinstance(content, list):
|
||||||
|
for _ in content:
|
||||||
|
buffer += _["text"]
|
||||||
|
else:
|
||||||
|
logger.error(f"Unsupported content type - {content}")
|
||||||
|
raise Exception("Unsupported content type")
|
||||||
cache = buffer[:-20]
|
cache = buffer[:-20]
|
||||||
|
last_idx = 19
|
||||||
|
while cache and cache[-1] == '\\' and last_idx > 0:
|
||||||
|
cache = buffer[:-last_idx]
|
||||||
|
last_idx -= 1
|
||||||
|
|
||||||
# 尝试找到 "prompt": " 开始位置
|
|
||||||
if prompt_finished:
|
if prompt_finished:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -272,7 +282,7 @@ class PromptOptimizerService:
|
|||||||
def parser_prompt_variables(prompt: str):
|
def parser_prompt_variables(prompt: str):
|
||||||
try:
|
try:
|
||||||
pattern = r'\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}\}'
|
pattern = r'\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}\}'
|
||||||
matches = re.findall(pattern, prompt)
|
matches = re.findall(pattern, str(prompt))
|
||||||
variables = list(set(matches))
|
variables = list(set(matches))
|
||||||
return variables
|
return variables
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -250,7 +250,8 @@ class SharedChatService:
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
deep_thinking=model_parameters.get("deep_thinking", False),
|
deep_thinking=model_parameters.get("deep_thinking", False),
|
||||||
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
|
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
|
||||||
capability=api_key_obj.capability or [],
|
json_output=model_parameters.get("json_output", False),
|
||||||
|
capability=api_key_obj.capability,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 加载历史消息
|
# 加载历史消息
|
||||||
@@ -455,6 +456,7 @@ class SharedChatService:
|
|||||||
streaming=True,
|
streaming=True,
|
||||||
deep_thinking=model_parameters.get("deep_thinking", False),
|
deep_thinking=model_parameters.get("deep_thinking", False),
|
||||||
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
|
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
|
||||||
|
json_output=model_parameters.get("json_output", False),
|
||||||
capability=api_key_obj.capability or [],
|
capability=api_key_obj.capability or [],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from pydantic import BaseModel, Field
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.logging_config import get_logger
|
from app.core.logging_config import get_logger
|
||||||
|
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import _USER_PLACEHOLDER_NAMES
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.repositories.conversation_repository import ConversationRepository
|
from app.repositories.conversation_repository import ConversationRepository
|
||||||
@@ -21,7 +22,7 @@ from app.repositories.end_user_repository import EndUserRepository
|
|||||||
from app.repositories.neo4j.cypher_queries import Graph_Node_query
|
from app.repositories.neo4j.cypher_queries import Graph_Node_query
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping
|
from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping
|
||||||
from app.services.memory_base_service import MemoryBaseService
|
from app.services.memory_base_service import MemoryBaseService, MIN_MEMORY_SUMMARY_COUNT
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||||
from app.services.memory_short_service import ShortService
|
from app.services.memory_short_service import ShortService
|
||||||
@@ -400,12 +401,21 @@ class UserMemoryService:
|
|||||||
# 构建响应数据(转换时间为毫秒时间戳)
|
# 构建响应数据(转换时间为毫秒时间戳)
|
||||||
# 将 meta_data 中的 profile、knowledge_tags、behavioral_hints 平铺到顶层
|
# 将 meta_data 中的 profile、knowledge_tags、behavioral_hints 平铺到顶层
|
||||||
meta = end_user_info_record.meta_data or {}
|
meta = end_user_info_record.meta_data or {}
|
||||||
|
|
||||||
|
# profile 列表字段截断:只返回前 MAX_PROFILE_LIST_SIZE 条(按时间从新到旧)
|
||||||
|
MAX_PROFILE_LIST_SIZE = 5
|
||||||
|
profile = meta.get("profile")
|
||||||
|
if isinstance(profile, dict):
|
||||||
|
for key in ("role", "domain", "expertise", "interests"):
|
||||||
|
if isinstance(profile.get(key), list):
|
||||||
|
profile[key] = profile[key][:MAX_PROFILE_LIST_SIZE]
|
||||||
|
|
||||||
response_data = {
|
response_data = {
|
||||||
"end_user_info_id": str(end_user_info_record.id),
|
"end_user_info_id": str(end_user_info_record.id),
|
||||||
"end_user_id": str(end_user_info_record.end_user_id),
|
"end_user_id": str(end_user_info_record.end_user_id),
|
||||||
"other_name": end_user_info_record.other_name,
|
"other_name": end_user_info_record.other_name,
|
||||||
"aliases": end_user_info_record.aliases,
|
"aliases": end_user_info_record.aliases,
|
||||||
"profile": meta.get("profile"),
|
"profile": profile,
|
||||||
"knowledge_tags": meta.get("knowledge_tags"),
|
"knowledge_tags": meta.get("knowledge_tags"),
|
||||||
"behavioral_hints": meta.get("behavioral_hints"),
|
"behavioral_hints": meta.get("behavioral_hints"),
|
||||||
"created_at": datetime_to_timestamp(end_user_info_record.created_at),
|
"created_at": datetime_to_timestamp(end_user_info_record.created_at),
|
||||||
@@ -477,7 +487,7 @@ class UserMemoryService:
|
|||||||
allowed_fields = {'other_name', 'aliases', 'meta_data'}
|
allowed_fields = {'other_name', 'aliases', 'meta_data'}
|
||||||
|
|
||||||
# 用户占位名称黑名单,不允许作为 other_name 或出现在 aliases 中
|
# 用户占位名称黑名单,不允许作为 other_name 或出现在 aliases 中
|
||||||
_user_placeholder_names = {'用户', '我', 'User', 'I'}
|
_user_placeholder_names = _USER_PLACEHOLDER_NAMES
|
||||||
|
|
||||||
# 过滤 other_name:不允许设置为占位名称
|
# 过滤 other_name:不允许设置为占位名称
|
||||||
if 'other_name' in update_data and update_data['other_name'] and update_data['other_name'].strip() in _user_placeholder_names:
|
if 'other_name' in update_data and update_data['other_name'] and update_data['other_name'].strip() in _user_placeholder_names:
|
||||||
@@ -1504,7 +1514,7 @@ async def analytics_memory_types(
|
|||||||
2. 工作记忆 (WORKING_MEMORY) = 会话数量(通过 ConversationRepository.get_conversation_by_user_id 获取)
|
2. 工作记忆 (WORKING_MEMORY) = 会话数量(通过 ConversationRepository.get_conversation_by_user_id 获取)
|
||||||
3. 短期记忆 (SHORT_TERM_MEMORY) = /short_term 接口返回的问答对数量
|
3. 短期记忆 (SHORT_TERM_MEMORY) = /short_term 接口返回的问答对数量
|
||||||
4. 显性记忆 (EXPLICIT_MEMORY) = 情景记忆 + 语义记忆(通过 MemoryBaseService.get_explicit_memory_count 获取)
|
4. 显性记忆 (EXPLICIT_MEMORY) = 情景记忆 + 语义记忆(通过 MemoryBaseService.get_explicit_memory_count 获取)
|
||||||
5. 隐性记忆 (IMPLICIT_MEMORY) = Statement 节点数量的三分之一
|
5. 隐性记忆 (IMPLICIT_MEMORY) = MemorySummary 节点数量(需 >= MIN_MEMORY_SUMMARY_COUNT 才显示,否则为 0)
|
||||||
6. 情绪记忆 (EMOTIONAL_MEMORY) = 情绪标签统计总数(通过 MemoryBaseService.get_emotional_memory_count 获取)
|
6. 情绪记忆 (EMOTIONAL_MEMORY) = 情绪标签统计总数(通过 MemoryBaseService.get_emotional_memory_count 获取)
|
||||||
7. 情景记忆 (EPISODIC_MEMORY) = memory_summary(通过 MemoryBaseService.get_episodic_memory_count 获取)
|
7. 情景记忆 (EPISODIC_MEMORY) = memory_summary(通过 MemoryBaseService.get_episodic_memory_count 获取)
|
||||||
8. 遗忘记忆 (FORGET_MEMORY) = 激活值低于阈值的节点数(通过 MemoryBaseService.get_forget_memory_count 获取)
|
8. 遗忘记忆 (FORGET_MEMORY) = 激活值低于阈值的节点数(通过 MemoryBaseService.get_forget_memory_count 获取)
|
||||||
@@ -1561,23 +1571,15 @@ async def analytics_memory_types(
|
|||||||
logger.warning(f"获取会话数量失败,工作记忆数量设为0: {str(e)}")
|
logger.warning(f"获取会话数量失败,工作记忆数量设为0: {str(e)}")
|
||||||
work_count = 0
|
work_count = 0
|
||||||
|
|
||||||
# 获取隐性记忆数量(基于 Statement 节点数量的三分之一)
|
# 获取隐性记忆数量(基于有关联关系的 MemorySummary 节点数量,需 >= MIN_MEMORY_SUMMARY_COUNT 才计入)
|
||||||
implicit_count = 0
|
implicit_count = 0
|
||||||
if end_user_id:
|
if end_user_id:
|
||||||
try:
|
try:
|
||||||
# 查询 Statement 节点数量
|
memory_summary_count = await base_service.get_valid_memory_summary_count(end_user_id)
|
||||||
query = """
|
implicit_count = memory_summary_count if memory_summary_count >= MIN_MEMORY_SUMMARY_COUNT else 0
|
||||||
MATCH (n:Statement)
|
logger.debug(f"隐性记忆数量(有效MemorySummary节点数): {implicit_count} (有效MemorySummary总数={memory_summary_count}, end_user_id={end_user_id})")
|
||||||
WHERE n.end_user_id = $end_user_id
|
|
||||||
RETURN count(n) as count
|
|
||||||
"""
|
|
||||||
result = await _neo4j_connector.execute_query(query, end_user_id=end_user_id)
|
|
||||||
statement_count = result[0]["count"] if result and len(result) > 0 else 0
|
|
||||||
# 取三分之一作为隐性记忆数量
|
|
||||||
implicit_count = round(statement_count / 3)
|
|
||||||
logger.debug(f"隐性记忆数量(Statement数量的1/3): {implicit_count} (Statement总数={statement_count}, end_user_id={end_user_id})")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"获取Statement数量失败,隐性记忆数量设为0: {str(e)}")
|
logger.warning(f"获取MemorySummary数量失败,隐性记忆数量设为0: {str(e)}")
|
||||||
implicit_count = 0
|
implicit_count = 0
|
||||||
|
|
||||||
# 原有的基于行为习惯的统计方式(已注释)
|
# 原有的基于行为习惯的统计方式(已注释)
|
||||||
@@ -1643,7 +1645,7 @@ async def analytics_memory_types(
|
|||||||
"WORKING_MEMORY": work_count, # 工作记忆(基于会话数量)
|
"WORKING_MEMORY": work_count, # 工作记忆(基于会话数量)
|
||||||
"SHORT_TERM_MEMORY": short_term_count, # 短期记忆(基于问答对数量)
|
"SHORT_TERM_MEMORY": short_term_count, # 短期记忆(基于问答对数量)
|
||||||
"EXPLICIT_MEMORY": explicit_count, # 显性记忆(情景记忆 + 语义记忆)
|
"EXPLICIT_MEMORY": explicit_count, # 显性记忆(情景记忆 + 语义记忆)
|
||||||
"IMPLICIT_MEMORY": implicit_count, # 隐性记忆(Statement数量的1/3)
|
"IMPLICIT_MEMORY": implicit_count, # 隐性记忆(MemorySummary节点数,需>=MIN_MEMORY_SUMMARY_COUNT)
|
||||||
"EMOTIONAL_MEMORY": emotion_count, # 情绪记忆(使用情绪标签统计)
|
"EMOTIONAL_MEMORY": emotion_count, # 情绪记忆(使用情绪标签统计)
|
||||||
"EPISODIC_MEMORY": episodic_count, # 情景记忆
|
"EPISODIC_MEMORY": episodic_count, # 情景记忆
|
||||||
"FORGET_MEMORY": forget_count # 遗忘记忆(激活值低于阈值)
|
"FORGET_MEMORY": forget_count # 遗忘记忆(激活值低于阈值)
|
||||||
|
|||||||
@@ -285,7 +285,7 @@ def activate_user(db: Session, user_id_to_activate: uuid.UUID, current_user: Use
|
|||||||
try:
|
try:
|
||||||
# 查找用户
|
# 查找用户
|
||||||
business_logger.debug(f"查找待激活用户: {user_id_to_activate}")
|
business_logger.debug(f"查找待激活用户: {user_id_to_activate}")
|
||||||
db_user = user_repository.get_user_by_id(db, user_id=user_id_to_activate)
|
db_user = user_repository.get_user_by_id_regardless_active(db, user_id=user_id_to_activate)
|
||||||
if not db_user:
|
if not db_user:
|
||||||
business_logger.warning(f"用户不存在: {user_id_to_activate}")
|
business_logger.warning(f"用户不存在: {user_id_to_activate}")
|
||||||
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||||
|
|||||||
@@ -957,7 +957,10 @@ class WorkflowService:
|
|||||||
for file in message["content"]:
|
for file in message["content"]:
|
||||||
human_meta["files"].append({
|
human_meta["files"].append({
|
||||||
"type": file.get("type"),
|
"type": file.get("type"),
|
||||||
"url": file.get("url")
|
"url": file.get("url"),
|
||||||
|
"file_type": file.get("origin_file_type"),
|
||||||
|
"name": file.get("name"),
|
||||||
|
"size": file.get("size")
|
||||||
})
|
})
|
||||||
if message["role"] == "assistant":
|
if message["role"] == "assistant":
|
||||||
assistant_message = message["content"]
|
assistant_message = message["content"]
|
||||||
|
|||||||
104
api/app/tasks.py
@@ -455,7 +455,7 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
|
|||||||
db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first()
|
db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first()
|
||||||
if db_knowledge is None:
|
if db_knowledge is None:
|
||||||
logger.error(f"[GraphRAG-KB] knowledge={kb_id} not found")
|
logger.error(f"[GraphRAG-KB] knowledge={kb_id} not found")
|
||||||
return f"build knowledge graph failed: knowledge not found"
|
return "build knowledge graph failed: knowledge not found"
|
||||||
|
|
||||||
if not (db_knowledge.parser_config and
|
if not (db_knowledge.parser_config and
|
||||||
db_knowledge.parser_config.get("graphrag", {}).get("use_graphrag", False)):
|
db_knowledge.parser_config.get("graphrag", {}).get("use_graphrag", False)):
|
||||||
@@ -538,7 +538,7 @@ def build_graphrag_for_document(document_id: str, knowledge_id: str):
|
|||||||
db_knowledge = db.query(Knowledge).filter(Knowledge.id == uuid.UUID(knowledge_id)).first()
|
db_knowledge = db.query(Knowledge).filter(Knowledge.id == uuid.UUID(knowledge_id)).first()
|
||||||
if db_document is None or db_knowledge is None:
|
if db_document is None or db_knowledge is None:
|
||||||
logger.error(f"[GraphRAG] document={document_id} or knowledge={knowledge_id} not found")
|
logger.error(f"[GraphRAG] document={document_id} or knowledge={knowledge_id} not found")
|
||||||
return f"build_graphrag_for_document failed: record not found"
|
return "build_graphrag_for_document failed: record not found"
|
||||||
|
|
||||||
graphrag_conf = db_knowledge.parser_config.get("graphrag", {})
|
graphrag_conf = db_knowledge.parser_config.get("graphrag", {})
|
||||||
with_resolution = graphrag_conf.get("resolution", False)
|
with_resolution = graphrag_conf.get("resolution", False)
|
||||||
@@ -617,7 +617,7 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
|||||||
db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first()
|
db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first()
|
||||||
if db_knowledge is None:
|
if db_knowledge is None:
|
||||||
logger.error(f"[SyncKB] knowledge={kb_id} not found")
|
logger.error(f"[SyncKB] knowledge={kb_id} not found")
|
||||||
return f"sync knowledge failed: knowledge not found"
|
return "sync knowledge failed: knowledge not found"
|
||||||
|
|
||||||
# 1. get vector_service
|
# 1. get vector_service
|
||||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||||
@@ -3102,29 +3102,11 @@ def extract_user_metadata_task(
|
|||||||
logger.info(f"[CELERY METADATA] No metadata extracted for end_user_id={end_user_id}")
|
logger.info(f"[CELERY METADATA] No metadata extracted for end_user_id={end_user_id}")
|
||||||
return {"status": "SUCCESS", "result": "no_metadata_extracted"}
|
return {"status": "SUCCESS", "result": "no_metadata_extracted"}
|
||||||
|
|
||||||
user_metadata, aliases_to_add, aliases_to_remove = extract_result
|
metadata_changes, aliases_to_add, aliases_to_remove = extract_result
|
||||||
logger.info(f"[CELERY METADATA] LLM 别名新增: {aliases_to_add}, 移除: {aliases_to_remove}")
|
logger.info(
|
||||||
|
f"[CELERY METADATA] LLM 元数据变更: {[c.model_dump() for c in metadata_changes]}, "
|
||||||
# 4. 清洗元数据、覆盖写入元数据和别名
|
f"别名新增: {aliases_to_add}, 移除: {aliases_to_remove}"
|
||||||
def clean_metadata(raw: dict) -> dict:
|
)
|
||||||
"""递归移除空字符串、空列表、空字典。"""
|
|
||||||
result = {}
|
|
||||||
for k, v in raw.items():
|
|
||||||
if v == "" or v == []:
|
|
||||||
continue
|
|
||||||
if isinstance(v, dict):
|
|
||||||
cleaned = clean_metadata(v)
|
|
||||||
if cleaned:
|
|
||||||
result[k] = cleaned
|
|
||||||
else:
|
|
||||||
result[k] = v
|
|
||||||
return result
|
|
||||||
|
|
||||||
raw_dict = user_metadata.model_dump(exclude_none=True) if user_metadata else {}
|
|
||||||
logger.info(f"[CELERY METADATA] LLM 输出完整元数据: {json.dumps(raw_dict, ensure_ascii=False)}")
|
|
||||||
|
|
||||||
cleaned = clean_metadata(raw_dict) if raw_dict else {}
|
|
||||||
logger.info(f"[CELERY METADATA] 清洗后元数据: {json.dumps(cleaned, ensure_ascii=False)}")
|
|
||||||
|
|
||||||
from datetime import datetime as dt, timezone as tz
|
from datetime import datetime as dt, timezone as tz
|
||||||
now = dt.now(tz.utc).isoformat()
|
now = dt.now(tz.utc).isoformat()
|
||||||
@@ -3152,15 +3134,49 @@ def extract_user_metadata_task(
|
|||||||
end_user = EndUserRepository(db).get_by_id(end_user_uuid)
|
end_user = EndUserRepository(db).get_by_id(end_user_uuid)
|
||||||
|
|
||||||
if info:
|
if info:
|
||||||
# 元数据覆盖写入
|
# 4. 元数据增量更新(按 LLM 输出的变更操作逐条执行,所有字段均为列表类型)
|
||||||
if cleaned:
|
if metadata_changes:
|
||||||
existing_meta = info.meta_data if info.meta_data else {}
|
# 深拷贝,确保 SQLAlchemy 能检测到变更
|
||||||
|
import copy
|
||||||
|
existing_meta = copy.deepcopy(info.meta_data) if info.meta_data else {}
|
||||||
updated_at = dict(existing_meta.get("_updated_at", {}))
|
updated_at = dict(existing_meta.get("_updated_at", {}))
|
||||||
_update_timestamps(existing_meta, cleaned, updated_at, now)
|
|
||||||
final = dict(cleaned)
|
for change in metadata_changes:
|
||||||
final["_updated_at"] = updated_at
|
field_path = change.field_path
|
||||||
info.meta_data = final
|
action = change.action
|
||||||
logger.info("[CELERY METADATA] 覆盖写入元数据")
|
value = change.value
|
||||||
|
|
||||||
|
if not value or not value.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 定位到目标字段的父级节点
|
||||||
|
parts = field_path.split(".")
|
||||||
|
target = existing_meta
|
||||||
|
for part in parts[:-1]:
|
||||||
|
target = target.setdefault(part, {})
|
||||||
|
leaf = parts[-1]
|
||||||
|
|
||||||
|
current_list = target.get(leaf, [])
|
||||||
|
|
||||||
|
if action == "set":
|
||||||
|
if value not in current_list:
|
||||||
|
# 新值插入列表头部,保证按时间从新到旧排序
|
||||||
|
current_list.insert(0, value)
|
||||||
|
target[leaf] = current_list
|
||||||
|
logger.info(f"[CELERY METADATA] set {field_path} = {value}")
|
||||||
|
|
||||||
|
elif action == "remove":
|
||||||
|
if value in current_list:
|
||||||
|
current_list.remove(value)
|
||||||
|
target[leaf] = current_list
|
||||||
|
logger.info(f"[CELERY METADATA] remove {value} from {field_path}")
|
||||||
|
|
||||||
|
updated_at[field_path] = now
|
||||||
|
|
||||||
|
existing_meta["_updated_at"] = updated_at
|
||||||
|
# 赋值深拷贝后的新对象,SQLAlchemy 会检测到字段变更并写入
|
||||||
|
info.meta_data = existing_meta
|
||||||
|
logger.info(f"[CELERY METADATA] 增量更新元数据完成: {json.dumps(existing_meta, ensure_ascii=False)}")
|
||||||
|
|
||||||
# 别名增量增删:(已有 - remove) + add
|
# 别名增量增删:(已有 - remove) + add
|
||||||
old_aliases = info.aliases if info.aliases else []
|
old_aliases = info.aliases if info.aliases else []
|
||||||
@@ -3196,12 +3212,28 @@ def extract_user_metadata_task(
|
|||||||
from app.models.end_user_info_model import EndUserInfo
|
from app.models.end_user_info_model import EndUserInfo
|
||||||
initial_aliases = filtered_add # 新记录只有 add,没有 remove
|
initial_aliases = filtered_add # 新记录只有 add,没有 remove
|
||||||
first_alias = initial_aliases[0] if initial_aliases else ""
|
first_alias = initial_aliases[0] if initial_aliases else ""
|
||||||
if first_alias or cleaned:
|
|
||||||
|
# 从变更操作构建初始元数据(所有字段均为列表类型)
|
||||||
|
initial_meta = {}
|
||||||
|
for change in metadata_changes:
|
||||||
|
if change.action == "set" and change.value is not None and change.value.strip():
|
||||||
|
parts = change.field_path.split(".")
|
||||||
|
target = initial_meta
|
||||||
|
for part in parts[:-1]:
|
||||||
|
target = target.setdefault(part, {})
|
||||||
|
leaf = parts[-1]
|
||||||
|
current_list = target.get(leaf, [])
|
||||||
|
if change.value not in current_list:
|
||||||
|
# 新值插入列表头部,保证按时间从新到旧排序
|
||||||
|
current_list.insert(0, change.value)
|
||||||
|
target[leaf] = current_list
|
||||||
|
|
||||||
|
if first_alias or initial_meta:
|
||||||
new_info = EndUserInfo(
|
new_info = EndUserInfo(
|
||||||
end_user_id=end_user_uuid,
|
end_user_id=end_user_uuid,
|
||||||
other_name=first_alias or "",
|
other_name=first_alias or "",
|
||||||
aliases=initial_aliases,
|
aliases=initial_aliases,
|
||||||
meta_data=cleaned if cleaned else None,
|
meta_data=initial_meta if initial_meta else None,
|
||||||
)
|
)
|
||||||
db.add(new_info)
|
db.add(new_info)
|
||||||
if end_user and first_alias and (
|
if end_user and first_alias and (
|
||||||
|
|||||||
@@ -1,4 +1,40 @@
|
|||||||
{
|
{
|
||||||
|
"v0.3.0": {
|
||||||
|
"introduction": {
|
||||||
|
"codeName": "破晓",
|
||||||
|
"releaseDate": "2026-4-15",
|
||||||
|
"upgradePosition": "🐻 全面升级应用工作流、记忆智能与系统稳健性,引入版本化API、多模态记忆感知及大量工作流增强,打造更可靠、精准的 MemoryBear",
|
||||||
|
"coreUpgrades": [
|
||||||
|
"1. 应用与API增强<br>* 版本化API调用支持:对外服务API支持指定版本调用<br>* 工作流检查清单:新增结构化验证步骤<br>* 深度思考参数精准控制:仅向支持深度推理的模型发送思考参数<br>* 提示器模型返回优化:优化提示器模型响应处理",
|
||||||
|
"2. 记忆智能 🧠<br>* 多模态记忆感知Agent:支持多模态记忆读取与写入<br>* OpenClaw内置工具:新增内置工具扩展Agent工具集",
|
||||||
|
"3. 用户体验 🎨<br>* 流式渲染稳定性优化:解决LLM流式输出页面抖动问题<br>* 记忆中枢更名:「记忆相关」更名为「记忆中枢」",
|
||||||
|
"4. 工作流改进 ⚙️<br>* 三级变量模板转换:支持三级变量解析<br>* VL模型Token统计:修复模型组合中VL模型Token未统计问题<br>* 导入工作流功能特性同步:正确同步开场白、引用等属性<br>* 会话变量名称唯一性校验:防止变量名冲突<br>* 文件类型提取修复:正确提取file.type信息<br>* 条件分支显示修复:值为0或会话变量时正确渲染<br>* Object/Array校验规则:防止JSON序列化错误<br>* HTTP请求Body字段修正:body字段从name改为key",
|
||||||
|
"5. 知识库 📚<br>* Embedding Token截断安全边界:统一添加8000 token截断,优化Excel独立chunk处理",
|
||||||
|
"6. 稳健性与缺陷修复 🔧<br>* 原子性更新与批量访问失败修复<br>* 对话别名提取错误修复<br>* 工作流别名提取修正(区分用户和AI回复)<br>* RAG记忆分页数据修复<br>* 隐式记忆详情显示修复<br>* 向量查询驱动关闭异常修复<br>* 用户管理启停异常修复<br>* 模型列表筛选不一致修复",
|
||||||
|
"<br>",
|
||||||
|
"v0.3.0 标志着 MemoryBear 向生产成熟度迈出坚实一步。后续版本将持续深化工作流表达力、记忆检索精度和跨模态理解能力,强化复杂Agent编排支持,稳固大规模生产部署基础。",
|
||||||
|
"<br>",
|
||||||
|
"MemoryBear — 破晓 🐻✨"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"introduction_en": {
|
||||||
|
"codeName": "PoXiao",
|
||||||
|
"releaseDate": "2026-4-15",
|
||||||
|
"upgradePosition": "🐻 Comprehensive upgrades across application workflows, memory intelligence, and system robustness — introducing versioned APIs, multimodal memory perception, and extensive workflow enhancements for a more reliable MemoryBear",
|
||||||
|
"coreUpgrades": [
|
||||||
|
"1. Application & API Enhancements<br>* Versioned API Support: External APIs now support version-specific calls<br>* Workflow Checklist: Structured validation steps before deployment<br>* Deep Thinking Parameter Control: Only send thinking params to supported models<br>* Prompt Optimizer Return Optimization: Improved prompt optimizer response handling",
|
||||||
|
"2. Memory Intelligence 🧠<br>* Multimodal Memory Perception Agent: Read/write multimodal memory<br>* OpenClaw Built-in Tool: New built-in tool for agent operations",
|
||||||
|
"3. User Experience 🎨<br>* Streaming Render Stabilization: Eliminated page jitter during LLM output<br>* Memory Hub Renaming: Renamed to better reflect central memory role",
|
||||||
|
"4. Workflow Improvements ⚙️<br>* Three-Level Variable Template Conversion: Support for three-level variable resolution<br>* VL Model Token Tracking: Fixed token tracking for VL models in model groups<br>* Imported Workflow Feature Sync: Properly sync opening messages, citations, etc.<br>* Session Variable Name Uniqueness: Prevent variable name conflicts<br>* File Type Extraction Fix: Correctly extract file.type information<br>* Condition Branch Display Fix: Correct rendering for value 0 or session variables<br>* Object/Array Validation Rules: Prevent JSON serialization save errors<br>* HTTP Request Body Key Fix: Body field uses key instead of name",
|
||||||
|
"5. Knowledge Base 📚<br>* Embedding Token Truncation Safety: Unified 8000-token boundary, optimized Excel chunk processing",
|
||||||
|
"6. Robustness & Bug Fixes 🔧<br>* Atomic update & batch access failure fixes<br>* Conversation alias extraction fix<br>* Workflow alias extraction correction (user vs AI distinction)<br>* RAG memory pagination fix<br>* Implicit memory detail display fix<br>* Vector query driver closed exception fix<br>* User management enable/disable fix<br>* Model list filter inconsistency fix",
|
||||||
|
"<br>",
|
||||||
|
"v0.3.0 marks a meaningful step toward production maturity for MemoryBear. Upcoming releases will deepen workflow expressiveness, memory retrieval precision, and cross-modal understanding while strengthening complex agent orchestration and large-scale deployment foundations.",
|
||||||
|
"<br>",
|
||||||
|
"MemoryBear — Daybreak 🐻✨"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
"v0.2.10": {
|
"v0.2.10": {
|
||||||
"introduction": {
|
"introduction": {
|
||||||
"codeName": "炼剑",
|
"codeName": "炼剑",
|
||||||
|
|||||||
@@ -93,7 +93,8 @@
|
|||||||
"typescript-eslint": "^8.45.0",
|
"typescript-eslint": "^8.45.0",
|
||||||
"unplugin-auto-import": "^20.2.0",
|
"unplugin-auto-import": "^20.2.0",
|
||||||
"unplugin-vue-components": "^29.1.0",
|
"unplugin-vue-components": "^29.1.0",
|
||||||
"vite": "npm:rolldown-vite@7.1.14"
|
"vite": "npm:rolldown-vite@7.1.14",
|
||||||
|
"vite-plugin-svgr": "^5.2.0"
|
||||||
},
|
},
|
||||||
"overrides": {
|
"overrides": {
|
||||||
"vite": "npm:rolldown-vite@7.1.14"
|
"vite": "npm:rolldown-vite@7.1.14"
|
||||||
|
|||||||
@@ -175,3 +175,7 @@ export const getAppLogsUrl = (app_id: string) => `/apps/${app_id}/logs`
|
|||||||
export const getAppLogDetail = (app_id: string, conversation_id: string) => {
|
export const getAppLogDetail = (app_id: string, conversation_id: string) => {
|
||||||
return request.get(`/apps/${app_id}/logs/${conversation_id}`)
|
return request.get(`/apps/${app_id}/logs/${conversation_id}`)
|
||||||
}
|
}
|
||||||
|
// Reset agent model config to default
|
||||||
|
export const resetAppModelConfig = (app_id: string) => {
|
||||||
|
return request.get(`/apps/${app_id}/model/parameters/default`)
|
||||||
|
}
|
||||||
8
web/src/api/package.ts
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
import { request } from '@/utils/request'
|
||||||
|
|
||||||
|
import type { Package } from '@/views/Package/types'
|
||||||
|
// 套餐列表
|
||||||
|
export const getPackageListUrl = `/package-plans`
|
||||||
|
export const getPackageList = (query?: { category?: Package['category']; status?: boolean; }) => {
|
||||||
|
return request.get(getPackageListUrl, query)
|
||||||
|
}
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 14:00:23
|
* @Date: 2026-02-03 14:00:23
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-02-25 11:17:44
|
* @Last Modified time: 2026-04-14 18:36:01
|
||||||
*/
|
*/
|
||||||
import { request } from '@/utils/request'
|
import { request } from '@/utils/request'
|
||||||
import type { CreateModalData, ChangeEmailModalForm } from '@/views/UserManagement/types'
|
import type { CreateModalData, ChangeEmailModalForm } from '@/views/UserManagement/types'
|
||||||
@@ -57,3 +57,8 @@ export const sendEmailCode = (data: { email: string }) => {
|
|||||||
export const changeEmail = (data: ChangeEmailModalForm) => {
|
export const changeEmail = (data: ChangeEmailModalForm) => {
|
||||||
return request.put('/users/change-email', data)
|
return request.put('/users/change-email', data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 获取租户套餐信息
|
||||||
|
export const getTenantSubscription = () => {
|
||||||
|
return request.get('/tenant/subscription')
|
||||||
|
}
|
||||||
17
web/src/assets/images/application/export.svg
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||||
|
<title>导出</title>
|
||||||
|
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd" stroke-linecap="round">
|
||||||
|
<g id="记忆库-个人记忆-感知记忆-文本" transform="translate(-573, -158)" stroke="#171719">
|
||||||
|
<g id="导出" transform="translate(573, 158)">
|
||||||
|
<g id="编组-54" transform="translate(3, 3)">
|
||||||
|
<path d="M10,6 L10,7.5 C10,8.88071187 8.88071187,10 7.5,10 L2.5,10 C1.11928813,10 0,8.88071187 0,7.5 L0,6 L0,6" id="路径"></path>
|
||||||
|
<g id="编组-11" transform="translate(2, 0)">
|
||||||
|
<line x1="3" y1="0.08499952" x2="3" y2="6.99635859" id="路径-24"></line>
|
||||||
|
<polyline id="路径-25" stroke-linejoin="round" points="0 3 2.98005548 6.08298138e-18 6 3"></polyline>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 1.1 KiB |
17
web/src/assets/images/application/import.svg
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||||
|
<title>导入</title>
|
||||||
|
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd" stroke-linecap="round">
|
||||||
|
<g id="记忆库-个人记忆-感知记忆-文本" transform="translate(-555, -158)" stroke="#171719">
|
||||||
|
<g id="导入" transform="translate(555, 158)">
|
||||||
|
<g id="编组-54" transform="translate(3, 3)">
|
||||||
|
<path d="M10,6 L10,7.5 C10,8.88071187 8.88071187,10 7.5,10 L2.5,10 C1.11928813,10 0,8.88071187 0,7.5 L0,6 L0,6" id="路径"></path>
|
||||||
|
<g id="编组-11" transform="translate(5, 3.4982) scale(1, -1) translate(-5, -3.4982)translate(2, 0)">
|
||||||
|
<line x1="3" y1="0.08499952" x2="3" y2="6.99635859" id="路径-24"></line>
|
||||||
|
<polyline id="路径-25" stroke-linejoin="round" points="0 3 2.98005548 6.08298138e-18 6 3"></polyline>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 1.1 KiB |
15
web/src/assets/images/common/close_grey.svg
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||||
|
<title>关闭</title>
|
||||||
|
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||||
|
<g id="应用管理-My-Shares" transform="translate(-1396, -127)" fill="#5B6167" fill-rule="nonzero">
|
||||||
|
<g id="卡片1备份-2" transform="translate(1044, 108)">
|
||||||
|
<g id="编组-12" transform="translate(349, 16)">
|
||||||
|
<g id="关闭" transform="translate(3, 3)">
|
||||||
|
<polygon id="路径" points="9.00000098 8 13.3333333 12.3333324 12.3333324 13.3333333 8 9.00000098 3.66666764 13.3333333 2.66666667 12.3333324 6.99999902 8 2.66666667 3.66666764 3.66666764 2.66666667 8 6.99999902 12.3333324 2.66666667 13.3333333 3.66666764"></polygon>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 1005 B |
16
web/src/assets/images/index/arrow_right_dark.svg
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||||
|
<title>编组 5</title>
|
||||||
|
<g id="V1.1" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||||
|
<g id="首页" transform="translate(-1229, -446)" stroke="#212332">
|
||||||
|
<g id="编组-13" transform="translate(1120, 300)">
|
||||||
|
<g id="编组-6" transform="translate(16, 138)">
|
||||||
|
<g id="编组-5" transform="translate(93, 8)">
|
||||||
|
<polyline id="路径" points="10 6 12 8 10 10"></polyline>
|
||||||
|
<line x1="12" y1="8" x2="2" y2="8" id="路径-2"></line>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 820 B |
@@ -1,17 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
|
||||||
<title>退出</title>
|
|
||||||
<g id="V1.0版" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round">
|
|
||||||
<g id="应用管理-编排-默认状态" transform="translate(-1262, -24)" stroke="#5B6167">
|
|
||||||
<g id="返回空间" transform="translate(1262, 24)">
|
|
||||||
<g id="退出" transform="translate(8, 8) scale(-1, 1) translate(-8, -8)">
|
|
||||||
<g id="编组-7" transform="translate(2.5, 2)">
|
|
||||||
<path d="M6,12 L1,12 C0.44771525,12 0,11.5522847 0,11 L0,1 C0,0.44771525 0.44771525,1.11022302e-16 1,0 L6,0 L6,0" id="路径"></path>
|
|
||||||
<line x1="11" y1="6" x2="3" y2="6" id="路径-6"></line>
|
|
||||||
<polyline id="路径" points="8 3 11 6 8 9"></polyline>
|
|
||||||
</g>
|
|
||||||
</g>
|
|
||||||
</g>
|
|
||||||
</g>
|
|
||||||
</g>
|
|
||||||
</svg>
|
|
||||||
|
Before Width: | Height: | Size: 1.1 KiB |
19
web/src/assets/images/logout_grey.svg
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||||
|
<title>退出</title>
|
||||||
|
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round">
|
||||||
|
<g id="空间配置" transform="translate(-22, -763)" stroke="#5B6167" stroke-width="1.2">
|
||||||
|
<g id="退出" transform="translate(0, 742)">
|
||||||
|
<g id="返回空间" transform="translate(12, 10)">
|
||||||
|
<g id="退出" transform="translate(10, 11)">
|
||||||
|
<g id="编组-7" transform="translate(2.5, 2)">
|
||||||
|
<path d="M6,12 L1,12 C0.44771525,12 0,11.5522847 0,11 L0,1 C0,0.44771525 0.44771525,1.11022302e-16 1,0 L6,0 L6,0" id="路径"></path>
|
||||||
|
<line x1="11" y1="6" x2="3" y2="6" id="路径-6"></line>
|
||||||
|
<polyline id="路径" points="8 3 11 6 8 9"></polyline>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 1.1 KiB |
@@ -1,17 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
|
||||||
<title>退出</title>
|
|
||||||
<g id="V1.0版" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round">
|
|
||||||
<g id="应用管理-编排-默认状态" transform="translate(-1262, -24)" stroke="#155EEF">
|
|
||||||
<g id="返回空间" transform="translate(1262, 24)">
|
|
||||||
<g id="退出" transform="translate(8, 8) scale(-1, 1) translate(-8, -8)">
|
|
||||||
<g id="编组-7" transform="translate(2.5, 2)">
|
|
||||||
<path d="M6,12 L1,12 C0.44771525,12 0,11.5522847 0,11 L0,1 C0,0.44771525 0.44771525,1.11022302e-16 1,0 L6,0 L6,0" id="路径"></path>
|
|
||||||
<line x1="11" y1="6" x2="3" y2="6" id="路径-6"></line>
|
|
||||||
<polyline id="路径" points="8 3 11 6 8 9"></polyline>
|
|
||||||
</g>
|
|
||||||
</g>
|
|
||||||
</g>
|
|
||||||
</g>
|
|
||||||
</g>
|
|
||||||
</svg>
|
|
||||||
|
Before Width: | Height: | Size: 1.1 KiB |