Merge branch 'develop' into refactor/memory_search
# Conflicts: # api/app/core/memory/storage_services/search/__init__.py
This commit is contained in:
@@ -12,7 +12,7 @@ import time
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.errors import GraphRecursionError
|
||||
|
||||
@@ -41,6 +41,7 @@ class LangChainAgent:
|
||||
max_tool_consecutive_calls: int = 3, # 单个工具最大连续调用次数
|
||||
deep_thinking: bool = False, # 是否启用深度思考模式
|
||||
thinking_budget_tokens: Optional[int] = None, # 深度思考 token 预算
|
||||
json_output: bool = False, # 是否强制 JSON 输出
|
||||
capability: Optional[List[str]] = None # 模型能力列表,用于校验是否支持深度思考
|
||||
):
|
||||
"""初始化 LangChain Agent
|
||||
@@ -64,7 +65,6 @@ class LangChainAgent:
|
||||
self.streaming = streaming
|
||||
self.is_omni = is_omni
|
||||
self.max_tool_consecutive_calls = max_tool_consecutive_calls
|
||||
self.deep_thinking = deep_thinking and ("thinking" in (capability or []))
|
||||
|
||||
# 工具调用计数器:记录每个工具的连续调用次数
|
||||
self.tool_call_counter: Dict[str, int] = {}
|
||||
@@ -80,6 +80,17 @@ class LangChainAgent:
|
||||
|
||||
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
||||
|
||||
# ChatTongyi 要求 messages 含 'json' 字样才能使用 response_format
|
||||
# 在 system prompt 中注入 JSON 要求
|
||||
from app.models.models_model import ModelProvider
|
||||
if json_output and (
|
||||
(provider.lower() == ModelProvider.DASHSCOPE and not is_omni)
|
||||
or provider.lower() == ModelProvider.VOLCANO
|
||||
# 有工具时 response_format 会被移除,所有 provider 都需要 system prompt 注入保证 JSON 输出
|
||||
or bool(tools)
|
||||
):
|
||||
self.system_prompt += "\n请以JSON格式输出。"
|
||||
|
||||
logger.debug(
|
||||
f"Agent 迭代次数配置: max_iterations={self.max_iterations}, "
|
||||
f"tool_count={len(self.tools)}, "
|
||||
@@ -87,23 +98,17 @@ class LangChainAgent:
|
||||
f"auto_calculated={max_iterations is None}"
|
||||
)
|
||||
|
||||
# 根据 capability 校验是否真正支持深度思考
|
||||
actual_deep_thinking = self.deep_thinking
|
||||
if deep_thinking and not actual_deep_thinking:
|
||||
logger.warning(
|
||||
f"模型 {model_name} 不支持深度思考(capability 中无 'thinking'),已自动关闭 deep_thinking"
|
||||
)
|
||||
|
||||
# 创建 RedBearLLM(支持多提供商)
|
||||
# 创建 RedBearLLM,capability 校验由 RedBearModelConfig 统一处理
|
||||
model_config = RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
is_omni=is_omni,
|
||||
deep_thinking=actual_deep_thinking,
|
||||
thinking_budget_tokens=thinking_budget_tokens if actual_deep_thinking else None,
|
||||
support_thinking="thinking" in (capability or []),
|
||||
capability=capability,
|
||||
deep_thinking=deep_thinking,
|
||||
thinking_budget_tokens=thinking_budget_tokens,
|
||||
json_output=json_output,
|
||||
extra_params={
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
@@ -112,6 +117,9 @@ class LangChainAgent:
|
||||
)
|
||||
|
||||
self.llm = RedBearLLM(model_config, type=ModelType.CHAT)
|
||||
# 从经过校验的 config 读取实际生效的能力开关
|
||||
self.deep_thinking = model_config.deep_thinking
|
||||
self.json_output = model_config.json_output
|
||||
|
||||
# 获取底层模型用于真正的流式调用
|
||||
self._underlying_llm = self.llm._model if hasattr(self.llm, '_model') else self.llm
|
||||
@@ -237,9 +245,7 @@ class LangChainAgent:
|
||||
Returns:
|
||||
List[BaseMessage]: 消息列表
|
||||
"""
|
||||
messages:list = [SystemMessage(content=self.system_prompt)]
|
||||
|
||||
# 添加系统提示词
|
||||
messages: list = []
|
||||
|
||||
# 添加历史消息
|
||||
if history:
|
||||
|
||||
@@ -96,6 +96,38 @@ def require_api_key(
|
||||
resource_id=api_key_obj.resource_id,
|
||||
)
|
||||
|
||||
# ── Tenant 级别限速(来自套餐配额 api_ops_rate_limit)──────────
|
||||
try:
|
||||
from app.models.workspace_model import Workspace
|
||||
from premium.platform_admin.package_plan_service import TenantSubscriptionService
|
||||
|
||||
workspace = db.query(Workspace).filter(
|
||||
Workspace.id == api_key_obj.workspace_id
|
||||
).first()
|
||||
if workspace:
|
||||
quota = TenantSubscriptionService(db).get_effective_quota(workspace.tenant_id)
|
||||
tenant_qps_limit = quota.get("api_ops_rate_limit") if quota else None
|
||||
if tenant_qps_limit:
|
||||
rate_limiter = RateLimiterService()
|
||||
tenant_ok, tenant_info = await rate_limiter.check_tenant_rate_limit(
|
||||
workspace.tenant_id, tenant_qps_limit
|
||||
)
|
||||
if not tenant_ok:
|
||||
raise RateLimitException(
|
||||
"租户 API 调用速率超限",
|
||||
BizCode.API_KEY_QPS_LIMIT_EXCEEDED,
|
||||
rate_headers={
|
||||
"X-RateLimit-Tenant-Limit": str(tenant_info["limit"]),
|
||||
"X-RateLimit-Tenant-Remaining": str(tenant_info["remaining"]),
|
||||
"X-RateLimit-Tenant-Reset": str(tenant_info["reset"]),
|
||||
}
|
||||
)
|
||||
except RateLimitException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"Tenant 限速检查异常,跳过: {e}")
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
|
||||
rate_limiter = RateLimiterService()
|
||||
is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj)
|
||||
if not is_allowed:
|
||||
|
||||
@@ -14,6 +14,7 @@ from dotenv import load_dotenv
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import _USER_PLACEHOLDER_NAMES
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import \
|
||||
memory_summary_generation
|
||||
@@ -191,15 +192,37 @@ async def write(
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
|
||||
# 使用 Celery 异步任务触发聚类(不阻塞主流程)
|
||||
if all_entity_nodes:
|
||||
end_user_id = all_entity_nodes[0].end_user_id
|
||||
|
||||
# Neo4j 写入完成后,用 PgSQL 权威 aliases 覆盖 Neo4j 用户实体
|
||||
try:
|
||||
from app.repositories.end_user_info_repository import EndUserInfoRepository
|
||||
if end_user_id:
|
||||
with get_db_context() as db_session:
|
||||
info = EndUserInfoRepository(db_session).get_by_end_user_id(uuid.UUID(end_user_id))
|
||||
pg_aliases = info.aliases if info and info.aliases else []
|
||||
if info is not None:
|
||||
# 将 Python 侧占位名集合作为参数传入,避免 Cypher 硬编码
|
||||
placeholder_names = list(_USER_PLACEHOLDER_NAMES)
|
||||
await neo4j_connector.execute_query(
|
||||
"""
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE e.end_user_id = $end_user_id AND toLower(e.name) IN $placeholder_names
|
||||
SET e.aliases = $aliases
|
||||
""",
|
||||
end_user_id=end_user_id, aliases=pg_aliases,
|
||||
placeholder_names=placeholder_names,
|
||||
)
|
||||
logger.info(f"[AliasSync] Neo4j 用户实体 aliases 已用 PgSQL 权威源覆盖: {pg_aliases}")
|
||||
except Exception as sync_err:
|
||||
logger.warning(f"[AliasSync] PgSQL→Neo4j aliases 同步失败(不影响主流程): {sync_err}")
|
||||
|
||||
# 使用 Celery 异步任务触发聚类(不阻塞主流程)
|
||||
try:
|
||||
from app.tasks import run_incremental_clustering
|
||||
|
||||
end_user_id = all_entity_nodes[0].end_user_id
|
||||
new_entity_ids = [e.id for e in all_entity_nodes]
|
||||
|
||||
# 异步提交 Celery 任务
|
||||
task = run_incremental_clustering.apply_async(
|
||||
kwargs={
|
||||
"end_user_id": end_user_id,
|
||||
@@ -207,7 +230,6 @@ async def write(
|
||||
"llm_model_id": str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||
"embedding_model_id": str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
||||
},
|
||||
# 设置任务优先级(低优先级,不影响主业务)
|
||||
priority=3,
|
||||
)
|
||||
logger.info(
|
||||
@@ -215,7 +237,6 @@ async def write(
|
||||
f"task_id={task.id}, end_user_id={end_user_id}, entity_count={len(new_entity_ids)}"
|
||||
)
|
||||
except Exception as e:
|
||||
# 聚类任务提交失败不影响主流程
|
||||
logger.error(f"[Clustering] 提交聚类任务失败(不影响主流程): {e}", exc_info=True)
|
||||
|
||||
break
|
||||
|
||||
@@ -61,9 +61,9 @@ from app.core.memory.models.triplet_models import (
|
||||
# User metadata models
|
||||
from app.core.memory.models.metadata_models import (
|
||||
UserMetadata,
|
||||
UserMetadataBehavioralHints,
|
||||
UserMetadataProfile,
|
||||
MetadataExtractionResponse,
|
||||
MetadataFieldChange,
|
||||
)
|
||||
|
||||
# Ontology scenario models (LLM extracted from scenarios)
|
||||
@@ -133,9 +133,9 @@ __all__ = [
|
||||
"Triplet",
|
||||
"TripletExtractionResponse",
|
||||
"UserMetadata",
|
||||
"UserMetadataBehavioralHints",
|
||||
"UserMetadataProfile",
|
||||
"MetadataExtractionResponse",
|
||||
"MetadataFieldChange",
|
||||
# Ontology models
|
||||
"OntologyClass",
|
||||
"OntologyExtractionResponse",
|
||||
|
||||
@@ -4,7 +4,7 @@ Independent from triplet_models.py - these models are used by the
|
||||
standalone metadata extraction pipeline (post-dedup async Celery task).
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
@@ -13,8 +13,8 @@ class UserMetadataProfile(BaseModel):
|
||||
"""用户画像信息"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
role: str = Field(default="", description="用户职业或角色")
|
||||
domain: str = Field(default="", description="用户所在领域")
|
||||
role: List[str] = Field(default_factory=list, description="用户职业或角色")
|
||||
domain: List[str] = Field(default_factory=list, description="用户所在领域")
|
||||
expertise: List[str] = Field(
|
||||
default_factory=list, description="用户擅长的技能或工具"
|
||||
)
|
||||
@@ -23,31 +23,37 @@ class UserMetadataProfile(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class UserMetadataBehavioralHints(BaseModel):
|
||||
"""行为偏好"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
learning_stage: str = Field(default="", description="学习阶段")
|
||||
preferred_depth: str = Field(default="", description="偏好深度")
|
||||
tone_preference: str = Field(default="", description="语气偏好")
|
||||
|
||||
|
||||
class UserMetadata(BaseModel):
|
||||
"""用户元数据顶层结构"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
profile: UserMetadataProfile = Field(default_factory=UserMetadataProfile)
|
||||
behavioral_hints: UserMetadataBehavioralHints = Field(
|
||||
default_factory=UserMetadataBehavioralHints
|
||||
|
||||
|
||||
class MetadataFieldChange(BaseModel):
|
||||
"""单个元数据字段的变更操作"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
field_path: str = Field(
|
||||
description="字段路径,用点号分隔,如 'profile.role'、'profile.expertise'"
|
||||
)
|
||||
action: Literal["set", "remove"] = Field(
|
||||
description="操作类型:'set' 表示新增或修改,'remove' 表示移除"
|
||||
)
|
||||
value: Optional[str] = Field(
|
||||
default=None,
|
||||
description="字段的新值(action='set' 时必填)。标量字段直接填值,列表字段填单个要新增的元素"
|
||||
)
|
||||
knowledge_tags: List[str] = Field(default_factory=list, description="知识标签")
|
||||
|
||||
|
||||
class MetadataExtractionResponse(BaseModel):
|
||||
"""元数据提取 LLM 响应结构"""
|
||||
"""元数据提取 LLM 响应结构(增量模式)"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
user_metadata: UserMetadata = Field(default_factory=UserMetadata)
|
||||
metadata_changes: List[MetadataFieldChange] = Field(
|
||||
default_factory=list,
|
||||
description="元数据的增量变更列表,每项描述一个字段的新增、修改或移除操作",
|
||||
)
|
||||
aliases_to_add: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="本次新发现的用户别名(用户自我介绍或他人对用户的称呼)",
|
||||
|
||||
@@ -82,51 +82,38 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
|
||||
canonical.connect_strength = next(iter(pair))
|
||||
|
||||
# 别名合并(去重保序,使用标准化工具)
|
||||
# 用户实体的 aliases 由 PgSQL end_user_info 作为唯一权威源,去重合并时不修改
|
||||
try:
|
||||
canonical_name = (getattr(canonical, "name", "") or "").strip()
|
||||
incoming_name = (getattr(ent, "name", "") or "").strip()
|
||||
|
||||
# 收集所有需要合并的别名
|
||||
all_aliases = []
|
||||
|
||||
# 1. 添加canonical现有的别名
|
||||
existing = getattr(canonical, "aliases", []) or []
|
||||
all_aliases.extend(existing)
|
||||
|
||||
# 2. 添加incoming实体的名称(如果不同于canonical的名称)
|
||||
if incoming_name and incoming_name != canonical_name:
|
||||
all_aliases.append(incoming_name)
|
||||
|
||||
# 3. 添加incoming实体的所有别名
|
||||
incoming = getattr(ent, "aliases", []) or []
|
||||
all_aliases.extend(incoming)
|
||||
|
||||
# 4. 标准化并去重(优先使用alias_utils工具函数)
|
||||
try:
|
||||
from app.core.memory.utils.alias_utils import normalize_aliases
|
||||
canonical.aliases = normalize_aliases(canonical_name, all_aliases)
|
||||
except Exception:
|
||||
# 如果导入失败,使用增强的去重逻辑
|
||||
seen_normalized = set()
|
||||
unique_aliases = []
|
||||
if canonical_name.lower() not in _USER_PLACEHOLDER_NAMES:
|
||||
incoming_name = (getattr(ent, "name", "") or "").strip()
|
||||
|
||||
for alias in all_aliases:
|
||||
if not alias:
|
||||
continue
|
||||
|
||||
alias_stripped = str(alias).strip()
|
||||
if not alias_stripped or alias_stripped == canonical_name:
|
||||
continue
|
||||
|
||||
# 标准化:转小写用于去重判断
|
||||
alias_normalized = alias_stripped.lower()
|
||||
|
||||
if alias_normalized not in seen_normalized:
|
||||
seen_normalized.add(alias_normalized)
|
||||
unique_aliases.append(alias_stripped)
|
||||
# 收集所有需要合并的别名,过滤掉用户占位名避免污染非用户实体
|
||||
all_aliases = list(getattr(canonical, "aliases", []) or [])
|
||||
if incoming_name and incoming_name != canonical_name and incoming_name.lower() not in _USER_PLACEHOLDER_NAMES:
|
||||
all_aliases.append(incoming_name)
|
||||
all_aliases.extend(
|
||||
a for a in (getattr(ent, "aliases", []) or [])
|
||||
if a and a.strip().lower() not in _USER_PLACEHOLDER_NAMES
|
||||
)
|
||||
|
||||
# 排序并赋值
|
||||
canonical.aliases = sorted(unique_aliases)
|
||||
try:
|
||||
from app.core.memory.utils.alias_utils import normalize_aliases
|
||||
canonical.aliases = normalize_aliases(canonical_name, all_aliases)
|
||||
except Exception:
|
||||
seen_normalized = set()
|
||||
unique_aliases = []
|
||||
for alias in all_aliases:
|
||||
if not alias:
|
||||
continue
|
||||
alias_stripped = str(alias).strip()
|
||||
if not alias_stripped or alias_stripped == canonical_name:
|
||||
continue
|
||||
alias_normalized = alias_stripped.lower()
|
||||
if alias_normalized not in seen_normalized:
|
||||
seen_normalized.add(alias_normalized)
|
||||
unique_aliases.append(alias_stripped)
|
||||
canonical.aliases = sorted(unique_aliases)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -733,66 +720,37 @@ def fuzzy_match(
|
||||
|
||||
|
||||
def _merge_entities_with_aliases(canonical: ExtractedEntityNode, losing: ExtractedEntityNode):
|
||||
""" 模糊匹配中的实体合并。
|
||||
"""模糊匹配中的实体合并(别名部分)。
|
||||
|
||||
合并策略:
|
||||
1. 保留canonical的主名称不变
|
||||
2. 将losing的主名称添加为alias(如果不同)
|
||||
3. 合并两个实体的所有aliases
|
||||
4. 自动去重(case-insensitive)并排序
|
||||
|
||||
Args:
|
||||
canonical: 规范实体(保留)
|
||||
losing: 被合并实体(删除)
|
||||
|
||||
Note:
|
||||
使用alias_utils.normalize_aliases进行标准化去重
|
||||
用户实体的 aliases 由 PgSQL end_user_info 作为唯一权威源,跳过合并。
|
||||
"""
|
||||
# 获取规范实体的名称
|
||||
canonical_name = (getattr(canonical, "name", "") or "").strip()
|
||||
if canonical_name.lower() in _USER_PLACEHOLDER_NAMES:
|
||||
return
|
||||
|
||||
losing_name = (getattr(losing, "name", "") or "").strip()
|
||||
|
||||
# 收集所有需要合并的别名
|
||||
all_aliases = []
|
||||
|
||||
# 1. 添加canonical现有的别名
|
||||
current_aliases = getattr(canonical, "aliases", []) or []
|
||||
all_aliases.extend(current_aliases)
|
||||
|
||||
# 2. 添加losing实体的名称(如果不同于canonical的名称)
|
||||
all_aliases = list(getattr(canonical, "aliases", []) or [])
|
||||
if losing_name and losing_name != canonical_name:
|
||||
all_aliases.append(losing_name)
|
||||
all_aliases.extend(getattr(losing, "aliases", []) or [])
|
||||
|
||||
# 3. 添加losing实体的所有别名
|
||||
losing_aliases = getattr(losing, "aliases", []) or []
|
||||
all_aliases.extend(losing_aliases)
|
||||
|
||||
# 4. 标准化并去重(使用标准化后的字符串进行去重)
|
||||
try:
|
||||
from app.core.memory.utils.alias_utils import normalize_aliases
|
||||
canonical.aliases = normalize_aliases(canonical_name, all_aliases)
|
||||
except Exception:
|
||||
# 如果导入失败,使用增强的去重逻辑
|
||||
# 使用标准化后的字符串作为key进行去重
|
||||
seen_normalized = set()
|
||||
unique_aliases = []
|
||||
|
||||
for alias in all_aliases:
|
||||
if not alias:
|
||||
continue
|
||||
|
||||
alias_stripped = str(alias).strip()
|
||||
if not alias_stripped or alias_stripped == canonical_name:
|
||||
continue
|
||||
|
||||
# 标准化:转小写用于去重判断
|
||||
alias_normalized = alias_stripped.lower()
|
||||
|
||||
if alias_normalized not in seen_normalized:
|
||||
seen_normalized.add(alias_normalized)
|
||||
unique_aliases.append(alias_stripped)
|
||||
|
||||
# 排序并赋值
|
||||
canonical.aliases = sorted(unique_aliases)
|
||||
|
||||
# ========== 主循环:遍历所有实体对进行模糊匹配 ==========
|
||||
|
||||
@@ -1391,18 +1391,18 @@ class ExtractionOrchestrator:
|
||||
"""
|
||||
将本轮提取的用户别名同步到 end_user 和 end_user_info 表。
|
||||
|
||||
注意:此方法在 Neo4j 写入之前调用,因此不能依赖 Neo4j 作为别名的权威数据源。
|
||||
改为直接使用内存中去重后的 entity_nodes 的 aliases,与 PgSQL 已有的 aliases 合并。
|
||||
PgSQL end_user_info.aliases 是用户别名的唯一权威源。
|
||||
此方法仅将本轮 LLM 从对话中新提取的别名增量追加到 PgSQL,
|
||||
不再从 Neo4j 二层去重合并历史别名,避免脏数据反向污染 PgSQL。
|
||||
|
||||
策略:
|
||||
1. 从内存中的 entity_nodes 提取本轮用户别名(current_aliases)
|
||||
2. 从去重后的 entity_nodes 中提取完整别名(含 Neo4j 二层去重合并的历史别名)
|
||||
3. 从 PgSQL end_user_info 读取已有的 aliases(db_aliases)
|
||||
4. 合并 db_aliases + deduped_aliases + current_aliases,去重保序
|
||||
5. 写回 PgSQL
|
||||
1. 从本轮对话原始发言中提取用户别名(current_aliases)
|
||||
2. 从 PgSQL end_user_info 读取已有的 aliases(db_aliases)
|
||||
3. 合并 db_aliases + current_aliases,去重保序
|
||||
4. 写回 PgSQL
|
||||
|
||||
Args:
|
||||
entity_nodes: 去重后的实体节点列表(内存中,含二层去重合并结果)
|
||||
entity_nodes: 去重后的实体节点列表(内存中)
|
||||
dialog_data_list: 对话数据列表
|
||||
"""
|
||||
try:
|
||||
@@ -1418,11 +1418,6 @@ class ExtractionOrchestrator:
|
||||
# 1. 提取本轮对话的用户别名(保持 LLM 提取的原始顺序,不排序)
|
||||
current_aliases = self._extract_current_aliases(entity_nodes, dialog_data_list)
|
||||
|
||||
# 1.5 从去重后的 entity_nodes 中提取完整别名
|
||||
# 二层去重会将 Neo4j 中已有的历史别名合并到 entity_nodes 中,
|
||||
# 这里提取出来确保 PgSQL 与 Neo4j 的别名保持同步
|
||||
deduped_aliases = self._extract_deduped_entity_aliases(entity_nodes)
|
||||
|
||||
# 1.6 从 Neo4j 查询已有的 AI 助手别名,作为额外的排除源
|
||||
# (防止 LLM 未提取出 AI 助手实体时,AI 别名泄漏到用户别名中)
|
||||
neo4j_assistant_aliases = await self._fetch_neo4j_assistant_aliases(end_user_id)
|
||||
@@ -1434,19 +1429,12 @@ class ExtractionOrchestrator:
|
||||
]
|
||||
if len(current_aliases) < before_count:
|
||||
logger.info(f"通过 Neo4j AI 助手别名排除了 {before_count - len(current_aliases)} 个误归属别名")
|
||||
# 同样过滤 deduped_aliases
|
||||
deduped_aliases = [
|
||||
a for a in deduped_aliases
|
||||
if a.strip().lower() not in neo4j_assistant_aliases
|
||||
]
|
||||
|
||||
if not current_aliases and not deduped_aliases:
|
||||
if not current_aliases:
|
||||
logger.debug(f"本轮未提取到用户别名,跳过同步: end_user_id={end_user_id}")
|
||||
return
|
||||
|
||||
logger.info(f"本轮对话提取的 aliases: {current_aliases}")
|
||||
if deduped_aliases:
|
||||
logger.info(f"去重后实体的完整 aliases(含历史): {deduped_aliases}")
|
||||
|
||||
# 2. 同步到数据库
|
||||
end_user_uuid = uuid.UUID(end_user_id)
|
||||
@@ -1457,21 +1445,15 @@ class ExtractionOrchestrator:
|
||||
logger.warning(f"未找到 end_user_id={end_user_id} 的用户记录")
|
||||
return
|
||||
|
||||
# 3. 从 PgSQL 读取已有 aliases 并与本轮合并
|
||||
# 3. 从 PgSQL 读取已有 aliases 并与本轮新增合并
|
||||
info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid)
|
||||
db_aliases = (info.aliases if info and info.aliases else [])
|
||||
# 过滤掉占位名称
|
||||
db_aliases = [a for a in db_aliases if a.strip().lower() not in self.USER_PLACEHOLDER_NAMES]
|
||||
|
||||
# 合并:已有 + 去重后完整别名 + 本轮新增,去重保序
|
||||
# 合并:PgSQL 已有 + 本轮新增,去重保序(不再合并 Neo4j 历史别名)
|
||||
merged_aliases = list(db_aliases)
|
||||
seen_lower = {a.strip().lower() for a in merged_aliases}
|
||||
# 先合并去重后实体的完整别名(含 Neo4j 历史别名)
|
||||
for alias in deduped_aliases:
|
||||
if alias.strip().lower() not in seen_lower:
|
||||
merged_aliases.append(alias)
|
||||
seen_lower.add(alias.strip().lower())
|
||||
# 再合并本轮新提取的别名
|
||||
for alias in current_aliases:
|
||||
if alias.strip().lower() not in seen_lower:
|
||||
merged_aliases.append(alias)
|
||||
@@ -1505,9 +1487,7 @@ class ExtractionOrchestrator:
|
||||
info.aliases = merged_aliases
|
||||
logger.info(f"同步合并后 aliases 到 end_user_info: {merged_aliases}")
|
||||
else:
|
||||
first_alias = current_aliases[0].strip() if current_aliases else (
|
||||
deduped_aliases[0].strip() if deduped_aliases else ""
|
||||
)
|
||||
first_alias = current_aliases[0].strip() if current_aliases else ""
|
||||
# 确保 first_alias 不是占位名称
|
||||
if first_alias and first_alias.lower() not in self.USER_PLACEHOLDER_NAMES:
|
||||
db.add(EndUserInfo(
|
||||
|
||||
@@ -118,7 +118,7 @@ class MetadataExtractor:
|
||||
existing_aliases: Optional[List[str]] = None,
|
||||
) -> Optional[tuple]:
|
||||
"""
|
||||
对筛选后的 statement 列表调用 LLM 提取元数据和用户别名。
|
||||
对筛选后的 statement 列表调用 LLM 提取元数据增量变更和用户别名。
|
||||
|
||||
Args:
|
||||
statements: 用户发言的 statement 文本列表
|
||||
@@ -126,7 +126,8 @@ class MetadataExtractor:
|
||||
existing_aliases: 数据库已有的用户别名列表(可选)
|
||||
|
||||
Returns:
|
||||
(UserMetadata, List[str], List[str]) tuple: (metadata, aliases_to_add, aliases_to_remove) on success, None on failure
|
||||
(List[MetadataFieldChange], List[str], List[str]) tuple:
|
||||
(metadata_changes, aliases_to_add, aliases_to_remove) on success, None on failure
|
||||
"""
|
||||
if not statements:
|
||||
return None
|
||||
@@ -160,12 +161,12 @@ class MetadataExtractor:
|
||||
)
|
||||
|
||||
if response:
|
||||
metadata = response.user_metadata if response.user_metadata else None
|
||||
changes = response.metadata_changes if response.metadata_changes else []
|
||||
to_add = response.aliases_to_add if response.aliases_to_add else []
|
||||
to_remove = (
|
||||
response.aliases_to_remove if response.aliases_to_remove else []
|
||||
)
|
||||
return metadata, to_add, to_remove
|
||||
return changes, to_add, to_remove
|
||||
|
||||
logger.warning("LLM 返回的响应为空")
|
||||
return None
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
===Task===
|
||||
Extract user metadata from the following conversation statements spoken by the user.
|
||||
Extract user metadata changes from the following conversation statements spoken by the user.
|
||||
|
||||
{% if language == "zh" %}
|
||||
**"三度原则"判断标准:**
|
||||
@@ -10,28 +10,36 @@ Extract user metadata from the following conversation statements spoken by the u
|
||||
**提取规则:**
|
||||
- **只提取关于"用户本人"的画像信息**,忽略用户提到的第三方人物(如朋友、同事、家人)的信息
|
||||
- 仅提取文本中明确提到的信息,不要推测
|
||||
- 如果文本中没有可提取的用户画像信息,返回空的 user_metadata 对象
|
||||
- **输出语言必须与输入文本的语言一致**(输入中文则输出中文值,输入英文则输出英文值)
|
||||
|
||||
**增量模式(重要):**
|
||||
你只需要输出**本次对话引起的变更操作**,不要输出完整的元数据。每个变更是一个对象,包含:
|
||||
- `field_path`:字段路径,用点号分隔(如 `profile.role`、`profile.expertise`)
|
||||
- `action`:操作类型
|
||||
* `set`:新增或修改一个字段的值
|
||||
* `remove`:移除一个字段的值
|
||||
- `value`:字段的新值(`action="set"` 时必填,`action="remove"` 时填要移除的元素值)
|
||||
* 所有字段均为列表类型,每个元素一条变更记录
|
||||
|
||||
**判断规则:**
|
||||
- 用户提到新信息 → `action="set"`,填入新值
|
||||
- 用户明确否定已有信息(如"我不再做老师了"、"我已经不学Python了")→ `action="remove"`,`value` 填要移除的元素值
|
||||
- 如果本次对话没有任何可提取的变更,返回空的 `metadata_changes` 数组 `[]`
|
||||
- **不要为未被提及的字段生成任何变更操作**
|
||||
|
||||
{% if existing_metadata %}
|
||||
**重要:合并已有元数据**
|
||||
下方提供了数据库中已有的用户元数据。请结合用户最新发言,输出**合并后的完整元数据**:
|
||||
- 如果用户明确否定了已有信息(如"我不再教高中物理了"),在输出中**移除**该信息
|
||||
- 如果用户提到了新信息,**添加**到对应字段中
|
||||
- 如果已有信息未被用户否定,**保留**在输出中
|
||||
- 标量字段(如 role、domain):如果用户提到了新值,用新值替换;否则保留已有值
|
||||
- 最终输出应该是完整的、合并后的元数据,不是增量
|
||||
**已有元数据(仅供参考,用于判断是否需要变更):**
|
||||
请对比已有数据和用户最新发言,只输出差异部分的变更操作。
|
||||
- 如果用户说的信息和已有数据一致,不需要输出变更
|
||||
- 如果用户否定了已有数据中的某个值,输出 `remove` 操作
|
||||
- 如果用户提到了新信息,输出 `set` 操作
|
||||
{% endif %}
|
||||
|
||||
**字段说明:**
|
||||
- profile.role:用户的职业或角色,如 教师、医生、后端工程师
|
||||
- profile.domain:用户所在领域,如 教育、医疗、软件开发
|
||||
- profile.expertise:用户擅长的技能或工具(通用,不限于编程),如 Python、心理咨询、高中物理
|
||||
- profile.interests:用户主动表达兴趣的话题或领域标签
|
||||
- behavioral_hints.learning_stage:学习阶段(初学者/中级/高级)
|
||||
- behavioral_hints.preferred_depth:偏好深度(概览/技术细节/深入探讨)
|
||||
- behavioral_hints.tone_preference:语气偏好(轻松随意/专业简洁/学术严谨)
|
||||
- knowledge_tags:用户涉及的知识领域标签
|
||||
- profile.role:用户的职业或角色(列表),如 教师、医生、后端工程师,一个人可以有多个角色
|
||||
- profile.domain:用户所在领域(列表),如 教育、医疗、软件开发,一个人可以涉及多个领域
|
||||
- profile.expertise:用户擅长的技能或工具(列表),如 Python、心理咨询、高中物理
|
||||
- profile.interests:用户主动表达兴趣的话题或领域标签(列表)
|
||||
|
||||
**用户别名变更(增量模式):**
|
||||
- **aliases_to_add**:本次新发现的用户别名,包括:
|
||||
@@ -43,7 +51,6 @@ Extract user metadata from the following conversation statements spoken by the u
|
||||
- **aliases_to_remove**:用户明确否认的别名,包括:
|
||||
* 用户说"我不叫XX了"、"别叫我XX"、"我改名了,不叫XX" → 将 XX 放入此数组
|
||||
* **严格限制**:只将用户原文中**逐字提到**的被否认名字放入,不要推断关联的其他别名
|
||||
* 例如:用户说"我不叫陈小刀了" → 只移除"陈小刀",不要移除"陈哥"、"老陈"等未被提及的别名
|
||||
* 如果没有要移除的别名,返回空数组 `[]`
|
||||
{% if existing_aliases %}
|
||||
- 已有别名:{{ existing_aliases | tojson }}(仅供参考,不需要在输出中重复)
|
||||
@@ -57,28 +64,36 @@ Extract user metadata from the following conversation statements spoken by the u
|
||||
**Extraction rules:**
|
||||
- **Only extract profile information about the user themselves**, ignore information about third parties (friends, colleagues, family) mentioned by the user
|
||||
- Only extract information explicitly mentioned in the text, do not speculate
|
||||
- If no user profile information can be extracted, return an empty user_metadata object
|
||||
- **Output language must match the input text language**
|
||||
|
||||
**Incremental mode (important):**
|
||||
You should only output **the change operations caused by this conversation**, not the complete metadata. Each change is an object containing:
|
||||
- `field_path`: Field path separated by dots (e.g. `profile.role`, `profile.expertise`)
|
||||
- `action`: Operation type
|
||||
* `set`: Add or update a field value
|
||||
* `remove`: Remove a field value
|
||||
- `value`: The new value for the field (required when `action="set"`, for `action="remove"` fill in the element value to remove)
|
||||
* All fields are list types, one change record per element
|
||||
|
||||
**Decision rules:**
|
||||
- User mentions new information → `action="set"`, fill in the new value
|
||||
- User explicitly negates existing info (e.g. "I'm no longer a teacher", "I stopped learning Python") → `action="remove"`, `value` is the element to remove
|
||||
- If this conversation has no extractable changes, return an empty `metadata_changes` array `[]`
|
||||
- **Do NOT generate any change operations for fields not mentioned in the conversation**
|
||||
|
||||
{% if existing_metadata %}
|
||||
**Important: Merge with existing metadata**
|
||||
Existing user metadata from the database is provided below. Combine with the user's latest statements to output the **complete merged metadata**:
|
||||
- If the user explicitly negates existing info (e.g. "I no longer teach high school physics"), **remove** it from output
|
||||
- If the user mentions new info, **add** it to the corresponding field
|
||||
- If existing info is not negated by the user, **keep** it in the output
|
||||
- Scalar fields (e.g. role, domain): replace with new value if user mentions one; otherwise keep existing
|
||||
- The final output should be the complete, merged metadata — not an incremental update
|
||||
**Existing metadata (for reference only, to determine if changes are needed):**
|
||||
Compare existing data with the user's latest statements, and only output change operations for the differences.
|
||||
- If the user's statement matches existing data, no change is needed
|
||||
- If the user negates a value in existing data, output a `remove` operation
|
||||
- If the user mentions new information, output a `set` operation
|
||||
{% endif %}
|
||||
|
||||
**Field descriptions:**
|
||||
- profile.role: User's occupation or role, e.g. teacher, doctor, software engineer
|
||||
- profile.domain: User's domain, e.g. education, healthcare, software development
|
||||
- profile.expertise: User's skills or tools (general, not limited to programming)
|
||||
- profile.interests: Topics or domain tags the user actively expressed interest in
|
||||
- behavioral_hints.learning_stage: Learning stage (beginner/intermediate/advanced)
|
||||
- behavioral_hints.preferred_depth: Preferred depth (overview/detailed/deep dive)
|
||||
- behavioral_hints.tone_preference: Tone preference (casual/professional/academic)
|
||||
- knowledge_tags: Knowledge domain tags related to the user
|
||||
- profile.role: User's occupation or role (list), e.g. teacher, doctor, software engineer. A person can have multiple roles
|
||||
- profile.domain: User's domain (list), e.g. education, healthcare, software development. A person can span multiple domains
|
||||
- profile.expertise: User's skills or tools (list), e.g. Python, counseling, physics
|
||||
- profile.interests: Topics or domain tags the user actively expressed interest in (list)
|
||||
|
||||
**User alias changes (incremental mode):**
|
||||
- **aliases_to_add**: Newly discovered user aliases from this conversation, including:
|
||||
@@ -90,7 +105,6 @@ Existing user metadata from the database is provided below. Combine with the use
|
||||
- **aliases_to_remove**: Aliases the user explicitly denies, including:
|
||||
* User says "Don't call me XX anymore", "I'm not called XX", "I changed my name from XX" → put XX in this array
|
||||
* **Strict rule**: Only include the exact name the user **verbatim mentions** as denied. Do NOT infer or remove related aliases
|
||||
* Example: User says "I'm not called John anymore" → only remove "John", do NOT remove "Johnny", "J" or other related aliases not mentioned
|
||||
* If no aliases to remove, return empty array `[]`
|
||||
{% if existing_aliases %}
|
||||
- Existing aliases: {{ existing_aliases | tojson }} (for reference only, do not repeat in output)
|
||||
@@ -113,20 +127,11 @@ Existing user metadata from the database is provided below. Combine with the use
|
||||
Return a JSON object with the following structure:
|
||||
```json
|
||||
{
|
||||
"user_metadata": {
|
||||
"profile": {
|
||||
"role": "",
|
||||
"domain": "",
|
||||
"expertise": [],
|
||||
"interests": []
|
||||
},
|
||||
"behavioral_hints": {
|
||||
"learning_stage": "",
|
||||
"preferred_depth": "",
|
||||
"tone_preference": ""
|
||||
},
|
||||
"knowledge_tags": []
|
||||
},
|
||||
"metadata_changes": [
|
||||
{"field_path": "profile.role", "action": "set", "value": "后端工程师"},
|
||||
{"field_path": "profile.expertise", "action": "set", "value": "Python"},
|
||||
{"field_path": "profile.expertise", "action": "remove", "value": "Java"}
|
||||
],
|
||||
"aliases_to_add": [],
|
||||
"aliases_to_remove": []
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Optional, TypeVar
|
||||
from typing import Any, Dict, List, Optional, TypeVar
|
||||
|
||||
from langchain_aws import ChatBedrock
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
@@ -9,12 +9,12 @@ from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_ollama import OllamaLLM
|
||||
from langchain_openai import ChatOpenAI, OpenAI
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.models.models_model import ModelProvider, ModelType
|
||||
from app.core.models.volcano_chat import VolcanoChatOpenAI
|
||||
from app.core.models.compatible_chat import CompatibleChatOpenAI
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -25,10 +25,11 @@ class RedBearModelConfig(BaseModel):
|
||||
provider: str
|
||||
api_key: str
|
||||
base_url: Optional[str] = None
|
||||
capability: List[str] = Field(default_factory=list) # 模型能力列表,驱动所有能力开关
|
||||
is_omni: bool = False # 是否为 Omni 模型
|
||||
deep_thinking: bool = False # 是否启用深度思考模式
|
||||
thinking_budget_tokens: Optional[int] = None # 深度思考 token 预算
|
||||
support_thinking: bool = False # 模型是否支持 enable_thinking 参数(capability 含 thinking)
|
||||
json_output: bool = False # 是否强制 JSON 输出
|
||||
# 请求超时时间(秒)- 默认120秒以支持复杂的LLM调用,可通过环境变量 LLM_TIMEOUT 配置
|
||||
timeout: float = Field(default_factory=lambda: float(os.getenv("LLM_TIMEOUT", "120.0")))
|
||||
# 最大重试次数 - 默认2次以避免过长等待,可通过环境变量 LLM_MAX_RETRIES 配置
|
||||
@@ -36,6 +37,23 @@ class RedBearModelConfig(BaseModel):
|
||||
concurrency: int = 5 # 并发限流
|
||||
extra_params: Dict[str, Any] = {}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _resolve_capabilities(self) -> "RedBearModelConfig":
|
||||
from app.core.logging_config import get_business_logger
|
||||
logger = get_business_logger()
|
||||
if self.deep_thinking and "thinking" not in self.capability:
|
||||
logger.warning(
|
||||
f"模型 {self.model_name} 不支持深度思考(capability 中无 'thinking'),已自动关闭 deep_thinking"
|
||||
)
|
||||
self.deep_thinking = False
|
||||
self.thinking_budget_tokens = None
|
||||
if self.json_output and "json_output" not in self.capability:
|
||||
logger.warning(
|
||||
f"模型 {self.model_name} 不支持 JSON 输出(capability 中无 'json_output'),已自动关闭 json_output"
|
||||
)
|
||||
self.json_output = False
|
||||
return self
|
||||
|
||||
|
||||
class RedBearModelFactory:
|
||||
"""模型工厂类"""
|
||||
@@ -74,18 +92,19 @@ class RedBearModelFactory:
|
||||
is_streaming = bool(config.extra_params.get("streaming"))
|
||||
if is_streaming:
|
||||
params["stream_usage"] = True
|
||||
# 只有支持 thinking 的模型才传 enable_thinking
|
||||
if config.support_thinking:
|
||||
model_kwargs: Dict[str, Any] = config.extra_params.get("model_kwargs", {})
|
||||
if is_streaming:
|
||||
model_kwargs["enable_thinking"] = config.deep_thinking
|
||||
if config.deep_thinking:
|
||||
model_kwargs["incremental_output"] = True
|
||||
if config.thinking_budget_tokens:
|
||||
model_kwargs["thinking_budget"] = config.thinking_budget_tokens
|
||||
else:
|
||||
model_kwargs["enable_thinking"] = False
|
||||
params["model_kwargs"] = model_kwargs
|
||||
# 支持 thinking 的模型始终传 enable_thinking,关闭时显式传 False 避免模型默认开启思考
|
||||
if "thinking" in config.capability:
|
||||
extra_body = params.setdefault("extra_body", {})
|
||||
if config.deep_thinking:
|
||||
extra_body["enable_thinking"] = False
|
||||
if is_streaming:
|
||||
extra_body["enable_thinking"] = True
|
||||
if config.thinking_budget_tokens:
|
||||
extra_body["thinking_budget"] = config.thinking_budget_tokens
|
||||
# JSON 输出模式
|
||||
if config.json_output:
|
||||
model_kwargs = params.setdefault("model_kwargs", {})
|
||||
model_kwargs["response_format"] = {"type": "json_object"}
|
||||
return params
|
||||
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]:
|
||||
@@ -108,26 +127,31 @@ class RedBearModelFactory:
|
||||
**config.extra_params
|
||||
}
|
||||
# 流式模式下启用 stream_usage 以获取 token 统计
|
||||
if config.extra_params.get("streaming"):
|
||||
params["stream_usage"] = True
|
||||
# 深度思考模式
|
||||
is_streaming = bool(config.extra_params.get("streaming"))
|
||||
if is_streaming and not config.is_omni:
|
||||
if is_streaming:
|
||||
params["stream_usage"] = True
|
||||
# 支持 thinking 的模型始终传 enable_thinking,关闭时显式传 False 避免模型默认开启思考
|
||||
if "thinking" in config.capability:
|
||||
# VOLCANO 深度思考仅流式支持
|
||||
if provider == ModelProvider.VOLCANO:
|
||||
# 火山引擎深度思考仅流式调用支持,非流式时不传 thinking 参数
|
||||
thinking_config: Dict[str, Any] = {
|
||||
"type": "enabled" if config.deep_thinking else "disabled"
|
||||
}
|
||||
thinking_config: Dict[str, Any] = {"type": "enabled" if config.deep_thinking else "disabled"}
|
||||
if config.deep_thinking and config.thinking_budget_tokens:
|
||||
thinking_config["budget_tokens"] = config.thinking_budget_tokens
|
||||
params["extra_body"] = {"thinking": thinking_config}
|
||||
else:
|
||||
# 始终显式传递 enable_thinking,不支持该参数的模型(如 DeepSeek-R1)会直接忽略
|
||||
model_kwargs: Dict[str, Any] = config.extra_params.get("model_kwargs", {})
|
||||
model_kwargs["enable_thinking"] = config.deep_thinking
|
||||
if config.deep_thinking and config.thinking_budget_tokens:
|
||||
model_kwargs["thinking_budget"] = config.thinking_budget_tokens
|
||||
params["model_kwargs"] = model_kwargs
|
||||
extra_body = params.setdefault("extra_body", {})
|
||||
if config.deep_thinking:
|
||||
extra_body["enable_thinking"] = False
|
||||
if is_streaming:
|
||||
extra_body["enable_thinking"] = True
|
||||
if config.thinking_budget_tokens:
|
||||
extra_body["thinking_budget"] = config.thinking_budget_tokens
|
||||
# JSON 输出模式
|
||||
if config.json_output:
|
||||
model_kwargs = params.setdefault("model_kwargs", {})
|
||||
# VOLCANO 模型不支持 response_format,JSON 输出由 system prompt 注入实现
|
||||
if provider != ModelProvider.VOLCANO:
|
||||
model_kwargs["response_format"] = {"type": "json_object"}
|
||||
return params
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
params = {
|
||||
@@ -136,19 +160,20 @@ class RedBearModelFactory:
|
||||
"max_retries": config.max_retries,
|
||||
**config.extra_params
|
||||
}
|
||||
# 只有支持 thinking 的模型才传 enable_thinking
|
||||
if config.support_thinking:
|
||||
# 支持 thinking 的模型始终传 enable_thinking,关闭时显式传 False 避免模型默认开启思考
|
||||
if "thinking" in config.capability:
|
||||
is_streaming = bool(config.extra_params.get("streaming"))
|
||||
model_kwargs: Dict[str, Any] = config.extra_params.get("model_kwargs", {})
|
||||
if is_streaming:
|
||||
model_kwargs["enable_thinking"] = config.deep_thinking
|
||||
if config.deep_thinking:
|
||||
model_kwargs["incremental_output"] = True
|
||||
if config.thinking_budget_tokens:
|
||||
model_kwargs["thinking_budget"] = config.thinking_budget_tokens
|
||||
else:
|
||||
model_kwargs = params.setdefault("model_kwargs", {})
|
||||
if config.deep_thinking:
|
||||
model_kwargs["enable_thinking"] = False
|
||||
params["model_kwargs"] = model_kwargs
|
||||
if is_streaming:
|
||||
model_kwargs["enable_thinking"] = True
|
||||
model_kwargs["incremental_output"] = True
|
||||
if config.thinking_budget_tokens:
|
||||
model_kwargs["thinking_budget"] = config.thinking_budget_tokens
|
||||
if config.json_output:
|
||||
model_kwargs = params.setdefault("model_kwargs", {})
|
||||
model_kwargs["response_format"] = {"type": "json_object"}
|
||||
return params
|
||||
elif provider == ModelProvider.BEDROCK:
|
||||
# Bedrock 使用 AWS 凭证
|
||||
@@ -195,6 +220,10 @@ class RedBearModelFactory:
|
||||
params["additional_model_request_fields"] = {
|
||||
"thinking": {"type": "enabled", "budget_tokens": budget}
|
||||
}
|
||||
# JSON 输出模式
|
||||
if config.json_output:
|
||||
model_kwargs = params.setdefault("model_kwargs", {})
|
||||
model_kwargs["response_format"] = {"type": "json_object"}
|
||||
return params
|
||||
else:
|
||||
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
@@ -223,18 +252,19 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy
|
||||
"""根据模型提供商获取对应的模型类"""
|
||||
provider = config.provider.lower()
|
||||
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
||||
# dashscope的omni模型 和 volcano模型使用
|
||||
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
||||
return ChatOpenAI
|
||||
return CompatibleChatOpenAI
|
||||
if provider == ModelProvider.VOLCANO:
|
||||
return VolcanoChatOpenAI
|
||||
return CompatibleChatOpenAI
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
if type == ModelType.LLM:
|
||||
return OpenAI
|
||||
elif type == ModelType.CHAT:
|
||||
return ChatOpenAI
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商及类型: {provider}-{type}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
return CompatibleChatOpenAI
|
||||
# if type == ModelType.LLM:
|
||||
# return OpenAI
|
||||
# elif type == ModelType.CHAT:
|
||||
# return CompatibleChatOpenAI
|
||||
# else:
|
||||
# raise BusinessException(f"不支持的模型提供商及类型: {provider}-{type}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
return ChatTongyi
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
|
||||
@@ -8,12 +8,33 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
|
||||
class VolcanoChatOpenAI(ChatOpenAI):
|
||||
"""火山引擎 Chat 模型,支持深度思考内容(reasoning_content)的流式和非流式透传。"""
|
||||
class CompatibleChatOpenAI(ChatOpenAI):
|
||||
"""火山和千问的omni兼容模型,支持深度思考内容(reasoning_content)的流式和非流式透传。
|
||||
|
||||
同时修复 json_output + tools 同时使用时 langchain_openai 强制走 .parse()/.stream()
|
||||
导致 strict 校验报错的问题:有工具时从 payload 中移除 response_format,
|
||||
让父类走普通 .create()/.astream() 路径,JSON 输出由 system prompt 指令保证。
|
||||
"""
|
||||
|
||||
def _get_request_payload(
|
||||
self,
|
||||
input_: list[BaseMessage],
|
||||
*,
|
||||
stop: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
||||
# 有工具时 langchain_openai 检测到 response_format 会切换到 .parse()/.stream()
|
||||
# 接口,OpenAI SDK 要求此时所有工具必须 strict=True,动态生成的工具不满足。
|
||||
# 移除 response_format,让父类走普通路径,JSON 输出由 system prompt 指令保证。
|
||||
if payload.get("tools") and "response_format" in payload:
|
||||
payload.pop("response_format")
|
||||
return payload
|
||||
|
||||
def _create_chat_result(self, response: Union[dict, Any], generation_info: Optional[dict] = None) -> ChatResult:
|
||||
result = super()._create_chat_result(response, generation_info)
|
||||
@@ -6,7 +6,8 @@ models:
|
||||
description: AI21 Labs大语言模型,completion生成模式,256000上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -20,6 +21,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -38,6 +40,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -54,7 +57,8 @@ models:
|
||||
description: Cohere大语言模型,支持智能体思考、工具调用、流式工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -72,6 +76,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -87,7 +92,8 @@ models:
|
||||
description: Meta Llama大语言模型,支持智能体思考、工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -101,7 +107,8 @@ models:
|
||||
description: Mistral AI大语言模型,支持智能体思考、工具调用,32000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -115,7 +122,8 @@ models:
|
||||
description: OpenAI大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -130,7 +138,8 @@ models:
|
||||
description: Qwen大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
|
||||
@@ -8,6 +8,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -22,6 +23,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -36,6 +38,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -48,7 +51,8 @@ models:
|
||||
description: DeepSeek-V3.1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -61,7 +65,8 @@ models:
|
||||
description: DeepSeek-V3.2-exp实验版大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -74,7 +79,8 @@ models:
|
||||
description: DeepSeek-V3.2大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -87,7 +93,8 @@ models:
|
||||
description: DeepSeek-V3大语言模型,支持智能体思考,64000上下文窗口,对话模式,支持文本与JSON格式输出
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -100,7 +107,8 @@ models:
|
||||
description: farui-plus大语言模型,支持多工具调用、智能体思考、流式工具调用,12288上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -115,7 +123,8 @@ models:
|
||||
description: GLM-4.7大语言模型,支持多工具调用、智能体思考、流式工具调用,202752超大上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -133,6 +142,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -150,6 +160,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -180,6 +191,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -210,7 +222,7 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -376,6 +388,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -448,6 +461,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -466,6 +480,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -481,7 +496,8 @@ models:
|
||||
description: qwen2.5-0.5b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -498,6 +514,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -513,7 +530,7 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -530,6 +547,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -546,6 +564,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -561,7 +580,7 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -578,6 +597,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -594,6 +614,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -610,6 +631,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -626,6 +648,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -641,7 +664,7 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -656,7 +679,7 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -672,6 +695,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -687,6 +711,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -702,6 +727,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -719,6 +745,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -736,6 +763,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -752,6 +780,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -768,7 +797,7 @@ models:
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -785,6 +814,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -803,6 +833,8 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- audio
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -822,7 +854,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -844,6 +876,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -864,7 +897,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -886,6 +919,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -907,6 +941,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -928,6 +963,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -947,6 +983,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -964,6 +1001,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -979,6 +1017,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -994,6 +1033,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
|
||||
@@ -10,6 +10,7 @@ models:
|
||||
- vision
|
||||
- audio
|
||||
- video
|
||||
- json_output
|
||||
is_omni: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -27,7 +28,8 @@ models:
|
||||
description: gpt-3.5-turbo-0125大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -42,7 +44,8 @@ models:
|
||||
description: gpt-3.5-turbo-1106大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -57,7 +60,8 @@ models:
|
||||
description: gpt-3.5-turbo-16k大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -84,7 +88,8 @@ models:
|
||||
description: gpt-3.5-turbo大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -99,7 +104,8 @@ models:
|
||||
description: gpt-4-0125-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -114,7 +120,8 @@ models:
|
||||
description: gpt-4-1106-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -131,6 +138,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -146,7 +154,8 @@ models:
|
||||
description: gpt-4-turbo-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -163,6 +172,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -194,6 +204,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -213,6 +224,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -231,6 +243,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -248,6 +261,7 @@ models:
|
||||
is_official: true
|
||||
capability:
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -266,6 +280,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -284,6 +299,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -302,6 +318,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -321,6 +338,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -340,6 +358,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
|
||||
@@ -11,6 +11,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -26,6 +27,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -41,6 +43,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -56,6 +59,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -72,6 +76,7 @@ models:
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -87,6 +92,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -102,6 +108,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -117,6 +124,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -132,6 +140,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -148,6 +157,7 @@ models:
|
||||
- vision
|
||||
- video
|
||||
- thinking
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -175,7 +185,8 @@ models:
|
||||
description: 全新一代主力模型,性能全面升级,在知识、代码、推理等方面表现卓越。最大支持 128k 上下文窗口,输出长度支持最大 12k tokens。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
@@ -187,7 +198,8 @@ models:
|
||||
description: 全新一代轻量版模型,极致响应速度,效果与时延均达到全球一流水平。支持 32k 上下文窗口,输出长度支持最大 12k tokens。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
capability:
|
||||
- json_output
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
|
||||
485
api/app/core/quota_manager.py
Normal file
485
api/app/core/quota_manager.py
Normal file
@@ -0,0 +1,485 @@
|
||||
"""
|
||||
统一配额管理器 - 社区版和 SaaS 版共用
|
||||
|
||||
配额来源策略:
|
||||
1. 优先从 premium 模块的 tenant_subscriptions 表读取(SaaS 版)
|
||||
2. 降级到 default_free_plan.py 配置文件(社区版兜底)
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import Optional, Callable, Dict, Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_auth_logger
|
||||
from app.i18n.exceptions import QuotaExceededError
|
||||
|
||||
logger = get_auth_logger()
|
||||
|
||||
|
||||
def _get_user_from_kwargs(kwargs: dict):
|
||||
"""从 kwargs 中获取 user 对象"""
|
||||
for key in ["user", "current_user"]:
|
||||
if key in kwargs:
|
||||
return kwargs[key]
|
||||
return None
|
||||
|
||||
|
||||
def _get_tenant_id_from_kwargs(db: Session, kwargs: dict):
|
||||
"""从 kwargs 中获取 tenant_id"""
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if user and hasattr(user, 'tenant_id'):
|
||||
return user.tenant_id
|
||||
|
||||
workspace_id = kwargs.get("workspace_id")
|
||||
if workspace_id:
|
||||
from app.models.workspace_model import Workspace
|
||||
workspace = db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||
if workspace:
|
||||
return workspace.tenant_id
|
||||
|
||||
api_key_auth = kwargs.get("api_key_auth")
|
||||
if api_key_auth and hasattr(api_key_auth, 'workspace_id'):
|
||||
from app.models.workspace_model import Workspace
|
||||
workspace = db.query(Workspace).filter(Workspace.id == api_key_auth.workspace_id).first()
|
||||
if workspace:
|
||||
return workspace.tenant_id
|
||||
|
||||
data = kwargs.get("data") or kwargs.get("body") or kwargs.get("payload")
|
||||
if data and hasattr(data, "workspace_id"):
|
||||
from app.models.workspace_model import Workspace
|
||||
workspace = db.query(Workspace).filter(Workspace.id == data.workspace_id).first()
|
||||
if workspace:
|
||||
return workspace.tenant_id
|
||||
|
||||
share_data = kwargs.get("share_data")
|
||||
if share_data and hasattr(share_data, 'share_token'):
|
||||
from app.models.workspace_model import Workspace
|
||||
from app.models.app_model import App
|
||||
share_token = share_data.share_token
|
||||
from app.models.release_share_model import ReleaseShare
|
||||
share_record = db.query(ReleaseShare).filter(ReleaseShare.share_token == share_token).first()
|
||||
if share_record:
|
||||
app = db.query(App).filter(App.id == share_record.app_id, App.is_active.is_(True)).first()
|
||||
if app:
|
||||
return app.workspace.tenant_id
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_quota_config(db: Session, tenant_id: UUID) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取租户的配额配置
|
||||
|
||||
优先级:
|
||||
1. premium 模块的 tenant_subscriptions(SaaS 版)
|
||||
2. default_free_plan.py 配置文件(社区版兜底)
|
||||
"""
|
||||
# 尝试从 premium 模块获取
|
||||
try:
|
||||
from premium.platform_admin.package_plan_service import TenantSubscriptionService
|
||||
quota_config = TenantSubscriptionService(db).get_effective_quota(tenant_id)
|
||||
if quota_config:
|
||||
logger.debug(f"从 premium 模块获取租户 {tenant_id} 配额配置")
|
||||
return quota_config
|
||||
except (ModuleNotFoundError, ImportError, Exception) as e:
|
||||
logger.debug(f"无法从 premium 模块获取配额配置: {e}")
|
||||
|
||||
# 降级到配置文件
|
||||
try:
|
||||
from app.config.default_free_plan import DEFAULT_FREE_PLAN
|
||||
logger.info(f"使用配置文件中的免费套餐配额: tenant={tenant_id}")
|
||||
return DEFAULT_FREE_PLAN.get("quotas")
|
||||
except Exception as e:
|
||||
logger.error(f"无法从配置文件获取配额: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class QuotaUsageRepository:
|
||||
"""配额使用量数据访问层"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def count_workspaces(self, tenant_id: UUID) -> int:
|
||||
from app.models.workspace_model import Workspace
|
||||
return self.db.query(Workspace).filter(
|
||||
Workspace.tenant_id == tenant_id,
|
||||
Workspace.is_active.is_(True)
|
||||
).count()
|
||||
|
||||
def count_apps(self, tenant_id: UUID) -> int:
|
||||
from app.models.app_model import App
|
||||
from app.models.workspace_model import Workspace
|
||||
return self.db.query(App).join(
|
||||
Workspace, App.workspace_id == Workspace.id
|
||||
).filter(
|
||||
Workspace.tenant_id == tenant_id,
|
||||
App.is_active.is_(True)
|
||||
).count()
|
||||
|
||||
def count_skills(self, tenant_id: UUID) -> int:
|
||||
from app.models.skill_model import Skill
|
||||
return self.db.query(Skill).filter(
|
||||
Skill.tenant_id == tenant_id,
|
||||
Skill.is_active.is_(True)
|
||||
).count()
|
||||
|
||||
def sum_knowledge_capacity_gb(self, tenant_id: UUID) -> float:
|
||||
from app.models.document_model import Document
|
||||
from app.models.knowledge_model import Knowledge
|
||||
from app.models.workspace_model import Workspace
|
||||
result = self.db.query(func.coalesce(func.sum(Document.file_size), 0)).join(
|
||||
Knowledge, Document.kb_id == Knowledge.id
|
||||
).join(
|
||||
Workspace, Knowledge.workspace_id == Workspace.id
|
||||
).filter(
|
||||
Workspace.tenant_id == tenant_id,
|
||||
Document.status == 1,
|
||||
).scalar()
|
||||
return float(result) / (1024 ** 3) if result else 0.0
|
||||
|
||||
def count_memory_engines(self, tenant_id: UUID) -> int:
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
from app.models.workspace_model import Workspace
|
||||
return self.db.query(MemoryConfig).join(
|
||||
Workspace, MemoryConfig.workspace_id == Workspace.id
|
||||
).filter(
|
||||
Workspace.tenant_id == tenant_id
|
||||
).count()
|
||||
|
||||
def count_end_users(self, tenant_id: UUID) -> int:
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.workspace_model import Workspace
|
||||
return self.db.query(EndUser).join(
|
||||
Workspace, EndUser.workspace_id == Workspace.id
|
||||
).filter(
|
||||
Workspace.tenant_id == tenant_id
|
||||
).count()
|
||||
|
||||
def count_models(self, tenant_id: UUID) -> int:
|
||||
from app.models.models_model import ModelConfig
|
||||
return self.db.query(ModelConfig).filter(
|
||||
ModelConfig.tenant_id == tenant_id,
|
||||
ModelConfig.is_active == True
|
||||
).count()
|
||||
|
||||
def count_ontology_projects(self, tenant_id: UUID) -> int:
|
||||
from app.models.ontology_scene import OntologyScene
|
||||
from app.models.workspace_model import Workspace
|
||||
return self.db.query(OntologyScene).join(
|
||||
Workspace, OntologyScene.workspace_id == Workspace.id
|
||||
).filter(
|
||||
Workspace.tenant_id == tenant_id
|
||||
).count()
|
||||
|
||||
def get_usage_by_quota_type(self, tenant_id: UUID, quota_type: str):
|
||||
"""按配额类型分发,返回当前使用量"""
|
||||
dispatch = {
|
||||
"workspace_quota": self.count_workspaces,
|
||||
"app_quota": self.count_apps,
|
||||
"skill_quota": self.count_skills,
|
||||
"knowledge_capacity_quota": self.sum_knowledge_capacity_gb,
|
||||
"memory_engine_quota": self.count_memory_engines,
|
||||
"end_user_quota": self.count_end_users,
|
||||
"model_quota": self.count_models,
|
||||
"ontology_project_quota": self.count_ontology_projects,
|
||||
}
|
||||
fn = dispatch.get(quota_type)
|
||||
return fn(tenant_id) if fn else 0
|
||||
|
||||
|
||||
def _check_quota(
|
||||
db: Session,
|
||||
tenant_id: UUID,
|
||||
quota_type: str,
|
||||
resource_name: str,
|
||||
usage_func: Optional[Callable] = None,
|
||||
) -> None:
|
||||
"""核心配额检查逻辑:对比使用量和配额限制"""
|
||||
try:
|
||||
quota_config = _get_quota_config(db, tenant_id)
|
||||
if not quota_config:
|
||||
logger.warning(f"租户 {tenant_id} 无有效配额配置,跳过配额检查")
|
||||
return
|
||||
|
||||
quota_limit = quota_config.get(quota_type)
|
||||
if quota_limit is None:
|
||||
logger.warning(f"配额配置未包含 {quota_type},跳过配额检查")
|
||||
return
|
||||
|
||||
if usage_func:
|
||||
current_usage = usage_func(db, tenant_id)
|
||||
else:
|
||||
current_usage = QuotaUsageRepository(db).get_usage_by_quota_type(tenant_id, quota_type)
|
||||
|
||||
if current_usage >= quota_limit:
|
||||
logger.warning(
|
||||
f"配额不足: tenant={tenant_id}, type={quota_type}, "
|
||||
f"usage={current_usage}, limit={quota_limit}"
|
||||
)
|
||||
raise QuotaExceededError(
|
||||
resource=resource_name,
|
||||
current_usage=current_usage,
|
||||
quota_limit=quota_limit,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"配额检查通过: tenant={tenant_id}, type={quota_type}, "
|
||||
f"usage={current_usage}, limit={quota_limit}"
|
||||
)
|
||||
|
||||
except QuotaExceededError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"配额检查异常: tenant={tenant_id}, type={quota_type}, "
|
||||
f"error_type={type(e).__name__}, error={str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
# ─── 具名装饰器 ────────────────────────────────────────────────────────────
|
||||
|
||||
def check_workspace_quota(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.warning("配额检查失败:缺少 db 或 user 参数")
|
||||
return func(*args, **kwargs)
|
||||
_check_quota(db, user.tenant_id, "workspace_quota", "workspace")
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def check_skill_quota(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.warning("配额检查失败:缺少 db 或 user 参数")
|
||||
return func(*args, **kwargs)
|
||||
_check_quota(db, user.tenant_id, "skill_quota", "skill")
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def check_app_quota(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.warning("配额检查失败:缺少 db 或 user 参数")
|
||||
return func(*args, **kwargs)
|
||||
_check_quota(db, user.tenant_id, "app_quota", "app")
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def check_knowledge_capacity_quota(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
if not db:
|
||||
logger.warning("配额检查失败:缺少 db 参数")
|
||||
return await func(*args, **kwargs)
|
||||
tenant_id = _get_tenant_id_from_kwargs(db, kwargs)
|
||||
if not tenant_id:
|
||||
logger.warning("配额检查失败:无法获取 tenant_id")
|
||||
return await func(*args, **kwargs)
|
||||
_check_quota(db, tenant_id, "knowledge_capacity_quota", "knowledge_capacity")
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.warning("配额检查失败:缺少 db 或 user 参数")
|
||||
return func(*args, **kwargs)
|
||||
_check_quota(db, user.tenant_id, "knowledge_capacity_quota", "knowledge_capacity")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
|
||||
def check_memory_engine_quota(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.warning("配额检查失败:缺少 db 或 user 参数")
|
||||
return func(*args, **kwargs)
|
||||
_check_quota(db, user.tenant_id, "memory_engine_quota", "memory_engine")
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def check_end_user_quota(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
if not db:
|
||||
logger.warning("配额检查失败:缺少 db 参数")
|
||||
return await func(*args, **kwargs)
|
||||
tenant_id = _get_tenant_id_from_kwargs(db, kwargs)
|
||||
if not tenant_id:
|
||||
logger.warning("配额检查失败:无法获取 tenant_id")
|
||||
return await func(*args, **kwargs)
|
||||
_check_quota(db, tenant_id, "end_user_quota", "end_user")
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
if not db:
|
||||
logger.warning("配额检查失败:缺少 db 参数")
|
||||
return func(*args, **kwargs)
|
||||
tenant_id = _get_tenant_id_from_kwargs(db, kwargs)
|
||||
if not tenant_id:
|
||||
logger.warning("配额检查失败:无法获取 tenant_id")
|
||||
return func(*args, **kwargs)
|
||||
_check_quota(db, tenant_id, "end_user_quota", "end_user")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
|
||||
def check_ontology_project_quota(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.warning("配额检查失败:缺少 db 或 user 参数")
|
||||
return func(*args, **kwargs)
|
||||
_check_quota(db, user.tenant_id, "ontology_project_quota", "ontology_project")
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def check_model_quota(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.warning("配额检查失败:缺少 db 或 user 参数")
|
||||
return func(*args, **kwargs)
|
||||
_check_quota(db, user.tenant_id, "model_quota", "model")
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def check_model_activation_quota(func: Callable) -> Callable:
|
||||
"""模型激活时的配额检查装饰器"""
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.warning("配额检查失败:缺少 db 或 user 参数")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
model_id = kwargs.get("model_id") or (args[1] if len(args) > 1 else None)
|
||||
model_data = kwargs.get("model_data")
|
||||
|
||||
if not model_id or not model_data:
|
||||
logger.warning("模型激活配额检查失败:缺少 model_id 或 model_data 参数")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if model_data.is_active is True:
|
||||
try:
|
||||
from app.models.models_model import ModelConfig
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
existing_model = ModelConfigService.get_model_by_id(
|
||||
db=db,
|
||||
model_id=model_id,
|
||||
tenant_id=user.tenant_id
|
||||
)
|
||||
|
||||
if not existing_model.is_active:
|
||||
logger.info(f"模型激活操作,检查配额: model_id={model_id}, tenant_id={user.tenant_id}")
|
||||
_check_quota(db, user.tenant_id, "model_quota", "model")
|
||||
except Exception as e:
|
||||
logger.error(f"模型激活配额检查异常: model_id={model_id}, error={str(e)}")
|
||||
raise
|
||||
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def check_quota(quota_type: str, resource_name: str, usage_func: Optional[Callable] = None):
|
||||
"""通用配额检查装饰器,支持自定义使用量获取函数"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
db: Session = kwargs.get("db")
|
||||
user = _get_user_from_kwargs(kwargs)
|
||||
if not db or not user:
|
||||
logger.warning("配额检查失败:缺少 db 或 user 参数")
|
||||
return func(*args, **kwargs)
|
||||
_check_quota(db, user.tenant_id, quota_type, resource_name, usage_func)
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
# ─── 配额使用统计 ────────────────────────────────────────────────────────────
|
||||
|
||||
def get_quota_usage(db: Session, tenant_id: UUID) -> dict:
|
||||
"""获取租户所有配额的使用情况"""
|
||||
quota_config = _get_quota_config(db, tenant_id)
|
||||
if not quota_config:
|
||||
return {}
|
||||
|
||||
repo = QuotaUsageRepository(db)
|
||||
|
||||
def pct(used, limit):
|
||||
return round(used / limit * 100, 1) if limit else None
|
||||
|
||||
workspace_count = repo.count_workspaces(tenant_id)
|
||||
skill_count = repo.count_skills(tenant_id)
|
||||
app_count = repo.count_apps(tenant_id)
|
||||
knowledge_gb = repo.sum_knowledge_capacity_gb(tenant_id)
|
||||
memory_count = repo.count_memory_engines(tenant_id)
|
||||
end_user_count = repo.count_end_users(tenant_id)
|
||||
model_count = repo.count_models(tenant_id)
|
||||
ontology_count = repo.count_ontology_projects(tenant_id)
|
||||
|
||||
api_ops_current = 0
|
||||
try:
|
||||
from app.core.config import settings
|
||||
import redis
|
||||
_now = time.time()
|
||||
_rk = f"rate_limit:tenant_qps:{tenant_id}"
|
||||
_r = redis.StrictRedis(
|
||||
host=settings.REDIS_HOST, port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB, password=settings.REDIS_PASSWORD,
|
||||
decode_responses=True
|
||||
)
|
||||
api_ops_current = int(_r.zcount(_rk, _now - 1, "+inf"))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"workspace": {"used": workspace_count, "limit": quota_config.get("workspace_quota"), "percentage": pct(workspace_count, quota_config.get("workspace_quota"))},
|
||||
"skill": {"used": skill_count, "limit": quota_config.get("skill_quota"), "percentage": pct(skill_count, quota_config.get("skill_quota"))},
|
||||
"app": {"used": app_count, "limit": quota_config.get("app_quota"), "percentage": pct(app_count, quota_config.get("app_quota"))},
|
||||
"knowledge_capacity": {"used": round(knowledge_gb, 2), "limit": quota_config.get("knowledge_capacity_quota"), "percentage": pct(knowledge_gb, quota_config.get("knowledge_capacity_quota")), "unit": "GB"},
|
||||
"memory_engine": {"used": memory_count, "limit": quota_config.get("memory_engine_quota"), "percentage": pct(memory_count, quota_config.get("memory_engine_quota"))},
|
||||
"end_user": {"used": end_user_count, "limit": quota_config.get("end_user_quota"), "percentage": pct(end_user_count, quota_config.get("end_user_quota"))},
|
||||
"ontology_project": {"used": ontology_count, "limit": quota_config.get("ontology_project_quota"), "percentage": pct(ontology_count, quota_config.get("ontology_project_quota"))},
|
||||
"model": {"used": model_count, "limit": quota_config.get("model_quota"), "percentage": pct(model_count, quota_config.get("model_quota"))},
|
||||
"api_ops_rate_limit": {"current": api_ops_current, "limit": quota_config.get("api_ops_rate_limit"), "percentage": None, "unit": "次/秒"},
|
||||
}
|
||||
36
api/app/core/quota_stub.py
Normal file
36
api/app/core/quota_stub.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
配额检查 stub - 社区版和 SaaS 版统一使用 core.quota_manager 实现
|
||||
|
||||
所有配额检查逻辑统一在 core 层实现,两个版本共用:
|
||||
- 社区版:从 default_free_plan.py 读取配额限制
|
||||
- SaaS 版:优先从 tenant_subscriptions 表读取,降级到配置文件
|
||||
"""
|
||||
from app.core.quota_manager import (
|
||||
check_workspace_quota,
|
||||
check_skill_quota,
|
||||
check_app_quota,
|
||||
check_knowledge_capacity_quota,
|
||||
check_memory_engine_quota,
|
||||
check_end_user_quota,
|
||||
check_ontology_project_quota,
|
||||
check_model_quota,
|
||||
check_model_activation_quota,
|
||||
get_quota_usage,
|
||||
_check_quota,
|
||||
QuotaUsageRepository,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"check_workspace_quota",
|
||||
"check_skill_quota",
|
||||
"check_app_quota",
|
||||
"check_knowledge_capacity_quota",
|
||||
"check_memory_engine_quota",
|
||||
"check_end_user_quota",
|
||||
"check_ontology_project_quota",
|
||||
"check_model_quota",
|
||||
"check_model_activation_quota",
|
||||
"get_quota_usage",
|
||||
"_check_quota",
|
||||
"QuotaUsageRepository",
|
||||
]
|
||||
@@ -672,10 +672,15 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
excel_parser = ExcelParser()
|
||||
if parser_config.get("html4excel") and parser_config.get("html4excel").lower() == "true":
|
||||
sections = [(_, "") for _ in excel_parser.html(binary, 12) if _]
|
||||
parser_config["chunk_token_num"] = 0
|
||||
else:
|
||||
sections = [(_, "") for _ in excel_parser(binary) if _]
|
||||
parser_config["chunk_token_num"] = 12800
|
||||
callback(0.8, "Finish parsing.")
|
||||
# Excel 每行直接作为一个 chunk,不经过 naive_merge 避免被 delimiter 拆分
|
||||
chunks = [s for s, _ in sections]
|
||||
res.extend(tokenize_chunks(chunks, doc, is_english, None))
|
||||
res.extend(embed_res)
|
||||
res.extend(url_res)
|
||||
return res
|
||||
|
||||
elif re.search(r"\.(txt|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|sql)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
|
||||
@@ -232,14 +232,14 @@ class RAGExcelParser:
|
||||
t = str(ti[i].value) if i < len(ti) else ""
|
||||
t += (":" if t else "") + str(c.value)
|
||||
fields.append(t)
|
||||
line = "; ".join(fields)
|
||||
line = "\n".join(fields)
|
||||
if sheetname.lower().find("sheet") < 0:
|
||||
line += " ——" + sheetname
|
||||
line += "\n——" + sheetname
|
||||
res.append(line)
|
||||
else:
|
||||
# 只有表头的情况
|
||||
if header_fields:
|
||||
line = "; ".join(header_fields)
|
||||
line = "\n".join(header_fields)
|
||||
if sheetname.lower().find("sheet") < 0:
|
||||
line += " ——" + sheetname
|
||||
res.append(line)
|
||||
|
||||
@@ -50,7 +50,9 @@ class OpenAIEmbed(Base):
|
||||
def encode(self, texts: list):
|
||||
# OpenAI requires batch size <=16
|
||||
batch_size = 16
|
||||
texts = [truncate(t, 8191) for t in texts]
|
||||
# Use 8000 instead of 8191 to leave safety margin for tokenizer differences
|
||||
# between cl100k_base (used by truncate) and the actual embedding model
|
||||
texts = [truncate(t, 8000) for t in texts]
|
||||
ress = []
|
||||
total_tokens = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
@@ -63,7 +65,7 @@ class OpenAIEmbed(Base):
|
||||
return np.array(ress), total_tokens
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float",extra_body={"drop_params": True})
|
||||
res = self.client.embeddings.create(input=[truncate(text, 8000)], model=self.model_name, encoding_format="float",extra_body={"drop_params": True})
|
||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||
|
||||
|
||||
@@ -79,6 +81,7 @@ class LocalAIEmbed(Base):
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
texts = [truncate(t, 8000) for t in texts]
|
||||
ress = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
|
||||
@@ -173,6 +176,7 @@ class XinferenceEmbed(Base):
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
texts = [truncate(t, 8000) for t in texts]
|
||||
ress = []
|
||||
total_tokens = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
@@ -188,7 +192,7 @@ class XinferenceEmbed(Base):
|
||||
def encode_queries(self, text):
|
||||
res = None
|
||||
try:
|
||||
res = self.client.embeddings.create(input=[text], model=self.model_name)
|
||||
res = self.client.embeddings.create(input=[truncate(text, 8000)], model=self.model_name)
|
||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
|
||||
@@ -253,9 +253,9 @@ class DateTimeTool(BuiltinTool):
|
||||
return {
|
||||
"datetime": input_value,
|
||||
"timezone": timezone_str,
|
||||
"timestamp": int(dt.timestamp()) * 1000,
|
||||
"timestamp": int(dt.timestamp() * 1000),
|
||||
"iso_format": dt.isoformat(),
|
||||
"result_data": int(dt.timestamp()) * 1000
|
||||
"result_data": int(dt.timestamp() * 1000)
|
||||
}
|
||||
|
||||
def _calculate_datetime(self, kwargs) -> dict:
|
||||
|
||||
@@ -201,12 +201,15 @@ class VariablePool:
|
||||
|
||||
@staticmethod
|
||||
def _extract_field(struct: "VariableStruct", field: str | None) -> Any:
|
||||
"""If field is given, drill into a dict/object variable's value."""
|
||||
"""If field is given, drill into a dict/object/array[file] variable's value."""
|
||||
if field is None:
|
||||
return struct.instance.get_value()
|
||||
value = struct.instance.get_value()
|
||||
# array[file]: extract the field from every element, return a list
|
||||
if isinstance(value, list):
|
||||
return [item.get(field) if isinstance(item, dict) else getattr(item, field, None) for item in value]
|
||||
if not isinstance(value, dict):
|
||||
raise KeyError(f"Variable is not an object, cannot access field '{field}'")
|
||||
raise KeyError(f"Variable is not an object or array, cannot access field '{field}'")
|
||||
return value.get(field)
|
||||
|
||||
def get_instance(
|
||||
|
||||
@@ -28,86 +28,135 @@ class IterationRuntime:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_id: str,
|
||||
stream: bool,
|
||||
graph: CompiledStateGraph,
|
||||
node_id: str,
|
||||
config: dict[str, Any],
|
||||
state: WorkflowState,
|
||||
variable_pool: VariablePool,
|
||||
child_variable_pool: VariablePool,
|
||||
cycle_nodes: list,
|
||||
cycle_edges: list,
|
||||
):
|
||||
"""
|
||||
Initialize the iteration runtime.
|
||||
|
||||
Args:
|
||||
graph: Compiled workflow graph capable of async invocation.
|
||||
node_id: Unique identifier of the loop node.
|
||||
config: Dictionary containing iteration node configuration.
|
||||
state: Current workflow state at the point of iteration.
|
||||
stream: Whether to run in streaming mode. When True, each iteration
|
||||
uses graph.astream and emits cycle_item events in real time.
|
||||
When False, graph.ainvoke is used instead.
|
||||
node_id: The unique identifier of the iteration node in the workflow.
|
||||
Also used as the variable namespace for item/index inside
|
||||
the subgraph (e.g. {{ node_id.item }}).
|
||||
config: Raw configuration dict for the iteration node, parsed into
|
||||
IterationNodeConfig. Controls input/output variable selectors,
|
||||
parallel execution settings, and output flattening.
|
||||
state: The parent workflow state at the point the iteration node is
|
||||
entered. Each task receives a copy of this state as its
|
||||
starting point.
|
||||
variable_pool: The parent VariablePool containing all variables available
|
||||
at the time the iteration node executes, including sys.*,
|
||||
conv.*, and outputs from upstream nodes. Used as the source
|
||||
for deep-copying into each task's independent child pool.
|
||||
cycle_nodes: List of node config dicts belonging to this iteration's
|
||||
subgraph (i.e. nodes whose cycle field equals node_id).
|
||||
Passed to GraphBuilder when constructing each task's subgraph.
|
||||
cycle_edges: List of edge config dicts connecting nodes within the subgraph.
|
||||
Passed to GraphBuilder alongside cycle_nodes.
|
||||
"""
|
||||
self.start_id = start_id
|
||||
self.stream = stream
|
||||
self.graph = graph
|
||||
self.state = state
|
||||
self.node_id = node_id
|
||||
self.typed_config = IterationNodeConfig(**config)
|
||||
self.looping = True
|
||||
self.variable_pool = variable_pool
|
||||
self.child_variable_pool = child_variable_pool
|
||||
self.cycle_nodes = cycle_nodes
|
||||
self.cycle_edges = cycle_edges
|
||||
self.event_write = get_stream_writer()
|
||||
self.checkpoint = RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": uuid.uuid4()
|
||||
}
|
||||
)
|
||||
|
||||
self.output_value = None
|
||||
self.result: list = []
|
||||
|
||||
async def _init_iteration_state(self, item, idx):
|
||||
def _build_child_graph(self) -> tuple[CompiledStateGraph, VariablePool, str]:
|
||||
"""
|
||||
Initialize a per-iteration copy of the workflow state.
|
||||
Build an independent compiled subgraph for a single iteration task.
|
||||
|
||||
Args:
|
||||
item: Current element from the input array for this iteration.
|
||||
idx: Index of the element in the input array.
|
||||
Each call creates a brand-new VariablePool by deep-copying the parent pool,
|
||||
then passes it to GraphBuilder. GraphBuilder binds this pool to every node's
|
||||
execution closure at build time, so the pool and the subgraph always reference
|
||||
the same object. This is the key design invariant: item/index written into the
|
||||
pool after build will be visible to all nodes inside the subgraph.
|
||||
|
||||
Returns:
|
||||
A copy of the workflow state with iteration-specific variables set.
|
||||
graph: The compiled LangGraph subgraph ready for invocation.
|
||||
child_pool: The VariablePool bound to this subgraph's node closures.
|
||||
Callers must write item/index into this pool before invoking
|
||||
the graph, and read output from it after invocation.
|
||||
start_node_id: The ID of the CYCLE_START node inside the subgraph,
|
||||
used to set the initial activation signal in workflow state.
|
||||
"""
|
||||
loopstate = WorkflowState(
|
||||
**self.state
|
||||
from app.core.workflow.engine.graph_builder import GraphBuilder
|
||||
child_pool = VariablePool()
|
||||
child_pool.copy(self.variable_pool)
|
||||
builder = GraphBuilder(
|
||||
{"nodes": self.cycle_nodes, "edges": self.cycle_edges},
|
||||
stream=self.stream,
|
||||
variable_pool=child_pool,
|
||||
cycle=self.node_id,
|
||||
)
|
||||
self.child_variable_pool.copy(self.variable_pool)
|
||||
await self.child_variable_pool.new(self.node_id, "item", item, VariableType.type_map(item), mut=True)
|
||||
await self.child_variable_pool.new(self.node_id, "index", item, VariableType.type_map(item), mut=True)
|
||||
loopstate["node_outputs"][self.node_id] = {
|
||||
"item": item,
|
||||
"index": idx,
|
||||
}
|
||||
graph = builder.build()
|
||||
return graph, builder.variable_pool, builder.start_node_id
|
||||
|
||||
async def _init_iteration_state(self, item, idx, child_pool: VariablePool, start_id: str):
|
||||
"""
|
||||
Initialize the workflow state for a single iteration.
|
||||
|
||||
Writes the current item and its index into child_pool under the iteration
|
||||
node's namespace (e.g. iteration_xxx.item, iteration_xxx.index), making them
|
||||
accessible to downstream nodes inside the subgraph via variable selectors.
|
||||
|
||||
Also prepares a copy of the parent workflow state with:
|
||||
- node_outputs[node_id] set to {item, index} so the state snapshot is consistent
|
||||
with the pool values.
|
||||
- looping flag set to 1 (active) to signal the subgraph is inside a cycle.
|
||||
- activate[start_id] set to True to trigger the CYCLE_START node.
|
||||
|
||||
Args:
|
||||
item: The current element from the input array.
|
||||
idx: The zero-based index of this element in the input array.
|
||||
child_pool: The VariablePool bound to this iteration's subgraph.
|
||||
Must be the same object returned by _build_child_graph.
|
||||
start_id: The ID of the CYCLE_START node inside the subgraph.
|
||||
|
||||
Returns:
|
||||
A WorkflowState instance ready to be passed to graph.ainvoke or graph.astream.
|
||||
"""
|
||||
loopstate = WorkflowState(**self.state)
|
||||
await child_pool.new(self.node_id, "item", item, VariableType.type_map(item), mut=True)
|
||||
await child_pool.new(self.node_id, "index", idx, VariableType.type_map(idx), mut=True)
|
||||
loopstate["node_outputs"][self.node_id] = {"item": item, "index": idx}
|
||||
loopstate["looping"] = 1
|
||||
loopstate["activate"][self.start_id] = True
|
||||
loopstate["activate"][start_id] = True
|
||||
return loopstate
|
||||
|
||||
def merge_conv_vars(self):
|
||||
self.variable_pool.variables["conv"].update(
|
||||
self.child_variable_pool.variables["conv"]
|
||||
)
|
||||
def _merge_conv_vars(self, child_pool: VariablePool):
|
||||
self.variable_pool.variables["conv"].update(child_pool.variables["conv"])
|
||||
|
||||
async def run_task(self, item, idx):
|
||||
"""
|
||||
Execute a single iteration asynchronously.
|
||||
Each task builds its own subgraph so the variable pool closure is independent.
|
||||
|
||||
Args:
|
||||
item: The input element for this iteration.
|
||||
idx: The index of this iteration.
|
||||
Returns:
|
||||
Tuple of (idx, output, result, child_pool, stopped)
|
||||
"""
|
||||
graph, child_pool, start_id = self._build_child_graph()
|
||||
checkpoint = RunnableConfig(configurable={"thread_id": uuid.uuid4()})
|
||||
init_state = await self._init_iteration_state(item, idx, child_pool, start_id)
|
||||
|
||||
if self.stream:
|
||||
async for event in self.graph.astream(
|
||||
await self._init_iteration_state(item, idx),
|
||||
async for event in graph.astream(
|
||||
init_state,
|
||||
stream_mode=["debug"],
|
||||
config=self.checkpoint
|
||||
config=checkpoint
|
||||
):
|
||||
if isinstance(event, tuple) and len(event) == 2:
|
||||
mode, data = event
|
||||
@@ -117,7 +166,6 @@ class IterationRuntime:
|
||||
event_type = data.get("type")
|
||||
payload = data.get("payload", {})
|
||||
node_name = payload.get("name")
|
||||
|
||||
if node_name and node_name.startswith("nop"):
|
||||
continue
|
||||
if event_type == "task_result":
|
||||
@@ -140,17 +188,13 @@ class IterationRuntime:
|
||||
"token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
|
||||
}
|
||||
})
|
||||
result = self.graph.get_state(config=self.checkpoint).values
|
||||
result = graph.get_state(config=checkpoint).values
|
||||
else:
|
||||
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
|
||||
output = self.child_variable_pool.get_value(self.output_value)
|
||||
if isinstance(output, list) and self.typed_config.flatten:
|
||||
self.result.extend(output)
|
||||
else:
|
||||
self.result.append(output)
|
||||
if result["looping"] == 2:
|
||||
self.looping = False
|
||||
return result
|
||||
result = await graph.ainvoke(init_state)
|
||||
|
||||
output = child_pool.get_value(self.output_value)
|
||||
stopped = result["looping"] == 2
|
||||
return idx, output, result, child_pool, stopped
|
||||
|
||||
def _create_iteration_tasks(self, array_obj, idx):
|
||||
"""
|
||||
@@ -196,16 +240,32 @@ class IterationRuntime:
|
||||
tasks = self._create_iteration_tasks(array_obj, idx)
|
||||
logger.info(f"Iteration node {self.node_id}: running, concurrency {len(tasks)}")
|
||||
idx += self.typed_config.parallel_count
|
||||
child_state.extend(await asyncio.gather(*tasks))
|
||||
self.merge_conv_vars()
|
||||
batch = await asyncio.gather(*tasks)
|
||||
# Sort by idx to preserve order, then collect results
|
||||
batch_sorted = sorted(batch, key=lambda x: x[0])
|
||||
for _, output, result, child_pool, stopped in batch_sorted:
|
||||
if isinstance(output, list) and self.typed_config.flatten:
|
||||
self.result.extend(output)
|
||||
else:
|
||||
self.result.append(output)
|
||||
child_state.append(result)
|
||||
self._merge_conv_vars(child_pool)
|
||||
if stopped:
|
||||
self.looping = False
|
||||
else:
|
||||
# Execute iterations sequentially
|
||||
while idx < len(array_obj) and self.looping:
|
||||
logger.info(f"Iteration node {self.node_id}: running")
|
||||
item = array_obj[idx]
|
||||
result = await self.run_task(item, idx)
|
||||
self.merge_conv_vars()
|
||||
_, output, result, child_pool, stopped = await self.run_task(item, idx)
|
||||
if isinstance(output, list) and self.typed_config.flatten:
|
||||
self.result.extend(output)
|
||||
else:
|
||||
self.result.append(output)
|
||||
self._merge_conv_vars(child_pool)
|
||||
child_state.append(result)
|
||||
if stopped:
|
||||
self.looping = False
|
||||
idx += 1
|
||||
logger.info(f"Iteration node {self.node_id}: execution completed")
|
||||
return {
|
||||
|
||||
@@ -123,7 +123,7 @@ class CycleGraphNode(BaseNode):
|
||||
|
||||
return cycle_nodes, cycle_edges
|
||||
|
||||
def build_graph(self):
|
||||
def build_graph(self, variable_pool: VariablePool):
|
||||
"""
|
||||
Build and compile the internal subgraph for this cycle node.
|
||||
|
||||
@@ -135,6 +135,7 @@ class CycleGraphNode(BaseNode):
|
||||
from app.core.workflow.engine.graph_builder import GraphBuilder
|
||||
|
||||
self.child_variable_pool = VariablePool()
|
||||
self.child_variable_pool.copy(variable_pool)
|
||||
builder = GraphBuilder(
|
||||
{
|
||||
"nodes": self.cycle_nodes,
|
||||
@@ -165,8 +166,8 @@ class CycleGraphNode(BaseNode):
|
||||
Raises:
|
||||
RuntimeError: If the node type is unsupported.
|
||||
"""
|
||||
self.build_graph()
|
||||
if self.node_type == NodeType.LOOP:
|
||||
self.build_graph(variable_pool)
|
||||
return await LoopRuntime(
|
||||
start_id=self.start_node_id,
|
||||
stream=False,
|
||||
@@ -179,20 +180,19 @@ class CycleGraphNode(BaseNode):
|
||||
).run()
|
||||
if self.node_type == NodeType.ITERATION:
|
||||
return await IterationRuntime(
|
||||
start_id=self.start_node_id,
|
||||
stream=False,
|
||||
graph=self.graph,
|
||||
node_id=self.node_id,
|
||||
config=self.config,
|
||||
state=state,
|
||||
variable_pool=variable_pool,
|
||||
child_variable_pool=self.child_variable_pool
|
||||
cycle_nodes=self.cycle_nodes,
|
||||
cycle_edges=self.cycle_edges,
|
||||
).run()
|
||||
raise RuntimeError("Unknown cycle node type")
|
||||
|
||||
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
||||
self.build_graph()
|
||||
if self.node_type == NodeType.LOOP:
|
||||
self.build_graph(variable_pool)
|
||||
yield {
|
||||
"__final__": True,
|
||||
"result": await LoopRuntime(
|
||||
@@ -211,14 +211,13 @@ class CycleGraphNode(BaseNode):
|
||||
yield {
|
||||
"__final__": True,
|
||||
"result": await IterationRuntime(
|
||||
start_id=self.start_node_id,
|
||||
stream=True,
|
||||
graph=self.graph,
|
||||
node_id=self.node_id,
|
||||
config=self.config,
|
||||
state=state,
|
||||
variable_pool=variable_pool,
|
||||
child_variable_pool=self.child_variable_pool
|
||||
cycle_nodes=self.cycle_nodes,
|
||||
cycle_edges=self.cycle_edges,
|
||||
).run()
|
||||
}
|
||||
return
|
||||
|
||||
@@ -72,8 +72,9 @@ class HttpContentTypeConfig(BaseModel):
|
||||
@classmethod
|
||||
def validate_data(cls, v, info):
|
||||
content_type = info.data.get("content_type")
|
||||
if content_type == HttpContentType.FROM_DATA and not isinstance(v, HttpFormData):
|
||||
raise ValueError("When content_type is 'form-data', data must be of type HttpFormData")
|
||||
if content_type == HttpContentType.FROM_DATA and (
|
||||
not isinstance(v, list) or not all(isinstance(item, HttpFormData) for item in v)):
|
||||
raise ValueError("When content_type is 'form-data', data must be a list of HttpFormData")
|
||||
elif content_type in [HttpContentType.JSON] and not isinstance(v, str):
|
||||
raise ValueError("When content_type is JSON, data must be of type str")
|
||||
elif content_type in [HttpContentType.WWW_FORM] and not isinstance(v, dict):
|
||||
|
||||
@@ -260,17 +260,22 @@ class HttpRequestNode(BaseNode):
|
||||
))
|
||||
case HttpContentType.FROM_DATA:
|
||||
data = {}
|
||||
content["files"] = {}
|
||||
files = []
|
||||
for item in self.typed_config.body.data:
|
||||
key = self._render_template(item.key, variable_pool)
|
||||
if item.type == "text":
|
||||
data[self._render_template(item.key, variable_pool)] = self._render_template(item.value,
|
||||
variable_pool)
|
||||
data[key] = self._render_template(item.value, variable_pool)
|
||||
elif item.type == "file":
|
||||
content["files"][self._render_template(item.key, variable_pool)] = (
|
||||
uuid.uuid4().hex,
|
||||
await variable_pool.get_instance(item.value).get_content()
|
||||
)
|
||||
file_instance = variable_pool.get_instance(item.value)
|
||||
if isinstance(file_instance, ArrayVariable):
|
||||
for v in file_instance.value:
|
||||
if isinstance(v, FileVariable):
|
||||
files.append((key, (uuid.uuid4().hex, await v.get_content())))
|
||||
elif isinstance(file_instance, FileVariable):
|
||||
files.append((key, (uuid.uuid4().hex, await file_instance.get_content())))
|
||||
content["data"] = data
|
||||
if files:
|
||||
content["files"] = files
|
||||
case HttpContentType.BINARY:
|
||||
content["files"] = []
|
||||
file_instence = variable_pool.get_instance(self.typed_config.body.data)
|
||||
|
||||
@@ -6,6 +6,30 @@ from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
|
||||
|
||||
|
||||
class SubVariableConditionItem(BaseModel):
|
||||
"""A single condition on a file object's field, used inside sub_variable_condition."""
|
||||
key: str = Field(..., description="Field name of the file object, e.g. type, size, name")
|
||||
operator: ComparisonOperator = Field(..., description="Comparison operator")
|
||||
value: Any = Field(default=None, description="Value to compare with, or variable selector when input_type=variable")
|
||||
input_type: ValueInputType = Field(default=ValueInputType.CONSTANT, description="constant or variable")
|
||||
|
||||
@field_validator("input_type", mode="before")
|
||||
@classmethod
|
||||
def lower_input_type(cls, v):
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return ValueInputType(v.lower())
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid input_type: {v}")
|
||||
return v
|
||||
|
||||
|
||||
class SubVariableCondition(BaseModel):
|
||||
"""Sub-conditions applied to each file element in an array[file] variable."""
|
||||
logical_operator: LogicOperator = Field(default=LogicOperator.AND)
|
||||
conditions: list[SubVariableConditionItem] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ConditionDetail(BaseModel):
|
||||
operator: ComparisonOperator = Field(
|
||||
...,
|
||||
@@ -14,12 +38,12 @@ class ConditionDetail(BaseModel):
|
||||
|
||||
left: str = Field(
|
||||
...,
|
||||
description="Value to compare against"
|
||||
description="Variable selector, e.g. {{sys.files}}"
|
||||
)
|
||||
|
||||
right: Any = Field(
|
||||
default=None,
|
||||
description="Value to compare with"
|
||||
description="Value to compare with (unused when sub_variable_condition is set)"
|
||||
)
|
||||
|
||||
input_type: ValueInputType = Field(
|
||||
@@ -27,6 +51,11 @@ class ConditionDetail(BaseModel):
|
||||
description="Value input type for comparison"
|
||||
)
|
||||
|
||||
sub_variable_condition: SubVariableCondition | None = Field(
|
||||
default=None,
|
||||
description="Sub-conditions for array[file] fields. When set, operator must be contains/not_contains."
|
||||
)
|
||||
|
||||
@field_validator("input_type", mode="before")
|
||||
@classmethod
|
||||
def lower_input_type(cls, v):
|
||||
@@ -39,16 +68,19 @@ class ConditionDetail(BaseModel):
|
||||
|
||||
|
||||
class ConditionBranchConfig(BaseModel):
|
||||
"""Configuration for a conditional branch"""
|
||||
"""Configuration for a conditional branch.
|
||||
|
||||
logical_operator controls how all expressions are combined (AND/OR).
|
||||
"""
|
||||
|
||||
logical_operator: LogicOperator = Field(
|
||||
default=LogicOperator.AND,
|
||||
description="Logical operator used to combine multiple condition expressions"
|
||||
description="Logical operator used to combine all conditions"
|
||||
)
|
||||
|
||||
expressions: list[ConditionDetail] = Field(
|
||||
...,
|
||||
description="List of condition expressions within this branch"
|
||||
default_factory=list,
|
||||
description="List of conditions within this branch"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
|
||||
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance
|
||||
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance, ArrayFileContainsOperator
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -90,11 +90,9 @@ class IfElseNode(BaseNode):
|
||||
list[str]: A list of Python boolean expression strings,
|
||||
ordered by branch priority.
|
||||
"""
|
||||
branch_index = 0
|
||||
conditions = []
|
||||
|
||||
for case_branch in self.typed_config.cases:
|
||||
branch_index += 1
|
||||
branch_result = []
|
||||
for expression in case_branch.expressions:
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
@@ -103,13 +101,18 @@ class IfElseNode(BaseNode):
|
||||
left_value = self.get_variable(left_string, variable_pool)
|
||||
except KeyError:
|
||||
left_value = None
|
||||
evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
|
||||
variable_pool,
|
||||
expression.left,
|
||||
expression.right,
|
||||
expression.input_type
|
||||
)
|
||||
|
||||
if expression.sub_variable_condition is not None and isinstance(left_value, list):
|
||||
evaluator = ArrayFileContainsOperator(left_value, expression.sub_variable_condition, variable_pool)
|
||||
else:
|
||||
evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
|
||||
variable_pool,
|
||||
expression.left,
|
||||
expression.right,
|
||||
expression.input_type
|
||||
)
|
||||
branch_result.append(self._evaluate(expression.operator, evaluator))
|
||||
|
||||
if case_branch.logical_operator == LogicOperator.AND:
|
||||
conditions.append(all(branch_result))
|
||||
else:
|
||||
|
||||
@@ -116,6 +116,11 @@ class LLMNodeConfig(BaseNodeConfig):
|
||||
description="Top-p 采样参数"
|
||||
)
|
||||
|
||||
json_output: bool = Field(
|
||||
default=False,
|
||||
description="是否以 JSON 格式输出"
|
||||
)
|
||||
|
||||
frequency_penalty: float | None = Field(
|
||||
default=None,
|
||||
ge=-2.0,
|
||||
|
||||
@@ -22,6 +22,7 @@ from app.db import get_db_context
|
||||
from app.models import ModelType
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.services.model_service import ModelConfigService
|
||||
from app.models.models_model import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -126,7 +127,11 @@ class LLMNode(BaseNode):
|
||||
|
||||
# 4. 创建 LLM 实例(使用已提取的数据)
|
||||
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
|
||||
extra_params = {"streaming": stream} if stream else {}
|
||||
extra_params: dict[str, Any] = {"streaming": stream} if stream else {}
|
||||
if self.typed_config.temperature is not None:
|
||||
extra_params["temperature"] = self.typed_config.temperature
|
||||
if self.typed_config.max_tokens is not None:
|
||||
extra_params["max_tokens"] = self.typed_config.max_tokens
|
||||
|
||||
llm = RedBearLLM(
|
||||
RedBearModelConfig(
|
||||
@@ -135,7 +140,9 @@ class LLMNode(BaseNode):
|
||||
api_key=model_info.api_key,
|
||||
base_url=model_info.api_base,
|
||||
extra_params=extra_params,
|
||||
is_omni=model_info.is_omni
|
||||
is_omni=model_info.is_omni,
|
||||
capability=model_info.capability,
|
||||
json_output=self.typed_config.json_output,
|
||||
),
|
||||
type=model_info.model_type
|
||||
)
|
||||
@@ -218,6 +225,19 @@ class LLMNode(BaseNode):
|
||||
rendered = self._render_template(prompt_template, variable_pool)
|
||||
self.messages = [{"role": "user", "content": rendered}]
|
||||
|
||||
# ChatTongyi 要求 messages 含 'json' 字样才能使用 response_format,在 system prompt 中注入
|
||||
# VOLCANO 模型不支持 response_format,同样需要 system prompt 注入
|
||||
need_json_prompt = self.typed_config.json_output and (
|
||||
(model_info.provider.lower() == ModelProvider.DASHSCOPE and not model_info.is_omni)
|
||||
or model_info.provider.lower() == ModelProvider.VOLCANO
|
||||
)
|
||||
if need_json_prompt:
|
||||
system_msg = next((m for m in self.messages if m["role"] == "system"), None)
|
||||
if system_msg:
|
||||
system_msg["content"] += "\n请以JSON格式输出。"
|
||||
else:
|
||||
self.messages.insert(0, {"role": "system", "content": "请以JSON格式输出。"})
|
||||
|
||||
return llm
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> AIMessage:
|
||||
|
||||
@@ -395,11 +395,73 @@ class NoneObjectComparisonOperator:
|
||||
return lambda *args, **kwargs: False
|
||||
|
||||
|
||||
class ArrayFileContainsOperator:
|
||||
"""Handles contains/not_contains on array[file] with sub_variable_condition."""
|
||||
|
||||
def __init__(self, left_value: list[dict], sub_variable_condition: Any, pool: VariablePool | None = None):
|
||||
self.left_value = left_value
|
||||
self.sub_variable_condition = sub_variable_condition
|
||||
self.pool = pool
|
||||
|
||||
def _resolve_value(self, cond: Any) -> Any:
|
||||
if cond.input_type == ValueInputType.VARIABLE and self.pool is not None:
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
selector = re.sub(pattern, r"\1", str(cond.value)).strip()
|
||||
return self.pool.get_value(selector, default=None, strict=False)
|
||||
return cond.value
|
||||
|
||||
def _match_item(self, file_item: dict) -> bool:
|
||||
results = []
|
||||
for cond in self.sub_variable_condition.conditions:
|
||||
field_val = file_item.get(cond.key)
|
||||
expected = self._resolve_value(cond)
|
||||
result = self._eval_sub(field_val, cond.operator.value, expected)
|
||||
results.append(result)
|
||||
if self.sub_variable_condition.logical_operator.value == "and":
|
||||
return all(results)
|
||||
return any(results)
|
||||
|
||||
@staticmethod
|
||||
def _eval_sub(field_val: Any, op: str, expected: Any) -> bool:
|
||||
if field_val is None:
|
||||
return op == "empty"
|
||||
match op:
|
||||
case "eq": return str(field_val) == str(expected)
|
||||
case "ne": return str(field_val) != str(expected)
|
||||
case "contains": return isinstance(field_val, str) and str(expected) in field_val
|
||||
case "not_contains": return isinstance(field_val, str) and str(expected) not in field_val
|
||||
case "in": return field_val in (expected if isinstance(expected, list) else [expected])
|
||||
case "not_in": return field_val not in (expected if isinstance(expected, list) else [expected])
|
||||
case "gt": return isinstance(field_val, (int, float)) and field_val > float(expected)
|
||||
case "ge": return isinstance(field_val, (int, float)) and field_val >= float(expected)
|
||||
case "lt": return isinstance(field_val, (int, float)) and field_val < float(expected)
|
||||
case "le": return isinstance(field_val, (int, float)) and field_val <= float(expected)
|
||||
case "empty": return field_val in (None, "", 0)
|
||||
case "not_empty": return field_val not in (None, "", 0)
|
||||
case _: return False
|
||||
|
||||
def contains(self) -> bool:
|
||||
return any(self._match_item(f) for f in self.left_value if isinstance(f, dict))
|
||||
|
||||
def not_contains(self) -> bool:
|
||||
return not self.contains()
|
||||
|
||||
def empty(self) -> bool:
|
||||
return not self.left_value
|
||||
|
||||
def not_empty(self) -> bool:
|
||||
return bool(self.left_value)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return lambda *args, **kwargs: False
|
||||
|
||||
|
||||
CompareOperatorInstance = Union[
|
||||
StringComparisonOperator,
|
||||
NumberComparisonOperator,
|
||||
BooleanComparisonOperator,
|
||||
ArrayComparisonOperator,
|
||||
ArrayFileContainsOperator,
|
||||
ObjectComparisonOperator
|
||||
]
|
||||
CompareOperatorType = Type[CompareOperatorInstance]
|
||||
|
||||
@@ -15,6 +15,7 @@ from app.services.tool_service import ToolService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TEMPLATE_PATTERN = re.compile(r"\{\{.*?}}")
|
||||
PURE_VARIABLE_PATTERN = re.compile(r"^\{\{\s*([\w.]+)\s*}}$")
|
||||
|
||||
|
||||
class ToolNode(BaseNode):
|
||||
@@ -52,13 +53,21 @@ class ToolNode(BaseNode):
|
||||
# 渲染工具参数
|
||||
rendered_parameters = {}
|
||||
for param_name, param_template in self.typed_config.tool_parameters.items():
|
||||
if isinstance(param_template, str) and TEMPLATE_PATTERN.search(param_template):
|
||||
try:
|
||||
rendered_value = self._render_template(param_template, variable_pool)
|
||||
except Exception as e:
|
||||
raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e
|
||||
if isinstance(param_template, str):
|
||||
pure_match = PURE_VARIABLE_PATTERN.match(param_template)
|
||||
if pure_match:
|
||||
# 纯单变量引用直接取原始值,保留 int/bool/float 等类型
|
||||
rendered_value = self.get_variable(pure_match.group(1), variable_pool, strict=False)
|
||||
if rendered_value is None:
|
||||
rendered_value = self._render_template(param_template, variable_pool)
|
||||
elif TEMPLATE_PATTERN.search(param_template):
|
||||
try:
|
||||
rendered_value = self._render_template(param_template, variable_pool)
|
||||
except Exception as e:
|
||||
raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e
|
||||
else:
|
||||
rendered_value = param_template
|
||||
else:
|
||||
# 非模板参数(数字/布尔/普通字符串)直接保留原值
|
||||
rendered_value = param_template
|
||||
rendered_parameters[param_name] = rendered_value
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ class FileVariable(BaseVariable):
|
||||
total_bytes = 0
|
||||
chunks = []
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||
async with client.stream("GET", self.value.url) as resp:
|
||||
resp.raise_for_status()
|
||||
async for chunk in resp.aiter_bytes(8192):
|
||||
|
||||
Reference in New Issue
Block a user