feat(workflow): support nested variable access and DashScope rerank provider

This commit is contained in:
Timebomb2018
2026-04-10 16:21:49 +08:00
parent 807dee8460
commit e5e6699168
10 changed files with 88 additions and 19 deletions

View File

@@ -41,6 +41,7 @@ class BizCode(IntEnum):
FILE_NOT_FOUND = 4006
APP_NOT_FOUND = 4007
RELEASE_NOT_FOUND = 4008
USER_NO_ACCESS = 4009
# 冲突/状态5xxx
DUPLICATE_NAME = 5001
@@ -118,6 +119,7 @@ HTTP_MAPPING = {
BizCode.WORKSPACE_ACCESS_DENIED: 403,
BizCode.NOT_FOUND: 400,
BizCode.USER_NOT_FOUND: 200,
BizCode.USER_NO_ACCESS: 401,
BizCode.WORKSPACE_NOT_FOUND: 400,
BizCode.MODEL_NOT_FOUND: 400,
BizCode.KNOWLEDGE_NOT_FOUND: 400,

View File

@@ -206,10 +206,15 @@ class RedBearModelFactory:
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
return {
"model": config.model_name,
# "base_url": config.base_url,
"jina_api_key": config.api_key,
**config.extra_params
}
elif provider == ModelProvider.DASHSCOPE:
return {
"model": config.model_name,
"dashscope_api_key": config.api_key,
**config.extra_params
}
else:
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
@@ -265,6 +270,9 @@ def get_provider_rerank_class(provider: str):
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
from langchain_community.document_compressors import JinaRerank
return JinaRerank
elif provider == ModelProvider.DASHSCOPE:
from langchain_community.document_compressors.dashscope_rerank import DashScopeRerank
return DashScopeRerank
# elif provider == ModelProvider.OLLAMA:
# from langchain_ollama import OllamaEmbeddings
# return OllamaEmbeddings

View File

@@ -36,9 +36,7 @@ class RedBearEmbeddings(Embeddings):
"base_url": config.base_url,
"api_key": config.api_key,
"timeout": httpx.Timeout(timeout=config.timeout, connect=60.0),
"max_retries": config.max_retries,
"check_embedding_ctx_length": False,
"encoding_format": "float"
"max_retries": config.max_retries
}
elif provider == ModelProvider.DASHSCOPE:
params = {

View File

@@ -76,5 +76,9 @@ class RedBearRerank(BaseDocumentCompressor):
from langchain_community.document_compressors import JinaRerank
model_instance: JinaRerank = self._model
return model_instance.rerank(documents=documents, query=query, top_n=top_n)
elif provider == ModelProvider.DASHSCOPE:
from langchain_community.document_compressors.dashscope_rerank import DashScopeRerank
model_instance: DashScopeRerank = self._model
return model_instance.rerank(documents=documents, query=query, top_n=top_n)
else:
raise ValueError(f"不支持的模型提供商: {provider}")

View File

@@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
# ["Hello ", "{{user.name}}", "!"]
_OUTPUT_PATTERN = re.compile(r'\{\{.*?}}|[^{}]+')
# Strict variable format: {{ node_id.field_name }}
_VARIABLE_PATTERN = re.compile(r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*}}')
_VARIABLE_PATTERN = re.compile(r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+(?:\.[a-zA-Z0-9_]+)?\s*}}')
class GraphBuilder:

View File

@@ -14,7 +14,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
logger = get_logger(__name__)
SCOPE_PATTERN = re.compile(
r"\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\s*}}"
r"\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+(?:\.[a-zA-Z0-9_]+)?\s*}}"
)

View File

@@ -34,19 +34,22 @@ class LazyVariableDict:
return self._cache[key]
var_struct = self._source.get(key)
if var_struct is None:
raise KeyError(key)
value = var_struct.instance.to_literal() if self._literal else var_struct.instance.get_value()
return None
raw = var_struct.instance.get_value()
# literal 模式下 dict/list 保留结构,让 Jinja2 能继续访问子字段(如 .type
value = raw if (not self._literal or isinstance(raw, (dict, list))) else var_struct.instance.to_literal()
self._cache[key] = value
return value
def get(self, key, default=None):
try:
return self._resolve(key)
except KeyError:
return default
value = self._resolve(key)
return default if value is None else value
def __getitem__(self, key):
return self._resolve(key)
value = self._resolve(key)
if value is None:
raise KeyError(key)
return value
def __getattr__(self, key):
if key.startswith('_'):
@@ -164,7 +167,7 @@ class VariablePool:
def transform_selector(selector):
variable_literal = VARIABLE_PATTERN.sub(r"\1", selector).strip()
selector = VariableSelector.from_string(variable_literal).path
if len(selector) != 2:
if len(selector) not in (2, 3):
raise ValueError(f"Selector not valid - {selector}")
return selector
@@ -196,6 +199,16 @@ class VariablePool:
return None
return var_instance
@staticmethod
def _extract_field(struct: "VariableStruct", field: str | None) -> Any:
"""If field is given, drill into a dict/object variable's value."""
if field is None:
return struct.instance.get_value()
value = struct.instance.get_value()
if not isinstance(value, dict):
raise KeyError(f"Variable is not an object, cannot access field '{field}'")
return value.get(field)
def get_instance(
self,
selector: str,
@@ -250,12 +263,14 @@ class VariablePool:
Raises:
KeyError: If strict is True and the variable does not exist.
"""
path = self.transform_selector(selector)
variable_struct = self._get_variable_struct(selector)
if variable_struct is None:
if strict:
raise KeyError(f"{selector} not exist")
return default
if len(path) == 3:
return self._extract_field(variable_struct, path[2])
return variable_struct.instance.get_value()
def get_literal(
@@ -282,12 +297,15 @@ class VariablePool:
Raises:
KeyError: If strict is True and the variable does not exist.
"""
path = self.transform_selector(selector)
variable_struct = self._get_variable_struct(selector)
if variable_struct is None:
if strict:
raise KeyError(f"{selector} not exist")
return default
if len(path) == 3:
value = self._extract_field(variable_struct, path[2])
return str(value) if value is not None else ""
return variable_struct.instance.to_literal()
async def set(
@@ -345,7 +363,14 @@ class VariablePool:
Returns:
变量是否存在
"""
return self._get_variable_struct(selector) is not None
path = self.transform_selector(selector)
struct = self._get_variable_struct(selector)
if struct is None:
return False
if len(path) == 3:
value = struct.instance.get_value()
return isinstance(value, dict) and path[2] in value
return True
def lazy_namespace(self, namespace: str, literal: bool = False) -> LazyVariableDict:
return LazyVariableDict(self.variables.get(namespace, {}), literal)