Merge branch 'develop' into fix/memoryconfig-update
This commit is contained in:
@@ -574,8 +574,12 @@ async def get_file_url(
|
|||||||
# For local storage, generate signed URL with expiration
|
# For local storage, generate signed URL with expiration
|
||||||
url = generate_signed_url(str(file_id), expires)
|
url = generate_signed_url(str(file_id), expires)
|
||||||
else:
|
else:
|
||||||
# For remote storage (OSS/S3), get presigned URL
|
# For remote storage (OSS/S3), get presigned URL with forced download
|
||||||
url = await storage_service.get_file_url(file_key, expires=expires)
|
url = await storage_service.get_file_url(
|
||||||
|
file_key,
|
||||||
|
expires=expires,
|
||||||
|
file_name=file_metadata.file_name,
|
||||||
|
)
|
||||||
url = _match_scheme(request, url)
|
url = _match_scheme(request, url)
|
||||||
|
|
||||||
api_logger.info(f"Generated file URL: file_id={file_id}")
|
api_logger.info(f"Generated file URL: file_id={file_id}")
|
||||||
@@ -786,7 +790,7 @@ async def permanent_download_file(
|
|||||||
# For remote storage, redirect to presigned URL with long expiration
|
# For remote storage, redirect to presigned URL with long expiration
|
||||||
try:
|
try:
|
||||||
# Use a very long expiration (7 days max for most cloud providers)
|
# Use a very long expiration (7 days max for most cloud providers)
|
||||||
presigned_url = await storage_service.get_file_url(file_key, expires=604800)
|
presigned_url = await storage_service.get_file_url(file_key, expires=604800, file_name=file_metadata.file_name)
|
||||||
presigned_url = _match_scheme(request, presigned_url)
|
presigned_url = _match_scheme(request, presigned_url)
|
||||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -91,9 +91,11 @@ async def get_mcp_servers(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
cookies = api.get_cookies(token)
|
cookies = api.get_cookies(token)
|
||||||
|
headers=api.builder_headers(api.headers)
|
||||||
|
headers['Authorization'] = f'Bearer {token}'
|
||||||
r = api.session.put(
|
r = api.session.put(
|
||||||
url=api.mcp_base_url,
|
url=api.mcp_base_url,
|
||||||
headers=api.builder_headers(api.headers),
|
headers=headers,
|
||||||
json=body,
|
json=body,
|
||||||
cookies=cookies)
|
cookies=cookies)
|
||||||
raise_for_http_status(r)
|
raise_for_http_status(r)
|
||||||
@@ -173,6 +175,7 @@ async def get_operational_mcp_servers(
|
|||||||
|
|
||||||
url = f'{api.mcp_base_url}/operational'
|
url = f'{api.mcp_base_url}/operational'
|
||||||
headers = api.builder_headers(api.headers)
|
headers = api.builder_headers(api.headers)
|
||||||
|
headers['Authorization'] = f'Bearer {token}'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cookies = api.get_cookies(access_token=token, cookies_required=True)
|
cookies = api.get_cookies(access_token=token, cookies_required=True)
|
||||||
@@ -260,7 +263,9 @@ async def create_mcp_market_config(
|
|||||||
api.login(create_data.token)
|
api.login(create_data.token)
|
||||||
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
|
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
|
||||||
cookies = api.get_cookies(create_data.token)
|
cookies = api.get_cookies(create_data.token)
|
||||||
r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies)
|
headers = api.builder_headers(api.headers)
|
||||||
|
headers['Authorization'] = f'Bearer {create_data.token}'
|
||||||
|
r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies)
|
||||||
raise_for_http_status(r)
|
raise_for_http_status(r)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
|
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
|
||||||
@@ -290,9 +295,11 @@ async def create_mcp_market_config(
|
|||||||
'search': ""
|
'search': ""
|
||||||
}
|
}
|
||||||
cookies = api.get_cookies(token)
|
cookies = api.get_cookies(token)
|
||||||
|
headers = api.builder_headers(api.headers)
|
||||||
|
headers['Authorization'] = f'Bearer {token}'
|
||||||
r = api.session.put(
|
r = api.session.put(
|
||||||
url=api.mcp_base_url,
|
url=api.mcp_base_url,
|
||||||
headers=api.builder_headers(api.headers),
|
headers=headers,
|
||||||
json=body,
|
json=body,
|
||||||
cookies=cookies)
|
cookies=cookies)
|
||||||
raise_for_http_status(r)
|
raise_for_http_status(r)
|
||||||
@@ -393,7 +400,9 @@ async def update_mcp_market_config(
|
|||||||
api.login(update_data.token)
|
api.login(update_data.token)
|
||||||
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
|
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
|
||||||
cookies = api.get_cookies(update_data.token)
|
cookies = api.get_cookies(update_data.token)
|
||||||
r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies)
|
headers = api.builder_headers(api.headers)
|
||||||
|
headers['Authorization'] = f'Bearer {update_data.token}'
|
||||||
|
r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies)
|
||||||
raise_for_http_status(r)
|
raise_for_http_status(r)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
|
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
|
||||||
|
|||||||
@@ -669,6 +669,7 @@ async def config_query(
|
|||||||
content = {
|
content = {
|
||||||
"app_type": release.app.type,
|
"app_type": release.app.type,
|
||||||
"variables": release.config.get("variables"),
|
"variables": release.config.get("variables"),
|
||||||
|
"memory": release.config.get("memory", {}).get("enabled"),
|
||||||
"features": release.config.get("features")
|
"features": release.config.get("features")
|
||||||
}
|
}
|
||||||
elif release.app.type == AppType.MULTI_AGENT:
|
elif release.app.type == AppType.MULTI_AGENT:
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
|||||||
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
|
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
|
||||||
elif int(is_end_user_id) == int(scope):
|
elif int(is_end_user_id) == int(scope):
|
||||||
logger.info('写入长期记忆NEO4J')
|
logger.info('写入长期记忆NEO4J')
|
||||||
formatted_messages = (redis_messages)
|
formatted_messages = redis_messages
|
||||||
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
|
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
|
||||||
if hasattr(memory_config, 'config_id'):
|
if hasattr(memory_config, 'config_id'):
|
||||||
config_id = memory_config.config_id
|
config_id = memory_config.config_id
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
OpenAI Embedder 客户端实现
|
OpenAI Embedder 客户端实现
|
||||||
|
|
||||||
基于 LangChain 和 RedBearEmbeddings 的 OpenAI 嵌入模型客户端实现。
|
基于 LangChain 和 RedBearEmbeddings 的 OpenAI 嵌入模型客户端实现。
|
||||||
|
自动支持火山引擎的多模态 Embedding。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
@@ -13,6 +14,7 @@ from app.core.memory.llm_tools.embedder_client import (
|
|||||||
)
|
)
|
||||||
from app.core.models.base import RedBearModelConfig
|
from app.core.models.base import RedBearModelConfig
|
||||||
from app.core.models.embedding import RedBearEmbeddings
|
from app.core.models.embedding import RedBearEmbeddings
|
||||||
|
from app.models.models_model import ModelProvider
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -25,6 +27,7 @@ class OpenAIEmbedderClient(EmbedderClient):
|
|||||||
- 批量文本嵌入
|
- 批量文本嵌入
|
||||||
- 自动重试机制
|
- 自动重试机制
|
||||||
- 错误处理
|
- 错误处理
|
||||||
|
- 火山引擎多模态 Embedding(自动识别)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_config: RedBearModelConfig):
|
def __init__(self, model_config: RedBearModelConfig):
|
||||||
@@ -36,7 +39,7 @@ class OpenAIEmbedderClient(EmbedderClient):
|
|||||||
"""
|
"""
|
||||||
super().__init__(model_config)
|
super().__init__(model_config)
|
||||||
|
|
||||||
# 初始化 RedBearEmbeddings 模型
|
# 初始化 RedBearEmbeddings(自动支持火山引擎多模态)
|
||||||
self.model = RedBearEmbeddings(
|
self.model = RedBearEmbeddings(
|
||||||
RedBearModelConfig(
|
RedBearModelConfig(
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
@@ -47,8 +50,9 @@ class OpenAIEmbedderClient(EmbedderClient):
|
|||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
self.is_multimodal = self.model.is_multimodal_supported()
|
||||||
|
|
||||||
logger.info("OpenAI Embedder 客户端初始化完成")
|
logger.info(f"OpenAI Embedder 客户端初始化完成 (provider={self.provider}, multimodal={self.is_multimodal})")
|
||||||
|
|
||||||
async def response(
|
async def response(
|
||||||
self,
|
self,
|
||||||
@@ -77,7 +81,14 @@ class OpenAIEmbedderClient(EmbedderClient):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
# 生成嵌入向量
|
# 生成嵌入向量
|
||||||
embeddings = await self.model.aembed_documents(texts)
|
if self.is_multimodal:
|
||||||
|
# 火山引擎多模态 Embedding
|
||||||
|
embeddings = await self.model.aembed_multimodal(
|
||||||
|
[{"type": "text", "text": text} for text in texts]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 普通 Embedding
|
||||||
|
embeddings = await self.model.aembed_documents(texts)
|
||||||
|
|
||||||
logger.debug(f"成功生成 {len(embeddings)} 个嵌入向量")
|
logger.debug(f"成功生成 {len(embeddings)} 个嵌入向量")
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|||||||
@@ -1099,7 +1099,6 @@ class ExtractionOrchestrator:
|
|||||||
metadata=chunk.metadata,
|
metadata=chunk.metadata,
|
||||||
)
|
)
|
||||||
chunk_nodes.append(chunk_node)
|
chunk_nodes.append(chunk_node)
|
||||||
logger.error(f"chunk file: {chunk.files}")
|
|
||||||
|
|
||||||
for p, file_type in chunk.files:
|
for p, file_type in chunk.files:
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from .base import RedBearModelConfig, get_provider_llm_class, RedBearModelFacto
|
|||||||
from .llm import RedBearLLM
|
from .llm import RedBearLLM
|
||||||
from .embedding import RedBearEmbeddings
|
from .embedding import RedBearEmbeddings
|
||||||
from .rerank import RedBearRerank
|
from .rerank import RedBearRerank
|
||||||
|
from .generation import RedBearImageGenerator, RedBearVideoGenerator
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"RedBearModelConfig",
|
"RedBearModelConfig",
|
||||||
@@ -9,5 +10,7 @@ __all__ = [
|
|||||||
"RedBearEmbeddings",
|
"RedBearEmbeddings",
|
||||||
"RedBearRerank",
|
"RedBearRerank",
|
||||||
"RedBearModelFactory",
|
"RedBearModelFactory",
|
||||||
"get_provider_llm_class"
|
"get_provider_llm_class",
|
||||||
|
"RedBearImageGenerator",
|
||||||
|
"RedBearVideoGenerator"
|
||||||
]
|
]
|
||||||
@@ -67,7 +67,7 @@ class RedBearModelFactory:
|
|||||||
**config.extra_params
|
**config.extra_params
|
||||||
}
|
}
|
||||||
|
|
||||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA]:
|
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]:
|
||||||
# 使用 httpx.Timeout 对象来设置详细的超时配置
|
# 使用 httpx.Timeout 对象来设置详细的超时配置
|
||||||
# 这样可以分别控制连接超时和读取超时
|
# 这样可以分别控制连接超时和读取超时
|
||||||
import httpx
|
import httpx
|
||||||
@@ -160,11 +160,13 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy
|
|||||||
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
||||||
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
||||||
return ChatOpenAI
|
return ChatOpenAI
|
||||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.VOLCANO]:
|
||||||
if type == ModelType.LLM:
|
if type == ModelType.LLM:
|
||||||
return OpenAI
|
return OpenAI
|
||||||
elif type == ModelType.CHAT:
|
elif type == ModelType.CHAT:
|
||||||
return ChatOpenAI
|
return ChatOpenAI
|
||||||
|
else:
|
||||||
|
raise BusinessException(f"不支持的模型提供商及类型: {provider}-{type}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||||
elif provider == ModelProvider.DASHSCOPE:
|
elif provider == ModelProvider.DASHSCOPE:
|
||||||
return ChatTongyi
|
return ChatTongyi
|
||||||
elif provider == ModelProvider.OLLAMA:
|
elif provider == ModelProvider.OLLAMA:
|
||||||
|
|||||||
@@ -1,23 +1,190 @@
|
|||||||
|
|
||||||
from typing import Any, Dict, List, Optional, TypeVar, Callable
|
from typing import Any, Dict, List, Optional, Union
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
|
|
||||||
from app.core.models.base import RedBearModelConfig,get_provider_embedding_class,RedBearModelFactory
|
from app.core.models.base import RedBearModelConfig, get_provider_embedding_class, RedBearModelFactory
|
||||||
|
from app.models.models_model import ModelProvider
|
||||||
|
|
||||||
|
|
||||||
class RedBearEmbeddings(Embeddings):
|
class RedBearEmbeddings(Embeddings):
|
||||||
"""Embedding → 完全符合 LangChain Embeddings"""
|
"""统一的 Embedding 类,自动支持多模态(根据 provider 判断)"""
|
||||||
|
|
||||||
def __init__(self, config: RedBearModelConfig):
|
def __init__(self, config: RedBearModelConfig):
|
||||||
self._model = self._create_model(config)
|
|
||||||
self._config = config
|
self._config = config
|
||||||
|
self._is_volcano = config.provider.lower() == ModelProvider.VOLCANO
|
||||||
|
|
||||||
|
if self._is_volcano:
|
||||||
|
# 火山引擎使用 Ark SDK
|
||||||
|
self._client = self._create_volcano_client(config)
|
||||||
|
self._model = None
|
||||||
|
else:
|
||||||
|
# 其他 provider 使用 LangChain
|
||||||
|
self._model = self._create_model(config)
|
||||||
|
self._client = None
|
||||||
|
|
||||||
def _create_model(self, config: RedBearModelConfig) -> Embeddings:
|
def _create_model(self, config: RedBearModelConfig) -> Embeddings:
|
||||||
"""根据配置创建模型"""
|
"""根据配置创建 LangChain 模型"""
|
||||||
embedding_class = get_provider_embedding_class(config.provider)
|
embedding_class = get_provider_embedding_class(config.provider)
|
||||||
model_params = RedBearModelFactory.get_model_params(config)
|
model_params = RedBearModelFactory.get_model_params(config)
|
||||||
return embedding_class(**model_params)
|
return embedding_class(**model_params)
|
||||||
|
|
||||||
|
def _create_volcano_client(self, config: RedBearModelConfig):
|
||||||
|
"""创建火山引擎客户端"""
|
||||||
|
from volcenginesdkarkruntime import Ark
|
||||||
|
return Ark(api_key=config.api_key, base_url=config.base_url)
|
||||||
|
|
||||||
|
# ==================== LangChain 标准接口 ====================
|
||||||
|
|
||||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||||
return self._model.embed_documents(texts)
|
"""批量文本向量化(LangChain 标准接口)"""
|
||||||
|
if self._is_volcano:
|
||||||
|
# 火山引擎多模态 Embedding
|
||||||
|
contents = [{"type": "text", "text": text} for text in texts]
|
||||||
|
response = self._client.multimodal_embeddings.create(
|
||||||
|
model=self._config.model_name,
|
||||||
|
input=contents,
|
||||||
|
encoding_format="float"
|
||||||
|
)
|
||||||
|
return [response.data.embedding]
|
||||||
|
else:
|
||||||
|
# 其他 provider
|
||||||
|
return self._model.embed_documents(texts)
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
return self._model.embed_query(text)
|
"""单个文本向量化(LangChain 标准接口)"""
|
||||||
|
if self._is_volcano:
|
||||||
|
# 火山引擎多模态 Embedding
|
||||||
|
result = self.embed_documents([text])
|
||||||
|
return result[0] if result else []
|
||||||
|
else:
|
||||||
|
# 其他 provider
|
||||||
|
return self._model.embed_query(text)
|
||||||
|
|
||||||
|
# ==================== 多模态扩展方法 ====================
|
||||||
|
|
||||||
|
def embed_multimodal(
|
||||||
|
self,
|
||||||
|
contents: List[Dict[str, Any]],
|
||||||
|
**kwargs
|
||||||
|
) -> List[List[float]]:
|
||||||
|
"""
|
||||||
|
多模态向量化(仅火山引擎支持)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contents: 内容列表,格式:
|
||||||
|
- 文本: {"type": "text", "text": "..."}
|
||||||
|
- 图片: {"type": "image_url", "image_url": {"url": "..."}}
|
||||||
|
- 视频: {"type": "video_url", "video_url": {"url": "..."}}
|
||||||
|
**kwargs: 其他参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
向量列表
|
||||||
|
"""
|
||||||
|
if not self._is_volcano:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"多模态 Embedding 仅支持火山引擎,当前 provider: {self._config.provider}"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = self._client.multimodal_embeddings.create(
|
||||||
|
model=self._config.model_name,
|
||||||
|
input=contents,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return [response.data.embedding]
|
||||||
|
|
||||||
|
async def aembed_multimodal(
|
||||||
|
self,
|
||||||
|
contents: List[Dict[str, Any]],
|
||||||
|
**kwargs
|
||||||
|
) -> List[List[float]]:
|
||||||
|
"""异步多模态向量化"""
|
||||||
|
# 火山引擎 SDK 暂不支持异步,使用同步方法
|
||||||
|
return self.embed_multimodal(contents, **kwargs)
|
||||||
|
|
||||||
|
def embed_text(self, text: str, **kwargs) -> List[float]:
|
||||||
|
"""文本向量化(便捷方法)"""
|
||||||
|
if self._is_volcano:
|
||||||
|
result = self.embed_multimodal(
|
||||||
|
[{"type": "text", "text": text}],
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return result[0] if result else []
|
||||||
|
else:
|
||||||
|
return self.embed_query(text)
|
||||||
|
|
||||||
|
def embed_image(self, image_url: str, **kwargs) -> List[float]:
|
||||||
|
"""图片向量化(仅火山引擎支持)"""
|
||||||
|
if not self._is_volcano:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"图片向量化仅支持火山引擎,当前 provider: {self._config.provider}"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.embed_multimodal(
|
||||||
|
[{"type": "image_url", "image_url": {"url": image_url}}],
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return result[0] if result else []
|
||||||
|
|
||||||
|
def embed_video(self, video_url: str, **kwargs) -> List[float]:
|
||||||
|
"""视频向量化(仅火山引擎支持)"""
|
||||||
|
if not self._is_volcano:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"视频向量化仅支持火山引擎,当前 provider: {self._config.provider}"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.embed_multimodal(
|
||||||
|
[{"type": "video_url", "video_url": {"url": video_url}}],
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return result[0] if result else []
|
||||||
|
|
||||||
|
def embed_batch(
|
||||||
|
self,
|
||||||
|
items: List[Union[str, Dict[str, Any]]],
|
||||||
|
**kwargs
|
||||||
|
) -> List[List[float]]:
|
||||||
|
"""
|
||||||
|
批量向量化(支持混合类型)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
items: 可以是字符串列表或内容字典列表
|
||||||
|
**kwargs: 其他参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
向量列表
|
||||||
|
"""
|
||||||
|
# 如果全是字符串,使用标准方法
|
||||||
|
if all(isinstance(item, str) for item in items):
|
||||||
|
return self.embed_documents(items)
|
||||||
|
|
||||||
|
# 如果包含字典,需要多模态支持
|
||||||
|
if not self._is_volcano:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"混合类型批量向量化仅支持火山引擎,当前 provider: {self._config.provider}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 标准化输入格式
|
||||||
|
contents = []
|
||||||
|
for item in items:
|
||||||
|
if isinstance(item, str):
|
||||||
|
contents.append({"type": "text", "text": item})
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
contents.append(item)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的输入类型: {type(item)}")
|
||||||
|
|
||||||
|
return self.embed_multimodal(contents, **kwargs)
|
||||||
|
|
||||||
|
# ==================== 工具方法 ====================
|
||||||
|
|
||||||
|
def is_multimodal_supported(self) -> bool:
|
||||||
|
"""检查是否支持多模态"""
|
||||||
|
return self._is_volcano
|
||||||
|
|
||||||
|
def get_provider(self) -> str:
|
||||||
|
"""获取 provider"""
|
||||||
|
return self._config.provider
|
||||||
|
|
||||||
|
|
||||||
|
# 保留 RedBearMultimodalEmbeddings 作为别名,向后兼容
|
||||||
|
RedBearMultimodalEmbeddings = RedBearEmbeddings
|
||||||
|
|||||||
344
api/app/core/models/generation.py
Normal file
344
api/app/core/models/generation.py
Normal file
@@ -0,0 +1,344 @@
|
|||||||
|
"""
|
||||||
|
图片和视频生成模型封装
|
||||||
|
|
||||||
|
支持的 Provider:
|
||||||
|
- Volcano (火山引擎): 使用 volcenginesdkarkruntime
|
||||||
|
- OpenAI: 使用 openai SDK
|
||||||
|
"""
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from volcenginesdkarkruntime import Ark
|
||||||
|
from volcenginesdkarkruntime.types.images.images import (
|
||||||
|
SequentialImageGenerationOptions,
|
||||||
|
ContentGenerationTool,
|
||||||
|
OptimizePromptOptions
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.core.models.base import RedBearModelConfig
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
from app.models.models_model import ModelProvider
|
||||||
|
|
||||||
|
|
||||||
|
class RedBearImageGenerator:
|
||||||
|
"""图片生成模型封装"""
|
||||||
|
|
||||||
|
def __init__(self, config: RedBearModelConfig):
|
||||||
|
self._config = config
|
||||||
|
self._client = self._create_client(config)
|
||||||
|
|
||||||
|
def _create_client(self, config: RedBearModelConfig):
|
||||||
|
"""根据 provider 创建客户端"""
|
||||||
|
provider = config.provider.lower()
|
||||||
|
|
||||||
|
if provider == ModelProvider.VOLCANO:
|
||||||
|
return Ark(api_key=config.api_key, base_url=config.base_url)
|
||||||
|
# elif provider == ModelProvider.OPENAI:
|
||||||
|
# from openai import OpenAI
|
||||||
|
# return OpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||||
|
else:
|
||||||
|
raise BusinessException(
|
||||||
|
f"不支持的图片生成提供商: {provider}",
|
||||||
|
code=BizCode.PROVIDER_NOT_SUPPORTED
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
image: Optional[Any] = None,
|
||||||
|
size: Optional[str] = "2K",
|
||||||
|
output_format: str = "png",
|
||||||
|
response_format: str = "url",
|
||||||
|
watermark: bool = False,
|
||||||
|
sequential_image_generation: Optional[str] = None,
|
||||||
|
sequential_image_generation_options: Optional[Dict] = None,
|
||||||
|
tools: Optional[list] = None,
|
||||||
|
optimize_prompt_options: Optional[Dict] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
生成图片
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: 提示词
|
||||||
|
image: 参考图片URL或URL列表(图文生图/多图融合)
|
||||||
|
size: 图片尺寸,支持 "2K", "2048x2048", "1920x1080" 等(至少3686400像素)
|
||||||
|
output_format: 输出格式,如 "png", "jpg"
|
||||||
|
response_format: 返回格式,"url" 或 "b64_json"
|
||||||
|
watermark: 是否添加水印
|
||||||
|
sequential_image_generation: 组图生成模式,"auto" 或 "disabled"
|
||||||
|
sequential_image_generation_options: 组图生成选项,如 {"max_images": 4}
|
||||||
|
tools: 工具列表,如 [{"type": "web_search"}] 用于联网搜索生图
|
||||||
|
optimize_prompt_options: 提示词优化选项,如 {"mode": "fast"}
|
||||||
|
stream: 是否使用流式生成
|
||||||
|
**kwargs: 其他参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
生成结果
|
||||||
|
"""
|
||||||
|
provider = self._config.provider.lower()
|
||||||
|
|
||||||
|
if provider == ModelProvider.VOLCANO:
|
||||||
|
params = {
|
||||||
|
"model": self._config.model_name,
|
||||||
|
"prompt": prompt,
|
||||||
|
"size": size,
|
||||||
|
"output_format": output_format,
|
||||||
|
"response_format": response_format,
|
||||||
|
"watermark": watermark,
|
||||||
|
}
|
||||||
|
|
||||||
|
if image is not None:
|
||||||
|
params["image"] = image
|
||||||
|
|
||||||
|
if sequential_image_generation:
|
||||||
|
params["sequential_image_generation"] = sequential_image_generation
|
||||||
|
if sequential_image_generation_options:
|
||||||
|
params["sequential_image_generation_options"] = SequentialImageGenerationOptions(
|
||||||
|
**sequential_image_generation_options
|
||||||
|
)
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
params["tools"] = [ContentGenerationTool(**tool) if isinstance(tool, dict) else tool for tool in tools]
|
||||||
|
|
||||||
|
if optimize_prompt_options:
|
||||||
|
params["optimize_prompt_options"] = OptimizePromptOptions(**optimize_prompt_options)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
params["stream"] = True
|
||||||
|
|
||||||
|
params.update(kwargs)
|
||||||
|
response = self._client.images.generate(**params)
|
||||||
|
|
||||||
|
# elif provider == ModelProvider.OPENAI:
|
||||||
|
# response = self._client.images.generate(
|
||||||
|
# model=self._config.model_name,
|
||||||
|
# prompt=prompt,
|
||||||
|
# size=size,
|
||||||
|
# n=n,
|
||||||
|
# **kwargs
|
||||||
|
# )
|
||||||
|
else:
|
||||||
|
raise BusinessException(
|
||||||
|
f"不支持的提供商: {provider}",
|
||||||
|
code=BizCode.PROVIDER_NOT_SUPPORTED
|
||||||
|
)
|
||||||
|
|
||||||
|
return response.model_dump() if hasattr(response, 'model_dump') else response
|
||||||
|
|
||||||
|
async def agenerate(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
image: Optional[Any] = None,
|
||||||
|
size: Optional[str] = "2K",
|
||||||
|
output_format: str = "png",
|
||||||
|
response_format: str = "url",
|
||||||
|
watermark: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""异步生成图片"""
|
||||||
|
return self.generate(prompt, image, size, output_format, response_format, watermark, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class RedBearVideoGenerator:
|
||||||
|
"""视频生成模型封装"""
|
||||||
|
|
||||||
|
def __init__(self, config: RedBearModelConfig):
|
||||||
|
self._config = config
|
||||||
|
self._client = self._create_client(config)
|
||||||
|
|
||||||
|
def _create_client(self, config: RedBearModelConfig):
|
||||||
|
"""根据 provider 创建客户端"""
|
||||||
|
provider = config.provider.lower()
|
||||||
|
|
||||||
|
if provider == ModelProvider.VOLCANO:
|
||||||
|
return Ark(api_key=config.api_key, base_url=config.base_url)
|
||||||
|
else:
|
||||||
|
raise BusinessException(
|
||||||
|
f"不支持的视频生成提供商: {provider}",
|
||||||
|
code=BizCode.PROVIDER_NOT_SUPPORTED
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
image_url: Optional[str] = None,
|
||||||
|
first_frame_url: Optional[str] = None,
|
||||||
|
last_frame_url: Optional[str] = None,
|
||||||
|
reference_images: Optional[list] = None,
|
||||||
|
draft_task_id: Optional[str] = None,
|
||||||
|
duration: Optional[int] = None,
|
||||||
|
frames: Optional[int] = None,
|
||||||
|
ratio: Optional[str] = None,
|
||||||
|
resolution: Optional[str] = None,
|
||||||
|
generate_audio: bool = False,
|
||||||
|
watermark: bool = False,
|
||||||
|
camera_fixed: bool = False,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
return_last_frame: bool = False,
|
||||||
|
service_tier: str = "default",
|
||||||
|
execution_expires_after: Optional[int] = None,
|
||||||
|
draft: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
生成视频
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: 提示词
|
||||||
|
image_url: 首帧图片URL(图生视频-基于首帧)
|
||||||
|
first_frame_url: 首帧图片URL(图生视频-基于首尾帧)
|
||||||
|
last_frame_url: 尾帧图片URL(图生视频-基于首尾帧)
|
||||||
|
reference_images: 参考图片URL列表(图生视频-基于参考图)
|
||||||
|
draft_task_id: Draft任务ID(基于Draft生成正式视频)
|
||||||
|
duration: 视频时长(秒),与frames二选一
|
||||||
|
frames: 视频帧数,与duration二选一
|
||||||
|
ratio: 视频比例,如 "16:9", "9:16", "adaptive"
|
||||||
|
resolution: 视频分辨率,如 "720p", "1080p"
|
||||||
|
generate_audio: 是否生成音频
|
||||||
|
watermark: 是否添加水印
|
||||||
|
camera_fixed: 是否固定镜头
|
||||||
|
seed: 随机种子
|
||||||
|
return_last_frame: 是否返回最后一帧
|
||||||
|
service_tier: 服务层级,"default" 或 "flex"(离线推理)
|
||||||
|
execution_expires_after: 任务过期时间(秒)
|
||||||
|
draft: 是否生成样片
|
||||||
|
**kwargs: 其他参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
生成结果(包含任务ID,需要轮询获取结果)
|
||||||
|
"""
|
||||||
|
provider = self._config.provider.lower()
|
||||||
|
|
||||||
|
if provider == ModelProvider.VOLCANO:
|
||||||
|
content = [{"type": "text", "text": prompt}]
|
||||||
|
|
||||||
|
if draft_task_id:
|
||||||
|
content = [{"type": "draft_task", "draft_task": {"id": draft_task_id}}]
|
||||||
|
else:
|
||||||
|
if image_url:
|
||||||
|
content.append({"type": "image_url", "image_url": {"url": image_url}})
|
||||||
|
|
||||||
|
if first_frame_url:
|
||||||
|
content.append({"type": "image_url", "image_url": {"url": first_frame_url}, "role": "first_frame"})
|
||||||
|
if last_frame_url:
|
||||||
|
content.append({"type": "image_url", "image_url": {"url": last_frame_url}, "role": "last_frame"})
|
||||||
|
|
||||||
|
if reference_images:
|
||||||
|
for ref_url in reference_images:
|
||||||
|
content.append({"type": "image_url", "image_url": {"url": ref_url}, "role": "reference_image"})
|
||||||
|
|
||||||
|
params = {"model": self._config.model_name, "content": content, "watermark": watermark}
|
||||||
|
|
||||||
|
if duration:
|
||||||
|
params["duration"] = duration
|
||||||
|
if frames:
|
||||||
|
params["frames"] = frames
|
||||||
|
if ratio:
|
||||||
|
params["ratio"] = ratio
|
||||||
|
if resolution:
|
||||||
|
params["resolution"] = resolution
|
||||||
|
if generate_audio:
|
||||||
|
params["generate_audio"] = generate_audio
|
||||||
|
if camera_fixed:
|
||||||
|
params["camera_fixed"] = camera_fixed
|
||||||
|
if seed is not None:
|
||||||
|
params["seed"] = seed
|
||||||
|
if return_last_frame:
|
||||||
|
params["return_last_frame"] = return_last_frame
|
||||||
|
if service_tier != "default":
|
||||||
|
params["service_tier"] = service_tier
|
||||||
|
if execution_expires_after:
|
||||||
|
params["execution_expires_after"] = execution_expires_after
|
||||||
|
if draft:
|
||||||
|
params["draft"] = draft
|
||||||
|
|
||||||
|
params.update(kwargs)
|
||||||
|
response = self._client.content_generation.tasks.create(**params)
|
||||||
|
else:
|
||||||
|
raise BusinessException(
|
||||||
|
f"不支持的提供商: {provider}",
|
||||||
|
code=BizCode.PROVIDER_NOT_SUPPORTED
|
||||||
|
)
|
||||||
|
|
||||||
|
return response.model_dump() if hasattr(response, 'model_dump') else response
|
||||||
|
|
||||||
|
async def agenerate(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
image_url: Optional[str] = None,
|
||||||
|
duration: Optional[int] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""异步生成视频"""
|
||||||
|
return self.generate(prompt, image_url=image_url, duration=duration, **kwargs)
|
||||||
|
|
||||||
|
def get_task_status(self, task_id: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
查询视频生成任务状态
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: 任务ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
任务状态信息
|
||||||
|
"""
|
||||||
|
provider = self._config.provider.lower()
|
||||||
|
|
||||||
|
if provider == ModelProvider.VOLCANO:
|
||||||
|
response = self._client.content_generation.tasks.get(task_id=task_id)
|
||||||
|
return response.model_dump() if hasattr(response, 'model_dump') else response
|
||||||
|
else:
|
||||||
|
raise BusinessException(
|
||||||
|
f"不支持的提供商: {provider}",
|
||||||
|
code=BizCode.PROVIDER_NOT_SUPPORTED
|
||||||
|
)
|
||||||
|
|
||||||
|
async def aget_task_status(self, task_id: str) -> Dict[str, Any]:
|
||||||
|
"""异步查询任务状态"""
|
||||||
|
return self.get_task_status(task_id)
|
||||||
|
|
||||||
|
def list_tasks(self, page_size: int = 10, status: Optional[str] = None, **kwargs) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
查询视频生成任务列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
page_size: 每页数量
|
||||||
|
status: 任务状态筛选,如 "succeeded", "failed", "pending"
|
||||||
|
**kwargs: 其他参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
任务列表
|
||||||
|
"""
|
||||||
|
provider = self._config.provider.lower()
|
||||||
|
|
||||||
|
if provider == ModelProvider.VOLCANO:
|
||||||
|
params = {"page_size": page_size}
|
||||||
|
if status:
|
||||||
|
params["status"] = status
|
||||||
|
params.update(kwargs)
|
||||||
|
response = self._client.content_generation.tasks.list(**params)
|
||||||
|
return response.model_dump() if hasattr(response, 'model_dump') else response
|
||||||
|
else:
|
||||||
|
raise BusinessException(
|
||||||
|
f"不支持的提供商: {provider}",
|
||||||
|
code=BizCode.PROVIDER_NOT_SUPPORTED
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_task(self, task_id: str) -> None:
|
||||||
|
"""
|
||||||
|
删除或取消视频生成任务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: 任务ID
|
||||||
|
"""
|
||||||
|
provider = self._config.provider.lower()
|
||||||
|
|
||||||
|
if provider == ModelProvider.VOLCANO:
|
||||||
|
self._client.content_generation.tasks.delete(task_id=task_id)
|
||||||
|
else:
|
||||||
|
raise BusinessException(
|
||||||
|
f"不支持的提供商: {provider}",
|
||||||
|
code=BizCode.PROVIDER_NOT_SUPPORTED
|
||||||
|
)
|
||||||
334
api/app/core/models/scripts/volcano_models.yaml
Normal file
334
api/app/core/models/scripts/volcano_models.yaml
Normal file
@@ -0,0 +1,334 @@
|
|||||||
|
provider: volcano
|
||||||
|
models:
|
||||||
|
# Doubao-Seed 2.0 系列
|
||||||
|
- name: doubao-seed-2-0-pro-260215
|
||||||
|
type: chat
|
||||||
|
provider: volcano
|
||||||
|
description: 旗舰级全能通用模型,面向 Agent 时代的复杂推理与长链路任务执行场景。强调多模态理解、长上下文推理、结构化生成与工具增强执行。复杂指令与多约束执行能力突出,可稳定应对多步复杂规划、复杂图文推理、视频内容理解与高难度分析等场景。侧重长链路推理能力与复杂任务稳定性,适配真实业务中的复杂场景。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seed-2-0-lite-260215
|
||||||
|
type: chat
|
||||||
|
provider: volcano
|
||||||
|
description: 面向高频企业场景兼顾性能与成本的均衡型模型,综合能力超越上一代Doubao-Seed-1.8。胜任非结构化信息处理、内容创作、搜索推荐、数据分析等生产型工作,支持长上下文、多源信息融合、多步指令执行与高保真结构化输出。在保障稳定效果的同时显著优化成本。兼顾生成质量与响应速度,适合作为通用生产级模型。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seed-2-0-mini-260215
|
||||||
|
type: chat
|
||||||
|
provider: volcano
|
||||||
|
description: 面向低时延、高并发与成本敏感场景,提供极致的模型推理速度。模型效果与Doubao-Seed-1.6相当。支持256k上下文、4档思考长度和多模态理解,适合成本和速度优先的轻量级任务。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seed-2-0-code-preview-260215
|
||||||
|
type: chat
|
||||||
|
provider: volcano
|
||||||
|
description: 面向真实编程环境优化的 Coding 模型,能稳定调用 Claude Code 等常见 IDE 中的工具。模型特别优化了前端能力,在使用常见的前端框架时能有良好表现。模型支持使用 Skills,可以配合多种自定义技能使用。Seed 2.0 的编程加强版,更适合 Agentic Coding。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 代码模型
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
# Doubao-Seed 1.x 系列
|
||||||
|
- name: doubao-seed-1-8-251228
|
||||||
|
type: chat
|
||||||
|
provider: volcano
|
||||||
|
description: Doubao-Seed-1.8 面向多模态 Agent 场景定向优化。Agent 能力上,Tool Use、复杂指令遵循等能力均大幅增强。多模态理解方面,视觉基础能力显著提升,可低帧率理解超长视频,视频运动理解、复杂空间理解及文档结构化解析能力也有所优化,还原生支持智能上下文管理,用户可配置上下文策略。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seed-1-6-251015
|
||||||
|
type: chat
|
||||||
|
provider: volcano
|
||||||
|
description: Doubao-Seed-1.6全新多模态深度思考模型,同时支持minimal/low/medium/high 四种reasoning effort。 更强模型效果,服务复杂任务和有挑战场景。支持 256k 上下文窗口,输出长度支持最大 32k tokens。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seed-1-6-lite-251015
|
||||||
|
type: chat
|
||||||
|
provider: volcano
|
||||||
|
description: 更高性价比,常见任务的最佳选择,支持minimal、low、medium、high 四种reasoning_effort思考深度
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seed-1-6-flash-250828
|
||||||
|
type: chat
|
||||||
|
provider: volcano
|
||||||
|
description: Doubao-Seed-1.6-flash推理速度极致的多模态深度思考模型,TPOT低至10ms; 同时支持文本和视觉理解,文本理解能力超过上一代lite,视觉理解比肩友商pro系列模型。支持 256k 上下文窗口,输出长度支持最大 16k tokens。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seed-code-preview-251028
|
||||||
|
type: chat
|
||||||
|
provider: volcano
|
||||||
|
description: 面向Agentic编程任务进行了深度优化。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 代码模型
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seed-1-6-vision-250815
|
||||||
|
type: chat
|
||||||
|
provider: volcano
|
||||||
|
description: 全新Doubao-Seed-1.6系列视觉深度思考模型,视觉理解能力显著增强,并支持image_process视觉工具
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
# Doubao 1.5 系列
|
||||||
|
- name: doubao-1-5-vision-pro-32k-250115
|
||||||
|
type: chat
|
||||||
|
provider: volcano
|
||||||
|
description: 全新升级的多模态大模型,支持任意分辨率和极端长宽比图像识别,增强视觉推理、文档识别、细节信息理解和指令遵循能力。支持 32k 上下文窗口,输出长度支持最大 12k tokens。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-1-5-pro-32k-250115
|
||||||
|
type: chat
|
||||||
|
provider: volcano
|
||||||
|
description: 全新一代主力模型,性能全面升级,在知识、代码、推理等方面表现卓越。最大支持 128k 上下文窗口,输出长度支持最大 12k tokens。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability: []
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-1-5-lite-32k-250115
|
||||||
|
type: chat
|
||||||
|
provider: volcano
|
||||||
|
description: 全新一代轻量版模型,极致响应速度,效果与时延均达到全球一流水平。支持 32k 上下文窗口,输出长度支持最大 12k tokens。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability: []
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
# Doubao-Seedance 视频生成系列
|
||||||
|
- name: doubao-seedance-1-5-pro-251215
|
||||||
|
type: video
|
||||||
|
provider: volcano
|
||||||
|
description: 豆包视频生成模型Seedance 1.5 pro 作为全球领先的视频生成模型,可生成音画高精同步的视频内容。支持多人多语言对白,全面覆盖环境音、动作音、合成音、乐器音、背景音及人声,支持首尾帧,实现影视级叙事效果,满足影视、漫剧、电商及广告领域的高阶创作需求。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 视频生成
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seedance-1-0-pro-250528
|
||||||
|
type: video
|
||||||
|
provider: volcano
|
||||||
|
description: 一款支持多镜头叙事的视频生成基础模型,在各维度表现出色。它在语义理解与指令遵循能力上取得突破,能生成运动流畅、细节丰富、风格多样且具备影视级美感的 1080P 高清视频
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 视频生成
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seedance-1-0-pro-fast-251015
|
||||||
|
type: video
|
||||||
|
provider: volcano
|
||||||
|
description: 一款价格触底、效能封顶的全面模型,在视频生成质量、速度、价格之间取得了卓越平衡。它继承了Seedance 1.0 pro 核心优势,同时生成速度提升、价格更具竞争力,为创作者带来效率与成本双重优化的体验。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 视频生成
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seedance-1-0-lite-i2v-250428
|
||||||
|
type: video
|
||||||
|
provider: volcano
|
||||||
|
description: 基于首帧图片、尾帧图片(可选)、参考图片(可选)和文本提示词(可选)相结合的方式生成视频
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 视频生成
|
||||||
|
- 图生视频
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seedance-1-0-lite-t2v-250428
|
||||||
|
type: video
|
||||||
|
provider: volcano
|
||||||
|
description: 基于文本提示词生成视频
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability: []
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 视频生成
|
||||||
|
- 文生视频
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
# Doubao-Seedream 图像生成系列
|
||||||
|
- name: doubao-seedream-5-0-260128
|
||||||
|
type: image
|
||||||
|
provider: volcano
|
||||||
|
description: 字节跳动发布的最新图像创作模型。该模型首次搭载联网检索功能,能融合实时网络信息,提升生图时效性。同时,模型的聪明度进一步升级,能够精准解析复杂指令和视觉内容。此外,模型在世界知识广度、参考一致性及专业场景生成质量上均有增强,可更好地满足企业级视觉创作需求。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 图像生成
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seedream-4-5-251128
|
||||||
|
type: image
|
||||||
|
provider: volcano
|
||||||
|
description: 字节跳动最新推出的图像多模态模型,整合了文生图、图生图、组图输出等能力,融合常识和推理能力。相比前代4.0模型生成效果大幅提升,具备更好的编辑一致性和多图融合效果,能更精准的控制画面细节,小字、小人脸生成更自然,图片排版、色彩更和谐,美感提升。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 图像生成
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seedream-4-0-250828
|
||||||
|
type: image
|
||||||
|
provider: volcano
|
||||||
|
description: 基于领先架构的SOTA级多模态图像创作模型,其生成美感、指令遵循、结构完整度、主体保持一致性处于世界头部水平。模型采用同一套架构实现文生图与编辑能力的统一,原生支持文本 、单图和多图输入,并能通过对提示词的深度推理,自动适配最优的图像比例尺寸与生成数量,可一次性连续输出最多 15 张内容关联的图像,支持 4K 超高清输出。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 图像生成
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seedream-3-0-t2i-250415
|
||||||
|
type: image
|
||||||
|
provider: volcano
|
||||||
|
description: 一款支持原生高分辨率的中英双语图像生成基础模型,综合能力媲美GPT-4o,处于世界第一梯队。支持原生 2K 分辨率输出;响应速度更快;小字生成更准确,文本排版效果增强;指令遵循能力强,美感&结构提升,保真度和细节表现较好。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability: []
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 图像生成
|
||||||
|
- 文生图
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
# Doubao 翻译系列
|
||||||
|
- name: doubao-seed-translation-250915
|
||||||
|
type: chat
|
||||||
|
provider: volcano
|
||||||
|
description: 通用多语言翻译模型,支持30余种语言互译,支持 4K 上下文窗口,输出长度支持最大 3K tokens
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability: []
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 翻译模型
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
# Doubao Embedding 系列
|
||||||
|
- name: doubao-embedding-vision-251215
|
||||||
|
type: embedding
|
||||||
|
provider: volcano
|
||||||
|
description: 主要面向图文多模向量检索的使用场景,支持图片输入及中、英双语文本输入,最长 128K 上下文长度。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 向量模型
|
||||||
|
- 多模态模型
|
||||||
|
logo: volcano
|
||||||
@@ -61,24 +61,16 @@ class ElasticSearchConfig(BaseModel):
|
|||||||
class ElasticSearchVector(BaseVector):
|
class ElasticSearchVector(BaseVector):
|
||||||
def __init__(self, index_name: str, config: ElasticSearchConfig, embedding_config: ModelApiKey, reranker_config: ModelApiKey):
|
def __init__(self, index_name: str, config: ElasticSearchConfig, embedding_config: ModelApiKey, reranker_config: ModelApiKey):
|
||||||
super().__init__(index_name.lower())
|
super().__init__(index_name.lower())
|
||||||
# self.embeddings = XinferenceEmbeddings(
|
|
||||||
# server_url=os.getenv("XINFERENCE_URL", "http://127.0.0.1"), # Default Xinference port
|
# 初始化 Embedding 模型(自动支持火山引擎多模态)
|
||||||
# model_uid="bge-m3" # replace model_uid with the model UID return from launching the model
|
|
||||||
# )
|
|
||||||
# Remove debug printing to avoid leaking sensitive information
|
|
||||||
# print("embedding:" + embedding_config.model_name + "|" + embedding_config.provider + "|" + embedding_config.api_key + "|" + embedding_config.api_base)
|
|
||||||
self.embeddings = RedBearEmbeddings(RedBearModelConfig(
|
self.embeddings = RedBearEmbeddings(RedBearModelConfig(
|
||||||
model_name=embedding_config.model_name,
|
model_name=embedding_config.model_name,
|
||||||
provider=embedding_config.provider,
|
provider=embedding_config.provider,
|
||||||
api_key=embedding_config.api_key,
|
api_key=embedding_config.api_key,
|
||||||
base_url=embedding_config.api_base
|
base_url=embedding_config.api_base
|
||||||
))
|
))
|
||||||
# self.reranker = XinferenceRerank(
|
self.is_multimodal_embedding = self.embeddings.is_multimodal_supported()
|
||||||
# server_url=os.getenv("XINFERENCE_URL", "http://127.0.0.1"),
|
|
||||||
# model_uid="bge-reranker-large"
|
|
||||||
# )
|
|
||||||
# Remove debug printing to avoid leaking sensitive information
|
|
||||||
# print("reranker:"+ reranker_config.model_name + "|" + reranker_config.provider + "|" + reranker_config.api_key + "|" + reranker_config.api_base)
|
|
||||||
self.reranker = RedBearRerank(RedBearModelConfig(
|
self.reranker = RedBearRerank(RedBearModelConfig(
|
||||||
model_name=reranker_config.model_name,
|
model_name=reranker_config.model_name,
|
||||||
provider=reranker_config.provider,
|
provider=reranker_config.provider,
|
||||||
@@ -144,7 +136,11 @@ class ElasticSearchVector(BaseVector):
|
|||||||
def add_chunks(self, chunks: list[DocumentChunk], **kwargs):
|
def add_chunks(self, chunks: list[DocumentChunk], **kwargs):
|
||||||
# 实现 Elasticsearch 保存向量
|
# 实现 Elasticsearch 保存向量
|
||||||
texts = [chunk.page_content for chunk in chunks]
|
texts = [chunk.page_content for chunk in chunks]
|
||||||
embeddings = self.embeddings.embed_documents(list(texts))
|
if self.is_multimodal_embedding:
|
||||||
|
# 火山引擎多模态 Embedding
|
||||||
|
embeddings = self.embeddings.embed_batch(texts)
|
||||||
|
else:
|
||||||
|
embeddings = self.embeddings.embed_documents(list(texts))
|
||||||
self.create(chunks, embeddings, **kwargs)
|
self.create(chunks, embeddings, **kwargs)
|
||||||
|
|
||||||
def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs):
|
def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs):
|
||||||
@@ -394,7 +390,11 @@ class ElasticSearchVector(BaseVector):
|
|||||||
updated count.
|
updated count.
|
||||||
"""
|
"""
|
||||||
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3"
|
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3"
|
||||||
chunk.vector = self.embeddings.embed_query(chunk.page_content)
|
if self.is_multimodal_embedding:
|
||||||
|
# 火山引擎多模态 Embedding
|
||||||
|
chunk.vector = self.embeddings.embed_text(chunk.page_content)
|
||||||
|
else:
|
||||||
|
chunk.vector = self.embeddings.embed_query(chunk.page_content)
|
||||||
|
|
||||||
body = {
|
body = {
|
||||||
"script": {
|
"script": {
|
||||||
@@ -454,7 +454,11 @@ class ElasticSearchVector(BaseVector):
|
|||||||
|
|
||||||
def search_by_vector(self, query: str, **kwargs: Any) -> list[DocumentChunk]:
|
def search_by_vector(self, query: str, **kwargs: Any) -> list[DocumentChunk]:
|
||||||
"""Search the nearest neighbors to a vector."""
|
"""Search the nearest neighbors to a vector."""
|
||||||
query_vector = self.embeddings.embed_query(query)
|
if self.is_multimodal_embedding:
|
||||||
|
# 火山引擎多模态 Embedding
|
||||||
|
query_vector = self.embeddings.embed_text(query)
|
||||||
|
else:
|
||||||
|
query_vector = self.embeddings.embed_query(query)
|
||||||
top_k = kwargs.get("top_k", 1024)
|
top_k = kwargs.get("top_k", 1024)
|
||||||
score_threshold = float(kwargs.get("score_threshold") or 0.3)
|
score_threshold = float(kwargs.get("score_threshold") or 0.3)
|
||||||
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3"
|
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3"
|
||||||
|
|||||||
@@ -109,17 +109,13 @@ class StorageBackend(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_url(self, file_key: str, expires: int = 3600) -> str:
|
async def get_url(
|
||||||
"""
|
self,
|
||||||
Get an access URL for the file.
|
file_key: str,
|
||||||
|
expires: int = 3600,
|
||||||
Args:
|
file_name: Optional[str] = None
|
||||||
file_key: Unique identifier for the file in the storage system.
|
) -> str:
|
||||||
expires: URL validity period in seconds (default: 1 hour).
|
"""Get an access URL for the file."""
|
||||||
|
|
||||||
Returns:
|
|
||||||
URL for accessing the file.
|
|
||||||
"""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def get_permanent_url(self, file_key: str) -> Optional[str]:
|
async def get_permanent_url(self, file_key: str) -> Optional[str]:
|
||||||
|
|||||||
@@ -210,7 +210,12 @@ class LocalStorage(StorageBackend):
|
|||||||
cause=e,
|
cause=e,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_url(self, file_key: str, expires: int = 3600) -> str:
|
async def get_url(
|
||||||
|
self,
|
||||||
|
file_key: str,
|
||||||
|
expires: int = 3600,
|
||||||
|
file_name: Optional[str] = None
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Get an access URL for the file.
|
Get an access URL for the file.
|
||||||
|
|
||||||
@@ -220,6 +225,7 @@ class LocalStorage(StorageBackend):
|
|||||||
Args:
|
Args:
|
||||||
file_key: Unique identifier for the file in the storage system.
|
file_key: Unique identifier for the file in the storage system.
|
||||||
expires: URL validity period in seconds (not used for local storage).
|
expires: URL validity period in seconds (not used for local storage).
|
||||||
|
file_name: If set, adds Content-Disposition: attachment to force download.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A relative URL path for accessing the file.
|
A relative URL path for accessing the file.
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ Storage Service (OSS) using the oss2 SDK.
|
|||||||
|
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
|
import urllib.parse
|
||||||
from typing import AsyncIterator, Optional
|
from typing import AsyncIterator, Optional
|
||||||
|
|
||||||
import oss2
|
import oss2
|
||||||
@@ -242,24 +243,33 @@ class OSSStorage(StorageBackend):
|
|||||||
logger.error(f"Failed to check file existence in OSS {file_key}: {e}")
|
logger.error(f"Failed to check file existence in OSS {file_key}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def get_url(self, file_key: str, expires: int = 3600) -> str:
|
async def get_url(
|
||||||
|
self,
|
||||||
|
file_key: str,
|
||||||
|
expires: int = 3600,
|
||||||
|
file_name: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Get a presigned URL for accessing the file.
|
Get a presigned URL for accessing the file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_key: Unique identifier for the file in the storage system.
|
file_key: Unique identifier for the file in the storage system.
|
||||||
expires: URL validity period in seconds (default: 1 hour).
|
expires: URL validity period in seconds (default: 1 hour).
|
||||||
|
file_name: If set, adds Content-Disposition: attachment to force download.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A presigned URL for accessing the file.
|
A presigned URL for accessing the file.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
url = self.bucket.sign_url("GET", file_key, expires)
|
params = {}
|
||||||
|
if file_name:
|
||||||
|
filename_encoded = urllib.parse.quote(file_name.encode("utf-8"))
|
||||||
|
params["response-content-disposition"] = f"attachment; filename*=UTF-8''{filename_encoded}"
|
||||||
|
url = self.bucket.sign_url("GET", file_key, expires, params=params if params else None)
|
||||||
logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s")
|
logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s")
|
||||||
return url
|
return url
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
|
logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
|
||||||
# Return a basic URL format as fallback
|
|
||||||
return f"https://{self.bucket_name}.{self.endpoint.replace('https://', '').replace('http://', '')}/{file_key}"
|
return f"https://{self.bucket_name}.{self.endpoint.replace('https://', '').replace('http://', '')}/{file_key}"
|
||||||
|
|
||||||
async def get_permanent_url(self, file_key: str) -> str:
|
async def get_permanent_url(self, file_key: str) -> str:
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ using the boto3 SDK.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import io
|
import io
|
||||||
|
import urllib.parse
|
||||||
import logging
|
import logging
|
||||||
from typing import AsyncIterator, Optional
|
from typing import AsyncIterator, Optional
|
||||||
|
|
||||||
@@ -352,31 +353,37 @@ class S3Storage(StorageBackend):
|
|||||||
logger.error(f"Failed to check file existence in S3 {file_key}: {e}")
|
logger.error(f"Failed to check file existence in S3 {file_key}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def get_url(self, file_key: str, expires: int = 3600) -> str:
|
async def get_url(
|
||||||
|
self,
|
||||||
|
file_key: str,
|
||||||
|
expires: int = 3600,
|
||||||
|
file_name: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Get a presigned URL for accessing the file.
|
Get a presigned URL for accessing the file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_key: Unique identifier for the file in the storage system.
|
file_key: Unique identifier for the file in the storage system.
|
||||||
expires: URL validity period in seconds (default: 1 hour).
|
expires: URL validity period in seconds (default: 1 hour).
|
||||||
|
file_name: If set, adds Content-Disposition: attachment to force download.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A presigned URL for accessing the file.
|
A presigned URL for accessing the file.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
params = {"Bucket": self.bucket_name, "Key": file_key}
|
||||||
|
if file_name:
|
||||||
|
filename_encoded = urllib.parse.quote(file_name.encode("utf-8"))
|
||||||
|
params["ResponseContentDisposition"] = f"attachment; filename*=UTF-8''{filename_encoded}"
|
||||||
url = self.client.generate_presigned_url(
|
url = self.client.generate_presigned_url(
|
||||||
"get_object",
|
"get_object",
|
||||||
Params={
|
Params=params,
|
||||||
"Bucket": self.bucket_name,
|
|
||||||
"Key": file_key,
|
|
||||||
},
|
|
||||||
ExpiresIn=expires,
|
ExpiresIn=expires,
|
||||||
)
|
)
|
||||||
logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s")
|
logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s")
|
||||||
return url
|
return url
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
|
logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
|
||||||
# Return a basic URL format as fallback
|
|
||||||
return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}"
|
return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}"
|
||||||
|
|
||||||
async def get_permanent_url(self, file_key: str) -> str:
|
async def get_permanent_url(self, file_key: str) -> str:
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from typing import Any
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from app.core.workflow.adapters.errors import ExceptionDefineition
|
from app.core.workflow.adapters.errors import ExceptionDefinition
|
||||||
from app.schemas.workflow_schema import (
|
from app.schemas.workflow_schema import (
|
||||||
EdgeDefinition,
|
EdgeDefinition,
|
||||||
NodeDefinition,
|
NodeDefinition,
|
||||||
@@ -40,8 +40,8 @@ class WorkflowParserResult(BaseModel):
|
|||||||
edges: list[EdgeDefinition] = Field(default_factory=list)
|
edges: list[EdgeDefinition] = Field(default_factory=list)
|
||||||
nodes: list[NodeDefinition] = Field(default_factory=list)
|
nodes: list[NodeDefinition] = Field(default_factory=list)
|
||||||
variables: list[VariableDefinition] = Field(default_factory=list)
|
variables: list[VariableDefinition] = Field(default_factory=list)
|
||||||
warnings: list[ExceptionDefineition] = Field(default_factory=list)
|
warnings: list[ExceptionDefinition] = Field(default_factory=list)
|
||||||
errors: list[ExceptionDefineition] = Field(default_factory=list)
|
errors: list[ExceptionDefinition] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowImportResult(BaseModel):
|
class WorkflowImportResult(BaseModel):
|
||||||
@@ -51,8 +51,8 @@ class WorkflowImportResult(BaseModel):
|
|||||||
edges: list[EdgeDefinition] = Field(default_factory=list)
|
edges: list[EdgeDefinition] = Field(default_factory=list)
|
||||||
nodes: list[NodeDefinition] = Field(default_factory=list)
|
nodes: list[NodeDefinition] = Field(default_factory=list)
|
||||||
variables: list[VariableDefinition] = Field(default_factory=list)
|
variables: list[VariableDefinition] = Field(default_factory=list)
|
||||||
warnings: list[ExceptionDefineition] = Field(default_factory=list)
|
warnings: list[ExceptionDefinition] = Field(default_factory=list)
|
||||||
errors: list[ExceptionDefineition] = Field(default_factory=list)
|
errors: list[ExceptionDefinition] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class BasePlatformAdapter(ABC):
|
class BasePlatformAdapter(ABC):
|
||||||
|
|||||||
@@ -9,9 +9,9 @@ from urllib.parse import quote
|
|||||||
|
|
||||||
from app.core.workflow.adapters.base_converter import BaseConverter
|
from app.core.workflow.adapters.base_converter import BaseConverter
|
||||||
from app.core.workflow.adapters.errors import (
|
from app.core.workflow.adapters.errors import (
|
||||||
UnsupportVariableType,
|
UnsupportedVariableType,
|
||||||
UnknowModelWarning,
|
UnknownModelWarning,
|
||||||
ExceptionDefineition,
|
ExceptionDefinition,
|
||||||
ExceptionType
|
ExceptionType
|
||||||
)
|
)
|
||||||
from app.core.workflow.nodes.assigner.config import AssignmentItem
|
from app.core.workflow.nodes.assigner.config import AssignmentItem
|
||||||
@@ -54,7 +54,7 @@ from app.core.workflow.nodes.http_request.config import (
|
|||||||
HttpFormData,
|
HttpFormData,
|
||||||
HttpTimeOutConfig,
|
HttpTimeOutConfig,
|
||||||
HttpRetryConfig,
|
HttpRetryConfig,
|
||||||
HttpErrorDefaultTamplete,
|
HttpErrorDefaultTemplate,
|
||||||
HttpErrorHandleConfig
|
HttpErrorHandleConfig
|
||||||
)
|
)
|
||||||
from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig
|
from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig
|
||||||
@@ -108,7 +108,7 @@ class DifyConverter(BaseConverter):
|
|||||||
try:
|
try:
|
||||||
return config.model_validate(value)
|
return config.model_validate(value)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.errors.append(ExceptionDefineition(
|
self.errors.append(ExceptionDefinition(
|
||||||
type=ExceptionType.CONFIG,
|
type=ExceptionType.CONFIG,
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
node_name=node_name,
|
node_name=node_name,
|
||||||
@@ -138,7 +138,7 @@ class DifyConverter(BaseConverter):
|
|||||||
var_selector = mapping.get(var_selector, var_selector)
|
var_selector = mapping.get(var_selector, var_selector)
|
||||||
return var_selector
|
return var_selector
|
||||||
|
|
||||||
def _process_list_variable_litearl(self, variable_selector: list) -> str | None:
|
def _process_list_variable_literal(self, variable_selector: list) -> str | None:
|
||||||
if not self.process_var_selector(".".join(variable_selector)):
|
if not self.process_var_selector(".".join(variable_selector)):
|
||||||
return None
|
return None
|
||||||
return "{{" + self.process_var_selector(".".join(variable_selector)) + "}}"
|
return "{{" + self.process_var_selector(".".join(variable_selector)) + "}}"
|
||||||
@@ -269,7 +269,7 @@ class DifyConverter(BaseConverter):
|
|||||||
var_type = self.variable_type_map(var["type"])
|
var_type = self.variable_type_map(var["type"])
|
||||||
if not var_type:
|
if not var_type:
|
||||||
self.errors.append(
|
self.errors.append(
|
||||||
UnsupportVariableType(
|
UnsupportedVariableType(
|
||||||
scope=node["id"],
|
scope=node["id"],
|
||||||
name=var["variable"],
|
name=var["variable"],
|
||||||
var_type=var["type"],
|
var_type=var["type"],
|
||||||
@@ -281,7 +281,7 @@ class DifyConverter(BaseConverter):
|
|||||||
|
|
||||||
if var_type in ["file", "array[file]"]:
|
if var_type in ["file", "array[file]"]:
|
||||||
self.errors.append(
|
self.errors.append(
|
||||||
ExceptionDefineition(
|
ExceptionDefinition(
|
||||||
type=ExceptionType.VARIABLE,
|
type=ExceptionType.VARIABLE,
|
||||||
node_id=node["id"],
|
node_id=node["id"],
|
||||||
node_name=node_data["title"],
|
node_name=node_data["title"],
|
||||||
@@ -311,7 +311,7 @@ class DifyConverter(BaseConverter):
|
|||||||
def convert_question_classifier_node_config(self, node: dict) -> dict:
|
def convert_question_classifier_node_config(self, node: dict) -> dict:
|
||||||
node_data = node["data"]
|
node_data = node["data"]
|
||||||
self.warnings.append(
|
self.warnings.append(
|
||||||
UnknowModelWarning(
|
UnknownModelWarning(
|
||||||
node_id=node["id"],
|
node_id=node["id"],
|
||||||
node_name=node_data["title"],
|
node_name=node_data["title"],
|
||||||
model_name=node_data["model"].get("name")
|
model_name=node_data["model"].get("name")
|
||||||
@@ -327,7 +327,7 @@ class DifyConverter(BaseConverter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
result = QuestionClassifierNodeConfig.model_construct(
|
result = QuestionClassifierNodeConfig.model_construct(
|
||||||
input_variable=self._process_list_variable_litearl(node_data.get("query_variable_selector")),
|
input_variable=self._process_list_variable_literal(node_data.get("query_variable_selector")),
|
||||||
user_supplement_prompt=self.trans_variable_format(node_data.get("instructions", "")),
|
user_supplement_prompt=self.trans_variable_format(node_data.get("instructions", "")),
|
||||||
categories=categories,
|
categories=categories,
|
||||||
).model_dump()
|
).model_dump()
|
||||||
@@ -337,13 +337,13 @@ class DifyConverter(BaseConverter):
|
|||||||
def convert_llm_node_config(self, node: dict) -> dict:
|
def convert_llm_node_config(self, node: dict) -> dict:
|
||||||
node_data = node["data"]
|
node_data = node["data"]
|
||||||
self.warnings.append(
|
self.warnings.append(
|
||||||
UnknowModelWarning(
|
UnknownModelWarning(
|
||||||
node_id=node["id"],
|
node_id=node["id"],
|
||||||
node_name=node_data["title"],
|
node_name=node_data["title"],
|
||||||
model_name=node_data["model"].get("name")
|
model_name=node_data["model"].get("name")
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
context = self._process_list_variable_litearl(node_data["context"]["variable_selector"])
|
context = self._process_list_variable_literal(node_data["context"]["variable_selector"])
|
||||||
memory = MemoryWindowSetting(
|
memory = MemoryWindowSetting(
|
||||||
enable=bool(node_data.get("memory")),
|
enable=bool(node_data.get("memory")),
|
||||||
enable_window=bool(node_data.get("memory", {}).get("window", {}).get("enabled", False)),
|
enable_window=bool(node_data.get("memory", {}).get("window", {}).get("enabled", False)),
|
||||||
@@ -367,7 +367,7 @@ class DifyConverter(BaseConverter):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
vision = node_data["vision"]["enabled"]
|
vision = node_data["vision"]["enabled"]
|
||||||
vision_input = self._process_list_variable_litearl(
|
vision_input = self._process_list_variable_literal(
|
||||||
node_data["vision"]["configs"]["variable_selector"]
|
node_data["vision"]["configs"]["variable_selector"]
|
||||||
) if vision else None
|
) if vision else None
|
||||||
result = LLMNodeConfig.model_construct(
|
result = LLMNodeConfig.model_construct(
|
||||||
@@ -433,7 +433,7 @@ class DifyConverter(BaseConverter):
|
|||||||
conditions.append(
|
conditions.append(
|
||||||
LoopConditionDetail.model_construct(
|
LoopConditionDetail.model_construct(
|
||||||
operator=self.convert_compare_operator(condition["comparison_operator"]),
|
operator=self.convert_compare_operator(condition["comparison_operator"]),
|
||||||
left=self._process_list_variable_litearl(condition["variable_selector"]),
|
left=self._process_list_variable_literal(condition["variable_selector"]),
|
||||||
right=self.trans_variable_format(
|
right=self.trans_variable_format(
|
||||||
right_value
|
right_value
|
||||||
) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type(
|
) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type(
|
||||||
@@ -453,7 +453,7 @@ class DifyConverter(BaseConverter):
|
|||||||
right_input_type = variable["value_type"]
|
right_input_type = variable["value_type"]
|
||||||
right_value_type = self.variable_type_map(variable["var_type"])
|
right_value_type = self.variable_type_map(variable["var_type"])
|
||||||
if right_input_type == ValueInputType.VARIABLE:
|
if right_input_type == ValueInputType.VARIABLE:
|
||||||
right_value = self._process_list_variable_litearl(variable.get("value", ""))
|
right_value = self._process_list_variable_literal(variable.get("value", ""))
|
||||||
else:
|
else:
|
||||||
right_value = self.convert_variable_type(right_value_type, variable.get("value", ""))
|
right_value = self.convert_variable_type(right_value_type, variable.get("value", ""))
|
||||||
loop_variables.append(
|
loop_variables.append(
|
||||||
@@ -475,10 +475,10 @@ class DifyConverter(BaseConverter):
|
|||||||
def convert_iteration_node_config(self, node: dict) -> dict:
|
def convert_iteration_node_config(self, node: dict) -> dict:
|
||||||
node_data = node["data"]
|
node_data = node["data"]
|
||||||
result = IterationNodeConfig.model_construct(
|
result = IterationNodeConfig.model_construct(
|
||||||
input=self._process_list_variable_litearl(node_data["iterator_selector"]),
|
input=self._process_list_variable_literal(node_data["iterator_selector"]),
|
||||||
parallel=node_data["is_parallel"],
|
parallel=node_data["is_parallel"],
|
||||||
parallel_count=node_data["parallel_nums"],
|
parallel_count=node_data["parallel_nums"],
|
||||||
output=self._process_list_variable_litearl(node_data["output_selector"]),
|
output=self._process_list_variable_literal(node_data["output_selector"]),
|
||||||
output_type=self.variable_type_map(node_data.get("output_type")),
|
output_type=self.variable_type_map(node_data.get("output_type")),
|
||||||
flatten=node_data["flatten_output"],
|
flatten=node_data["flatten_output"],
|
||||||
).model_dump()
|
).model_dump()
|
||||||
@@ -494,8 +494,8 @@ class DifyConverter(BaseConverter):
|
|||||||
continue
|
continue
|
||||||
assignments.append(
|
assignments.append(
|
||||||
AssignmentItem(
|
AssignmentItem(
|
||||||
variable_selector=self._process_list_variable_litearl(assignment["variable_selector"]),
|
variable_selector=self._process_list_variable_literal(assignment["variable_selector"]),
|
||||||
value=self._process_list_variable_litearl(
|
value=self._process_list_variable_literal(
|
||||||
assignment["value"]
|
assignment["value"]
|
||||||
) if assignment["input_type"] == ValueInputType.VARIABLE else assignment["value"],
|
) if assignment["input_type"] == ValueInputType.VARIABLE else assignment["value"],
|
||||||
operation=self.convert_assignment_operator(assignment["operation"])
|
operation=self.convert_assignment_operator(assignment["operation"])
|
||||||
@@ -514,7 +514,7 @@ class DifyConverter(BaseConverter):
|
|||||||
input_variables.append(
|
input_variables.append(
|
||||||
InputVariable.model_construct(
|
InputVariable.model_construct(
|
||||||
name=input_variable["variable"],
|
name=input_variable["variable"],
|
||||||
variable=self._process_list_variable_litearl(input_variable["value_selector"]),
|
variable=self._process_list_variable_literal(input_variable["value_selector"]),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -570,7 +570,7 @@ class DifyConverter(BaseConverter):
|
|||||||
else:
|
else:
|
||||||
if node_data["body"]["data"]:
|
if node_data["body"]["data"]:
|
||||||
body_content = (node_data["body"]["data"][0].get("value") or
|
body_content = (node_data["body"]["data"][0].get("value") or
|
||||||
self._process_list_variable_litearl(node_data["body"]["data"][0].get("file")))
|
self._process_list_variable_literal(node_data["body"]["data"][0].get("file")))
|
||||||
else:
|
else:
|
||||||
body_content = ""
|
body_content = ""
|
||||||
|
|
||||||
@@ -585,7 +585,7 @@ class DifyConverter(BaseConverter):
|
|||||||
self.trans_variable_format(key_value[0])
|
self.trans_variable_format(key_value[0])
|
||||||
] = self.trans_variable_format(key_value[1])
|
] = self.trans_variable_format(key_value[1])
|
||||||
else:
|
else:
|
||||||
self.warnings.append(ExceptionDefineition(
|
self.warnings.append(ExceptionDefinition(
|
||||||
type=ExceptionType.CONFIG,
|
type=ExceptionType.CONFIG,
|
||||||
node_id=node["id"],
|
node_id=node["id"],
|
||||||
node_name=node_data["title"],
|
node_name=node_data["title"],
|
||||||
@@ -603,7 +603,7 @@ class DifyConverter(BaseConverter):
|
|||||||
self.trans_variable_format(key_value[0])
|
self.trans_variable_format(key_value[0])
|
||||||
] = self.trans_variable_format(key_value[1])
|
] = self.trans_variable_format(key_value[1])
|
||||||
else:
|
else:
|
||||||
self.warnings.append(ExceptionDefineition(
|
self.warnings.append(ExceptionDefinition(
|
||||||
type=ExceptionType.CONFIG,
|
type=ExceptionType.CONFIG,
|
||||||
node_id=node["id"],
|
node_id=node["id"],
|
||||||
node_name=node_data["title"],
|
node_name=node_data["title"],
|
||||||
@@ -625,7 +625,7 @@ class DifyConverter(BaseConverter):
|
|||||||
default_header = var["value"]
|
default_header = var["value"]
|
||||||
elif var["key"] == "status_code":
|
elif var["key"] == "status_code":
|
||||||
default_status_code = var["value"]
|
default_status_code = var["value"]
|
||||||
default_value = HttpErrorDefaultTamplete(
|
default_value = HttpErrorDefaultTemplate(
|
||||||
body=default_body,
|
body=default_body,
|
||||||
headers=default_header,
|
headers=default_header,
|
||||||
status_code=default_status_code,
|
status_code=default_status_code,
|
||||||
@@ -668,7 +668,7 @@ class DifyConverter(BaseConverter):
|
|||||||
for variable in node_data["variables"]:
|
for variable in node_data["variables"]:
|
||||||
mapping.append(VariablesMappingConfig.model_construct(
|
mapping.append(VariablesMappingConfig.model_construct(
|
||||||
name=variable["variable"],
|
name=variable["variable"],
|
||||||
value=self._process_list_variable_litearl(variable["value_selector"])
|
value=self._process_list_variable_literal(variable["value_selector"])
|
||||||
))
|
))
|
||||||
result = JinjaRenderNodeConfig.model_construct(
|
result = JinjaRenderNodeConfig.model_construct(
|
||||||
template=node_data["template"],
|
template=node_data["template"],
|
||||||
@@ -679,14 +679,14 @@ class DifyConverter(BaseConverter):
|
|||||||
|
|
||||||
def convert_knowledge_node_config(self, node: dict) -> dict:
|
def convert_knowledge_node_config(self, node: dict) -> dict:
|
||||||
node_data = node["data"]
|
node_data = node["data"]
|
||||||
self.warnings.append(ExceptionDefineition(
|
self.warnings.append(ExceptionDefinition(
|
||||||
node_id=node["id"],
|
node_id=node["id"],
|
||||||
node_name=node_data["title"],
|
node_name=node_data["title"],
|
||||||
type=ExceptionType.CONFIG,
|
type=ExceptionType.CONFIG,
|
||||||
detail=f"Please reconfigure the Knowledge Retrieval node.",
|
detail=f"Please reconfigure the Knowledge Retrieval node.",
|
||||||
))
|
))
|
||||||
result = KnowledgeRetrievalNodeConfig.model_construct(
|
result = KnowledgeRetrievalNodeConfig.model_construct(
|
||||||
query=self._process_list_variable_litearl(node_data["query_variable_selector"]),
|
query=self._process_list_variable_literal(node_data["query_variable_selector"]),
|
||||||
).model_dump()
|
).model_dump()
|
||||||
|
|
||||||
self.config_validate(node["id"], node["data"]["title"], KnowledgeRetrievalNodeConfig, result)
|
self.config_validate(node["id"], node["data"]["title"], KnowledgeRetrievalNodeConfig, result)
|
||||||
@@ -695,7 +695,7 @@ class DifyConverter(BaseConverter):
|
|||||||
def convert_parameter_extractor_node_config(self, node: dict) -> dict:
|
def convert_parameter_extractor_node_config(self, node: dict) -> dict:
|
||||||
node_data = node["data"]
|
node_data = node["data"]
|
||||||
self.warnings.append(
|
self.warnings.append(
|
||||||
UnknowModelWarning(
|
UnknownModelWarning(
|
||||||
node_id=node["id"],
|
node_id=node["id"],
|
||||||
node_name=node_data["title"],
|
node_name=node_data["title"],
|
||||||
model_name=node_data["model"].get("name")
|
model_name=node_data["model"].get("name")
|
||||||
@@ -712,7 +712,7 @@ class DifyConverter(BaseConverter):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
result = ParameterExtractorNodeConfig.model_construct(
|
result = ParameterExtractorNodeConfig.model_construct(
|
||||||
text=self._process_list_variable_litearl(node_data["query"]),
|
text=self._process_list_variable_literal(node_data["query"]),
|
||||||
params=params,
|
params=params,
|
||||||
prompt=node_data.get("instruction")
|
prompt=node_data.get("instruction")
|
||||||
).model_dump()
|
).model_dump()
|
||||||
@@ -727,14 +727,14 @@ class DifyConverter(BaseConverter):
|
|||||||
group_type = {}
|
group_type = {}
|
||||||
if not advanced_settings or not advanced_settings["group_enabled"]:
|
if not advanced_settings or not advanced_settings["group_enabled"]:
|
||||||
group_variables = [
|
group_variables = [
|
||||||
self._process_list_variable_litearl(variable)
|
self._process_list_variable_literal(variable)
|
||||||
for variable in node_data["variables"]
|
for variable in node_data["variables"]
|
||||||
]
|
]
|
||||||
group_type["output"] = node_data["output_type"]
|
group_type["output"] = node_data["output_type"]
|
||||||
else:
|
else:
|
||||||
for group in advanced_settings["groups"]:
|
for group in advanced_settings["groups"]:
|
||||||
group_variables[group["group_name"]] = [
|
group_variables[group["group_name"]] = [
|
||||||
self._process_list_variable_litearl(variable)
|
self._process_list_variable_literal(variable)
|
||||||
for variable in group["variables"]
|
for variable in group["variables"]
|
||||||
]
|
]
|
||||||
group_type[group["group_name"]] = group["output_type"]
|
group_type[group["group_name"]] = group["output_type"]
|
||||||
@@ -751,7 +751,7 @@ class DifyConverter(BaseConverter):
|
|||||||
|
|
||||||
def convert_tool_node_config(self, node: dict) -> dict:
|
def convert_tool_node_config(self, node: dict) -> dict:
|
||||||
node_data = node["data"]
|
node_data = node["data"]
|
||||||
self.warnings.append(ExceptionDefineition(
|
self.warnings.append(ExceptionDefinition(
|
||||||
node_id=node["id"],
|
node_id=node["id"],
|
||||||
node_name=node_data["title"],
|
node_name=node_data["title"],
|
||||||
type=ExceptionType.CONFIG,
|
type=ExceptionType.CONFIG,
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from app.core.workflow.adapters.base_adapter import (
|
|||||||
WorkflowParserResult
|
WorkflowParserResult
|
||||||
)
|
)
|
||||||
from app.core.workflow.adapters.dify.converter import DifyConverter
|
from app.core.workflow.adapters.dify.converter import DifyConverter
|
||||||
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType
|
from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType
|
||||||
from app.core.workflow.nodes.enums import NodeType
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
from app.schemas.workflow_schema import (
|
from app.schemas.workflow_schema import (
|
||||||
NodeDefinition,
|
NodeDefinition,
|
||||||
@@ -85,7 +85,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
if not all(field in self.config for field in require_fields):
|
if not all(field in self.config for field in require_fields):
|
||||||
return False
|
return False
|
||||||
if self.config.get("app", {}).get("mode") == "workflow":
|
if self.config.get("app", {}).get("mode") == "workflow":
|
||||||
self.errors.append(ExceptionDefineition(
|
self.errors.append(ExceptionDefinition(
|
||||||
type=ExceptionType.PLATFORM,
|
type=ExceptionType.PLATFORM,
|
||||||
detail="workflow mode is not supported"
|
detail="workflow mode is not supported"
|
||||||
))
|
))
|
||||||
@@ -111,12 +111,12 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
edge = self._convert_edge(edge)
|
edge = self._convert_edge(edge)
|
||||||
if edge:
|
if edge:
|
||||||
self.edges.append(edge)
|
self.edges.append(edge)
|
||||||
#
|
|
||||||
for variable in self.config.get("workflow").get("conversation_variables"):
|
for variable in self.config.get("workflow").get("conversation_variables"):
|
||||||
con_var = self._convert_variable(variable)
|
con_var = self._convert_variable(variable)
|
||||||
if variable:
|
if variable:
|
||||||
self.conv_variables.append(con_var)
|
self.conv_variables.append(con_var)
|
||||||
#
|
|
||||||
# for variables in config.get("workflow").get("environment_variables"):
|
# for variables in config.get("workflow").get("environment_variables"):
|
||||||
# variable = self._convert_variable(variables)
|
# variable = self._convert_variable(variables)
|
||||||
# conv_variables.append(variable)
|
# conv_variables.append(variable)
|
||||||
@@ -152,7 +152,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
"y": node["position"]["y"] + position["y"]
|
"y": node["position"]["y"] + position["y"]
|
||||||
}
|
}
|
||||||
self.errors.append(
|
self.errors.append(
|
||||||
ExceptionDefineition(
|
ExceptionDefinition(
|
||||||
type=ExceptionType.NODE,
|
type=ExceptionType.NODE,
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
detail="parent cycle node not found"
|
detail="parent cycle node not found"
|
||||||
@@ -189,7 +189,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
node_data = node["data"]
|
node_data = node["data"]
|
||||||
converter = self.get_node_convert(node_type)
|
converter = self.get_node_convert(node_type)
|
||||||
if node_type == NodeType.UNKNOWN:
|
if node_type == NodeType.UNKNOWN:
|
||||||
self.errors.append(ExceptionDefineition(
|
self.errors.append(ExceptionDefinition(
|
||||||
type=ExceptionType.NODE,
|
type=ExceptionType.NODE,
|
||||||
node_id=node["id"],
|
node_id=node["id"],
|
||||||
node_name=node["data"]["title"],
|
node_name=node["data"]["title"],
|
||||||
@@ -197,7 +197,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
))
|
))
|
||||||
return converter(node)
|
return converter(node)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.errors.append(ExceptionDefineition(
|
self.errors.append(ExceptionDefinition(
|
||||||
type=ExceptionType.NODE,
|
type=ExceptionType.NODE,
|
||||||
node_id=node["id"],
|
node_id=node["id"],
|
||||||
node_name=node["data"]["title"],
|
node_name=node["data"]["title"],
|
||||||
@@ -207,7 +207,6 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
|
|
||||||
def _convert_edge(self, edge: dict[str, Any]) -> EdgeDefinition | None:
|
def _convert_edge(self, edge: dict[str, Any]) -> EdgeDefinition | None:
|
||||||
try:
|
try:
|
||||||
|
|
||||||
source = edge["source"]
|
source = edge["source"]
|
||||||
target = edge["target"]
|
target = edge["target"]
|
||||||
label = None
|
label = None
|
||||||
@@ -230,7 +229,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
label=label,
|
label=label,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.errors.append(ExceptionDefineition(
|
self.errors.append(ExceptionDefinition(
|
||||||
type=ExceptionType.EDGE,
|
type=ExceptionType.EDGE,
|
||||||
detail=f"convert edge error - {e}",
|
detail=f"convert edge error - {e}",
|
||||||
))
|
))
|
||||||
@@ -246,7 +245,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
description=variable.get("description")
|
description=variable.get("description")
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.errors.append(ExceptionDefineition(
|
self.errors.append(ExceptionDefinition(
|
||||||
type=ExceptionType.VARIABLE,
|
type=ExceptionType.VARIABLE,
|
||||||
name=variable.get("name"),
|
name=variable.get("name"),
|
||||||
detail=f"convert variable error - {e}",
|
detail=f"convert variable error - {e}",
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ class ExceptionType(StrEnum):
|
|||||||
UNKNOWN = "unknown"
|
UNKNOWN = "unknown"
|
||||||
|
|
||||||
|
|
||||||
class ExceptionDefineition(BaseModel):
|
class ExceptionDefinition(BaseModel):
|
||||||
type: ExceptionType
|
type: ExceptionType
|
||||||
detail: str
|
detail: str
|
||||||
|
|
||||||
@@ -29,7 +29,7 @@ class ExceptionDefineition(BaseModel):
|
|||||||
name: str | None = None
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class UnknowModelWarning(ExceptionDefineition):
|
class UnknownModelWarning(ExceptionDefinition):
|
||||||
type: ExceptionType = ExceptionType.NODE
|
type: ExceptionType = ExceptionType.NODE
|
||||||
|
|
||||||
def __init__(self, node_id, node_name, model_name):
|
def __init__(self, node_id, node_name, model_name):
|
||||||
@@ -40,36 +40,36 @@ class UnknowModelWarning(ExceptionDefineition):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class UnknowError(ExceptionDefineition):
|
class UnknownError(ExceptionDefinition):
|
||||||
type: ExceptionType = ExceptionType.UNKNOWN
|
type: ExceptionType = ExceptionType.UNKNOWN
|
||||||
|
|
||||||
def __init__(self, detail: str, **kwargs):
|
def __init__(self, detail: str, **kwargs):
|
||||||
super().__init__(detail=detail, **kwargs)
|
super().__init__(detail=detail, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class UnsupportPlatform(ExceptionDefineition):
|
class UnsupportedPlatform(ExceptionDefinition):
|
||||||
type: ExceptionType = ExceptionType.PLATFORM
|
type: ExceptionType = ExceptionType.PLATFORM
|
||||||
|
|
||||||
def __init__(self, platform: str):
|
def __init__(self, platform: str):
|
||||||
super().__init__(detail=f"Unsupport platform {platform}")
|
super().__init__(detail=f"Unsupported platform {platform}")
|
||||||
|
|
||||||
|
|
||||||
class UnsupportVariableType(ExceptionDefineition):
|
class UnsupportedVariableType(ExceptionDefinition):
|
||||||
type: ExceptionType = ExceptionType.VARIABLE
|
type: ExceptionType = ExceptionType.VARIABLE
|
||||||
|
|
||||||
def __init__(self, scope, name, var_type: str, **kwargs):
|
def __init__(self, scope, name, var_type: str, **kwargs):
|
||||||
super().__init__(scope=scope, name=name, detail=f"Unsupport variable type:[{var_type}]", **kwargs)
|
super().__init__(scope=scope, name=name, detail=f"Unsupported variable type: [{var_type}]", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class InvalidConfiguration(ExceptionDefineition):
|
class InvalidConfiguration(ExceptionDefinition):
|
||||||
type: ExceptionType = ExceptionType.CONFIG
|
type: ExceptionType = ExceptionType.CONFIG
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(detail="Invalid workflow configuration format")
|
super().__init__(detail="Invalid workflow configuration format")
|
||||||
|
|
||||||
|
|
||||||
class UnsupportNodeType(ExceptionDefineition):
|
class UnsupportedNodeType(ExceptionDefinition):
|
||||||
type: ExceptionType = ExceptionType.NODE
|
type: ExceptionType = ExceptionType.NODE
|
||||||
|
|
||||||
def __init__(self, node_id: str, node_type: str):
|
def __init__(self, node_id: str, node_type: str):
|
||||||
super().__init__(node_id=node_id, detail=f"Unsupport node Type {node_type}")
|
super().__init__(node_id=node_id, detail=f"Unsupported node type {node_type}")
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from app.core.workflow.adapters.base_adapter import (
|
|||||||
BasePlatformAdapter,
|
BasePlatformAdapter,
|
||||||
WorkflowParserResult
|
WorkflowParserResult
|
||||||
)
|
)
|
||||||
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType, UnsupportNodeType
|
from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType, UnsupportedNodeType
|
||||||
from app.core.workflow.adapters.memory_bear.memory_bear_converter import MemoryBearConverter
|
from app.core.workflow.adapters.memory_bear.memory_bear_converter import MemoryBearConverter
|
||||||
from app.core.workflow.nodes.enums import NodeType
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
from app.schemas.workflow_schema import ExecutionConfig, NodeDefinition, EdgeDefinition, VariableDefinition
|
from app.schemas.workflow_schema import ExecutionConfig, NodeDefinition, EdgeDefinition, VariableDefinition
|
||||||
@@ -73,7 +73,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
|
|||||||
try:
|
try:
|
||||||
node_type = self.map_node_type(node["type"])
|
node_type = self.map_node_type(node["type"])
|
||||||
if node_type == NodeType.UNKNOWN:
|
if node_type == NodeType.UNKNOWN:
|
||||||
self.errors.append(UnsupportNodeType(
|
self.errors.append(UnsupportedNodeType(
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
node_type=node["type"]
|
node_type=node["type"]
|
||||||
))
|
))
|
||||||
@@ -85,7 +85,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
|
|||||||
|
|
||||||
return NodeDefinition(**node)
|
return NodeDefinition(**node)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.errors.append(ExceptionDefineition(
|
self.errors.append(ExceptionDefinition(
|
||||||
type=ExceptionType.NODE,
|
type=ExceptionType.NODE,
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
node_name=node_name,
|
node_name=node_name,
|
||||||
@@ -97,14 +97,14 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
|
|||||||
def _convert_edge(self, edge: dict[str, Any], valid_node_ids: set) -> EdgeDefinition | None:
|
def _convert_edge(self, edge: dict[str, Any], valid_node_ids: set) -> EdgeDefinition | None:
|
||||||
try:
|
try:
|
||||||
if edge.get("source") not in valid_node_ids or edge.get("target") not in valid_node_ids:
|
if edge.get("source") not in valid_node_ids or edge.get("target") not in valid_node_ids:
|
||||||
self.warnings.append(ExceptionDefineition(
|
self.warnings.append(ExceptionDefinition(
|
||||||
type=ExceptionType.EDGE,
|
type=ExceptionType.EDGE,
|
||||||
detail=f"edge {edge.get('id')} skipped: source or target node not found"
|
detail=f"edge {edge.get('id')} skipped: source or target node not found"
|
||||||
))
|
))
|
||||||
return None
|
return None
|
||||||
return EdgeDefinition(**edge)
|
return EdgeDefinition(**edge)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.errors.append(ExceptionDefineition(
|
self.errors.append(ExceptionDefinition(
|
||||||
type=ExceptionType.EDGE,
|
type=ExceptionType.EDGE,
|
||||||
detail=f"convert edge error - {e}"
|
detail=f"convert edge error - {e}"
|
||||||
))
|
))
|
||||||
@@ -115,7 +115,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
|
|||||||
try:
|
try:
|
||||||
return VariableDefinition(**variable)
|
return VariableDefinition(**variable)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.warnings.append(ExceptionDefineition(
|
self.warnings.append(ExceptionDefinition(
|
||||||
type=ExceptionType.VARIABLE,
|
type=ExceptionType.VARIABLE,
|
||||||
name=variable.get("name"),
|
name=variable.get("name"),
|
||||||
detail=f"convert variable error - {e}"
|
detail=f"convert variable error - {e}"
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# -*- coding: UTF-8 -*-
|
# -*- coding: UTF-8 -*-
|
||||||
from app.core.workflow.adapters.base_converter import BaseConverter
|
from app.core.workflow.adapters.base_converter import BaseConverter
|
||||||
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType
|
from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType
|
||||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||||
from app.core.workflow.nodes.configs import (
|
from app.core.workflow.nodes.configs import (
|
||||||
StartNodeConfig,
|
StartNodeConfig,
|
||||||
@@ -65,7 +65,7 @@ class MemoryBearConverter(BaseConverter):
|
|||||||
try:
|
try:
|
||||||
return config_cls.model_validate(value)
|
return config_cls.model_validate(value)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.errors.append(ExceptionDefineition(
|
self.errors.append(ExceptionDefinition(
|
||||||
type=ExceptionType.CONFIG,
|
type=ExceptionType.CONFIG,
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
node_name=node_name,
|
node_name=node_name,
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import re
|
|||||||
import uuid
|
import uuid
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Any, Iterable
|
from typing import Any, Iterable, Callable
|
||||||
|
|
||||||
from langgraph.checkpoint.memory import InMemorySaver
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
from langgraph.graph import START, END
|
from langgraph.graph import START, END
|
||||||
@@ -41,48 +41,31 @@ class GraphBuilder:
|
|||||||
self,
|
self,
|
||||||
workflow_config: dict[str, Any],
|
workflow_config: dict[str, Any],
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
subgraph: bool = False,
|
cycle: str = '',
|
||||||
variable_pool: VariablePool | None = None
|
variable_pool: VariablePool | None = None
|
||||||
):
|
):
|
||||||
self.workflow_config = workflow_config
|
self.workflow_config = workflow_config
|
||||||
|
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
self.subgraph = subgraph
|
self.cycle = cycle
|
||||||
|
|
||||||
self.start_node_id: str | None = None
|
self.start_node_id: str | None = None
|
||||||
|
|
||||||
self.node_map = {node["id"]: node for node in self.nodes}
|
self.node_map: dict[str, dict] = {}
|
||||||
self.end_node_map: dict[str, StreamOutputConfig] = {}
|
self.end_node_map: dict[str, StreamOutputConfig] = {}
|
||||||
self._find_upstream_activation_dep = lru_cache(
|
self._find_upstream_activation_dep: Callable = self._find_upstream_activation_dep
|
||||||
maxsize=len(self.nodes) * 2
|
|
||||||
)(self._find_upstream_activation_dep)
|
|
||||||
if variable_pool:
|
if variable_pool:
|
||||||
self.variable_pool = variable_pool
|
self.variable_pool = variable_pool
|
||||||
else:
|
else:
|
||||||
self.variable_pool = VariablePool()
|
self.variable_pool = VariablePool()
|
||||||
|
|
||||||
self.graph = StateGraph(WorkflowState)
|
self.graph: StateGraph | None = None
|
||||||
self.add_nodes()
|
self.nodes: list = []
|
||||||
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
|
self.edges: list = []
|
||||||
self.end_nodes = [
|
self.reachable_nodes: set[str] | None = None
|
||||||
node
|
self.end_nodes: list[dict] = []
|
||||||
for node in self.nodes
|
|
||||||
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
|
|
||||||
]
|
|
||||||
self.add_edges()
|
|
||||||
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
|
||||||
|
|
||||||
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
|
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
|
||||||
self._build_reverse_adj()
|
self._adj: dict[str, list[str]] = defaultdict(list)
|
||||||
self._analyze_end_node_output()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def nodes(self) -> list[dict[str, Any]]:
|
|
||||||
return self.workflow_config.get("nodes", [])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def edges(self) -> list[dict[str, Any]]:
|
|
||||||
return self.workflow_config.get("edges", [])
|
|
||||||
|
|
||||||
def get_node_type(self, node_id: str) -> str:
|
def get_node_type(self, node_id: str) -> str:
|
||||||
"""Retrieve the type of node given its ID.
|
"""Retrieve the type of node given its ID.
|
||||||
@@ -108,13 +91,14 @@ class GraphBuilder:
|
|||||||
result[node[0]].append(node[1])
|
result[node[0]].append(node[1])
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _build_reverse_adj(self):
|
def _build_adj(self):
|
||||||
for edge in self.edges:
|
for edge in self.edges:
|
||||||
if edge["source"] not in self.reachable_nodes:
|
if edge["source"] not in self.reachable_nodes:
|
||||||
continue
|
continue
|
||||||
self._reverse_adj[edge.get("target")].append({
|
self._reverse_adj[edge.get("target")].append({
|
||||||
"id": edge["source"], "branch": edge.get("label")
|
"id": edge["source"], "branch": edge.get("label")
|
||||||
})
|
})
|
||||||
|
self._adj[edge.get("source")].append(edge["target"])
|
||||||
|
|
||||||
def _find_upstream_activation_dep(
|
def _find_upstream_activation_dep(
|
||||||
self,
|
self,
|
||||||
@@ -302,22 +286,13 @@ class GraphBuilder:
|
|||||||
"""
|
"""
|
||||||
for node in self.nodes:
|
for node in self.nodes:
|
||||||
node_type = node.get("type")
|
node_type = node.get("type")
|
||||||
if node_type == NodeType.NOTES:
|
|
||||||
continue
|
|
||||||
node_id = node.get("id")
|
node_id = node.get("id")
|
||||||
cycle_node = node.get("cycle")
|
if node_id not in self.reachable_nodes:
|
||||||
if cycle_node:
|
continue
|
||||||
# Nodes within a loop subgraph are constructed by CycleGraphNode
|
|
||||||
if not self.subgraph:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Record start and end node IDs
|
|
||||||
if node_type in [NodeType.START, NodeType.CYCLE_START]:
|
|
||||||
self.start_node_id = node_id
|
|
||||||
|
|
||||||
# Create node instance (start and end nodes are also created)
|
# Create node instance (start and end nodes are also created)
|
||||||
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
|
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
|
||||||
node_instance = NodeFactory.create_node(node, self.workflow_config)
|
node_instance = NodeFactory.create_node(node, self.workflow_config, self._adj[node_id])
|
||||||
|
|
||||||
if node_type in BRANCH_NODES:
|
if node_type in BRANCH_NODES:
|
||||||
|
|
||||||
@@ -413,11 +388,12 @@ class GraphBuilder:
|
|||||||
# Add conditional edges
|
# Add conditional edges
|
||||||
for source_node, branches in conditional_edges.items():
|
for source_node, branches in conditional_edges.items():
|
||||||
def make_router(src, branch_list):
|
def make_router(src, branch_list):
|
||||||
"""reate a router function for each source node that routes to a NOP node for later merging."""
|
"""Create a router function for each source node that routes to a NOP node for later merging."""
|
||||||
|
|
||||||
def make_branch_node(node_name, targets):
|
def make_branch_node(node_name, targets):
|
||||||
def node(s):
|
def node(s):
|
||||||
# NOTE: NOP NODE MUST NOT MODIFY STATE
|
# NOTE: NOP NODE USED FOR ROUTING ONLY.
|
||||||
|
# MUST NOT MUTATE STATE DIRECTLY; ONLY EMIT ACTIVATE SIGNALS.
|
||||||
return {
|
return {
|
||||||
"activate": {
|
"activate": {
|
||||||
node_id: s["activate"][node_name]
|
node_id: s["activate"][node_name]
|
||||||
@@ -504,14 +480,52 @@ class GraphBuilder:
|
|||||||
logger.debug(f"Added waiting edge: {sources} -> {target}")
|
logger.debug(f"Added waiting edge: {sources} -> {target}")
|
||||||
|
|
||||||
# Connect End nodes to the global END node
|
# Connect End nodes to the global END node
|
||||||
for end_node in self.end_nodes:
|
for node in self.reachable_nodes:
|
||||||
end_node_id = end_node.get("id")
|
if not self._adj[node]:
|
||||||
if end_node_id:
|
self.graph.add_edge(node, END)
|
||||||
self.graph.add_edge(end_node_id, END)
|
|
||||||
logger.debug(f"Added edge: {end_node_id} -> END")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
def build(self) -> CompiledStateGraph:
|
def build(self) -> CompiledStateGraph:
|
||||||
|
nodes = self.workflow_config.get("nodes", [])
|
||||||
|
edges = self.workflow_config.get("edges", [])
|
||||||
|
|
||||||
|
for node in nodes:
|
||||||
|
if (node.get("cycle") or '') == self.cycle:
|
||||||
|
node_type = node.get("type")
|
||||||
|
if node_type in [NodeType.START, NodeType.CYCLE_START]:
|
||||||
|
self.start_node_id = node.get("id")
|
||||||
|
elif node_type == NodeType.NOTES:
|
||||||
|
continue
|
||||||
|
self.nodes.append(node)
|
||||||
|
self.node_map[node.get("id")] = node
|
||||||
|
|
||||||
|
for edge in edges:
|
||||||
|
source_in = edge.get("source") in self.node_map
|
||||||
|
target_in = edge.get("target") in self.node_map
|
||||||
|
if source_in ^ target_in:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cycle node is connected to external node, "
|
||||||
|
f"source: {edge.get('source')}, target: {edge.get('target')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if source_in and target_in:
|
||||||
|
self.edges.append(edge)
|
||||||
|
|
||||||
|
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
|
||||||
|
self.end_nodes = [
|
||||||
|
node
|
||||||
|
for node in self.nodes
|
||||||
|
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
|
||||||
|
]
|
||||||
|
self._build_adj()
|
||||||
|
self._find_upstream_activation_dep: Callable = lru_cache(
|
||||||
|
maxsize=len(self.nodes)*2
|
||||||
|
)(self._find_upstream_activation_dep)
|
||||||
|
|
||||||
|
self.graph = StateGraph(WorkflowState)
|
||||||
|
self.add_nodes()
|
||||||
|
self.add_edges()
|
||||||
|
|
||||||
|
self._analyze_end_node_output()
|
||||||
checkpointer = InMemorySaver()
|
checkpointer = InMemorySaver()
|
||||||
self.graph = self.graph.compile(checkpointer=checkpointer)
|
return self.graph.compile(checkpointer=checkpointer)
|
||||||
return self.graph
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
# Author: Eternity
|
# Author: Eternity
|
||||||
# @Email: 1533512157@qq.com
|
# @Email: 1533512157@qq.com
|
||||||
# @Time : 2026/2/10 13:33
|
# @Time : 2026/2/10 13:33
|
||||||
|
from app.core.workflow.engine.runtime_schema import ExecutionContext
|
||||||
from app.core.workflow.engine.variable_pool import VariablePool
|
from app.core.workflow.engine.variable_pool import VariablePool
|
||||||
|
|
||||||
|
|
||||||
@@ -9,6 +10,7 @@ class WorkflowResultBuilder:
|
|||||||
def build_final_output(
|
def build_final_output(
|
||||||
self,
|
self,
|
||||||
result: dict,
|
result: dict,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool,
|
||||||
elapsed_time: float,
|
elapsed_time: float,
|
||||||
final_output: str,
|
final_output: str,
|
||||||
@@ -26,6 +28,8 @@ class WorkflowResultBuilder:
|
|||||||
- "node_outputs" (dict): Outputs of executed nodes.
|
- "node_outputs" (dict): Outputs of executed nodes.
|
||||||
- "messages" (list): Conversation messages exchanged during execution.
|
- "messages" (list): Conversation messages exchanged during execution.
|
||||||
- "error" (str, optional): Error message if any node failed.
|
- "error" (str, optional): Error message if any node failed.
|
||||||
|
execution_context (ExecutionContext): The execution context containing metadata like
|
||||||
|
execution ID, workspace ID, and user ID.)
|
||||||
variable_pool (VariablePool): Variable Pool
|
variable_pool (VariablePool): Variable Pool
|
||||||
elapsed_time (float): Total execution time in seconds.
|
elapsed_time (float): Total execution time in seconds.
|
||||||
final_output (Any): The aggregated or final output content of the workflow
|
final_output (Any): The aggregated or final output content of the workflow
|
||||||
@@ -48,18 +52,23 @@ class WorkflowResultBuilder:
|
|||||||
"""
|
"""
|
||||||
node_outputs = result.get("node_outputs", {})
|
node_outputs = result.get("node_outputs", {})
|
||||||
token_usage = self.aggregate_token_usage(node_outputs)
|
token_usage = self.aggregate_token_usage(node_outputs)
|
||||||
conversation_id = variable_pool.get_value("sys.conversation_id")
|
conversation_vars = {}
|
||||||
|
sys_vars = {}
|
||||||
|
|
||||||
|
if variable_pool:
|
||||||
|
conversation_vars = variable_pool.get_all_conversation_vars()
|
||||||
|
sys_vars = variable_pool.get_all_system_vars()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "completed" if success else "failed",
|
"status": "completed" if success else "failed",
|
||||||
"output": final_output,
|
"output": final_output,
|
||||||
"variables": {
|
"variables": {
|
||||||
"conv": variable_pool.get_all_conversation_vars(),
|
"conv": conversation_vars,
|
||||||
"sys": variable_pool.get_all_system_vars()
|
"sys": sys_vars
|
||||||
},
|
},
|
||||||
"node_outputs": node_outputs,
|
"node_outputs": node_outputs,
|
||||||
"messages": result.get("messages", []),
|
"messages": result.get("messages", []),
|
||||||
"conversation_id": conversation_id,
|
"conversation_id": execution_context.conversation_id,
|
||||||
"elapsed_time": elapsed_time,
|
"elapsed_time": elapsed_time,
|
||||||
"token_usage": token_usage,
|
"token_usage": token_usage,
|
||||||
"error": result.get("error"),
|
"error": result.get("error"),
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ class ExecutionContext(BaseModel):
|
|||||||
execution_id: str
|
execution_id: str
|
||||||
workspace_id: str
|
workspace_id: str
|
||||||
user_id: str
|
user_id: str
|
||||||
|
conversation_id: str
|
||||||
memory_storage_type: str
|
memory_storage_type: str
|
||||||
user_rag_memory_id: str
|
user_rag_memory_id: str
|
||||||
checkpoint_config: RunnableConfig
|
checkpoint_config: RunnableConfig
|
||||||
@@ -22,6 +23,7 @@ class ExecutionContext(BaseModel):
|
|||||||
execution_id: str,
|
execution_id: str,
|
||||||
workspace_id: str,
|
workspace_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
conversation_id: str,
|
||||||
memory_storage_type: str,
|
memory_storage_type: str,
|
||||||
user_rag_memory_id: str
|
user_rag_memory_id: str
|
||||||
):
|
):
|
||||||
@@ -29,6 +31,7 @@ class ExecutionContext(BaseModel):
|
|||||||
execution_id=execution_id,
|
execution_id=execution_id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
conversation_id=conversation_id,
|
||||||
memory_storage_type=memory_storage_type,
|
memory_storage_type=memory_storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
# @Email: 1533512157@qq.com
|
# @Email: 1533512157@qq.com
|
||||||
# @Time : 2026/2/9 13:51
|
# @Time : 2026/2/9 13:51
|
||||||
import datetime
|
import datetime
|
||||||
|
import time
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -82,13 +83,15 @@ class WorkflowExecutor:
|
|||||||
CompiledStateGraph: The compiled and ready-to-run state graph.
|
CompiledStateGraph: The compiled and ready-to-run state graph.
|
||||||
"""
|
"""
|
||||||
logger.info(f"Starting workflow graph build: execution_id={self.execution_context.execution_id}")
|
logger.info(f"Starting workflow graph build: execution_id={self.execution_context.execution_id}")
|
||||||
|
start_time = time.time()
|
||||||
builder = GraphBuilder(
|
builder = GraphBuilder(
|
||||||
self.workflow_config,
|
self.workflow_config,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.graph = builder.build()
|
||||||
self.start_node_id = builder.start_node_id
|
self.start_node_id = builder.start_node_id
|
||||||
self.variable_pool = builder.variable_pool
|
self.variable_pool = builder.variable_pool
|
||||||
self.graph = builder.build()
|
|
||||||
|
|
||||||
self.stream_coordinator.initialize_end_outputs(builder.end_node_map)
|
self.stream_coordinator.initialize_end_outputs(builder.end_node_map)
|
||||||
self.event_handler = EventStreamHandler(
|
self.event_handler = EventStreamHandler(
|
||||||
@@ -96,7 +99,8 @@ class WorkflowExecutor:
|
|||||||
variable_pool=self.variable_pool,
|
variable_pool=self.variable_pool,
|
||||||
execution_id=self.execution_context.execution_id
|
execution_id=self.execution_context.execution_id
|
||||||
)
|
)
|
||||||
logger.info(f"Workflow graph build completed: execution_id={self.execution_context.execution_id}")
|
logger.info(f"Workflow graph build completed: execution_id={self.execution_context.execution_id}, "
|
||||||
|
f"cost: {time.time() - start_time:.4f}s")
|
||||||
|
|
||||||
return self.graph
|
return self.graph
|
||||||
|
|
||||||
@@ -134,94 +138,12 @@ class WorkflowExecutor:
|
|||||||
return event.get("data")
|
return event.get("data")
|
||||||
return self.result_builder.build_final_output(
|
return self.result_builder.build_final_output(
|
||||||
{"error": "Workflow execution did not end as expected"},
|
{"error": "Workflow execution did not end as expected"},
|
||||||
|
self.execution_context,
|
||||||
self.variable_pool,
|
self.variable_pool,
|
||||||
(datetime.datetime.now() - start).total_seconds(),
|
(datetime.datetime.now() - start).total_seconds(),
|
||||||
"",
|
"",
|
||||||
success=False
|
success=False
|
||||||
)
|
)
|
||||||
# logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}")
|
|
||||||
#
|
|
||||||
# start_time = datetime.datetime.now()
|
|
||||||
#
|
|
||||||
# # Execute the workflow
|
|
||||||
# try:
|
|
||||||
# # Build the workflow graph
|
|
||||||
# graph = self.build_graph()
|
|
||||||
#
|
|
||||||
# # Initialize the variable pool with input data
|
|
||||||
# await self.variable_initializer.initialize(
|
|
||||||
# variable_pool=self.variable_pool,
|
|
||||||
# input_data=input_data,
|
|
||||||
# execution_context=self.execution_context
|
|
||||||
# )
|
|
||||||
# initial_state = self.state_manager.create_initial_state(
|
|
||||||
# workflow_config=self.workflow_config,
|
|
||||||
# input_data=input_data,
|
|
||||||
# execution_context=self.execution_context,
|
|
||||||
# start_node_id=self.start_node_id
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
|
|
||||||
#
|
|
||||||
# # Aggregate output from all End nodes
|
|
||||||
# full_content = ''
|
|
||||||
# for end_id in self.stream_coordinator.end_outputs.keys():
|
|
||||||
# full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
|
|
||||||
#
|
|
||||||
# # Append messages for user and assistant
|
|
||||||
# if input_data.get("files"):
|
|
||||||
# result["messages"].extend(
|
|
||||||
# [
|
|
||||||
# {
|
|
||||||
# "role": "user",
|
|
||||||
# "content": input_data.get("message", '')
|
|
||||||
# },
|
|
||||||
# {
|
|
||||||
# "role": "user",
|
|
||||||
# "content": input_data.get("files")
|
|
||||||
# },
|
|
||||||
# {
|
|
||||||
# "role": "assistant",
|
|
||||||
# "content": full_content
|
|
||||||
# }
|
|
||||||
# ]
|
|
||||||
# )
|
|
||||||
# else:
|
|
||||||
# result["messages"].extend(
|
|
||||||
# [
|
|
||||||
# {
|
|
||||||
# "role": "user",
|
|
||||||
# "content": input_data.get("message", '')
|
|
||||||
# },
|
|
||||||
# {
|
|
||||||
# "role": "assistant",
|
|
||||||
# "content": full_content
|
|
||||||
# }
|
|
||||||
# ]
|
|
||||||
# )
|
|
||||||
# # Calculate elapsed time
|
|
||||||
# end_time = datetime.datetime.now()
|
|
||||||
# elapsed_time = (end_time - start_time).total_seconds()
|
|
||||||
#
|
|
||||||
# logger.info(
|
|
||||||
# f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms")
|
|
||||||
#
|
|
||||||
# return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
|
|
||||||
#
|
|
||||||
# except Exception as e:
|
|
||||||
# end_time = datetime.datetime.now()
|
|
||||||
# elapsed_time = (end_time - start_time).total_seconds()
|
|
||||||
#
|
|
||||||
# logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
|
||||||
# exc_info=True)
|
|
||||||
# return {
|
|
||||||
# "status": "failed",
|
|
||||||
# "error": str(e),
|
|
||||||
# "output": None,
|
|
||||||
# "node_outputs": {},
|
|
||||||
# "elapsed_time": elapsed_time,
|
|
||||||
# "token_usage": None
|
|
||||||
# }
|
|
||||||
|
|
||||||
async def execute_stream(
|
async def execute_stream(
|
||||||
self,
|
self,
|
||||||
@@ -255,7 +177,7 @@ class WorkflowExecutor:
|
|||||||
"data": {
|
"data": {
|
||||||
"execution_id": self.execution_context.execution_id,
|
"execution_id": self.execution_context.execution_id,
|
||||||
"workspace_id": self.execution_context.workspace_id,
|
"workspace_id": self.execution_context.workspace_id,
|
||||||
"conversation_id": input_data.get("conversation_id"),
|
"conversation_id": self.execution_context.conversation_id,
|
||||||
"timestamp": int(start_time.timestamp() * 1000)
|
"timestamp": int(start_time.timestamp() * 1000)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -376,6 +298,7 @@ class WorkflowExecutor:
|
|||||||
"event": "workflow_end",
|
"event": "workflow_end",
|
||||||
"data": self.result_builder.build_final_output(
|
"data": self.result_builder.build_final_output(
|
||||||
result,
|
result,
|
||||||
|
self.execution_context,
|
||||||
self.variable_pool,
|
self.variable_pool,
|
||||||
elapsed_time,
|
elapsed_time,
|
||||||
full_content,
|
full_content,
|
||||||
@@ -396,6 +319,7 @@ class WorkflowExecutor:
|
|||||||
"event": "workflow_end",
|
"event": "workflow_end",
|
||||||
"data": self.result_builder.build_final_output(
|
"data": self.result_builder.build_final_output(
|
||||||
result,
|
result,
|
||||||
|
self.execution_context,
|
||||||
self.variable_pool,
|
self.variable_pool,
|
||||||
elapsed_time,
|
elapsed_time,
|
||||||
full_content,
|
full_content,
|
||||||
@@ -432,6 +356,7 @@ async def execute_workflow(
|
|||||||
execution_id=execution_id,
|
execution_id=execution_id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
conversation_id=input_data.get("conversation_id"),
|
||||||
memory_storage_type=memory_storage_type,
|
memory_storage_type=memory_storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id
|
user_rag_memory_id=user_rag_memory_id
|
||||||
)
|
)
|
||||||
@@ -471,6 +396,7 @@ async def execute_workflow_stream(
|
|||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
memory_storage_type=memory_storage_type,
|
memory_storage_type=memory_storage_type,
|
||||||
|
conversation_id=input_data.get("conversation_id"),
|
||||||
user_rag_memory_id=user_rag_memory_id
|
user_rag_memory_id=user_rag_memory_id
|
||||||
)
|
)
|
||||||
executor = WorkflowExecutor(
|
executor = WorkflowExecutor(
|
||||||
|
|||||||
@@ -65,8 +65,6 @@ class AgentNode(BaseNode):
|
|||||||
if not release:
|
if not release:
|
||||||
raise ValueError(f"Agent 不存在: {agent_id}")
|
raise ValueError(f"Agent 不存在: {agent_id}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return release, message
|
return release, message
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||||
|
|||||||
@@ -14,8 +14,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class AssignerNode(BaseNode):
|
class AssignerNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.variable_updater = True
|
self.variable_updater = True
|
||||||
self.typed_config: AssignerNodeConfig | None = None
|
self.typed_config: AssignerNodeConfig | None = None
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class BaseNode(ABC):
|
|||||||
All node types should inherit from this class and implement the `execute` method.
|
All node types should inherit from this class and implement the `execute` method.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
"""Initialize the node.
|
"""Initialize the node.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -41,6 +41,7 @@ class BaseNode(ABC):
|
|||||||
self.node_type = node_config["type"]
|
self.node_type = node_config["type"]
|
||||||
self.cycle = node_config.get("cycle")
|
self.cycle = node_config.get("cycle")
|
||||||
self.node_name = node_config.get("name", self.node_id)
|
self.node_name = node_config.get("name", self.node_id)
|
||||||
|
self.down_stream_nodes = down_stream_nodes
|
||||||
# 使用 or 运算符处理 None 值
|
# 使用 or 运算符处理 None 值
|
||||||
self.config = node_config.get("config") or {}
|
self.config = node_config.get("config") or {}
|
||||||
self.error_handling = node_config.get("error_handling") or {}
|
self.error_handling = node_config.get("error_handling") or {}
|
||||||
@@ -93,18 +94,16 @@ class BaseNode(ABC):
|
|||||||
dict: A dict with a single key 'activate', mapping node IDs to
|
dict: A dict with a single key 'activate', mapping node IDs to
|
||||||
their activation status (True/False).
|
their activation status (True/False).
|
||||||
"""
|
"""
|
||||||
edges = self.workflow_config.get("edges")
|
activate_flag = self.check_activate(state)
|
||||||
under_stream_nodes = [
|
|
||||||
edge.get("target")
|
if self.node_type not in BRANCH_NODES:
|
||||||
for edge in edges
|
activate = {node_id: activate_flag for node_id in self.down_stream_nodes}
|
||||||
if edge.get("source") == self.node_id and self.node_type not in BRANCH_NODES
|
else:
|
||||||
]
|
activate = {}
|
||||||
return {
|
|
||||||
"activate": {
|
activate[self.node_id] = activate_flag
|
||||||
node_id: self.check_activate(state)
|
|
||||||
for node_id in under_stream_nodes
|
return {"activate": activate}
|
||||||
} | {self.node_id: self.check_activate(state)}
|
|
||||||
}
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||||
@@ -315,8 +314,8 @@ class BaseNode(ABC):
|
|||||||
|
|
||||||
elapsed_time = (time.time() - start_time) * 1000
|
elapsed_time = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
logger.info(f"Node {self.node_id} streaming execution finished, "
|
logger.debug(f"Node {self.node_id} streaming execution finished, "
|
||||||
f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}")
|
f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}")
|
||||||
|
|
||||||
# Extract processed output (call subclass's _extract_output)
|
# Extract processed output (call subclass's _extract_output)
|
||||||
extracted_output = self._extract_output(final_result)
|
extracted_output = self._extract_output(final_result)
|
||||||
@@ -428,8 +427,8 @@ class BaseNode(ABC):
|
|||||||
when an error edge exists. If no error edge exists, this method
|
when an error edge exists. If no error edge exists, this method
|
||||||
raises an exception to stop the workflow.
|
raises an exception to stop the workflow.
|
||||||
"""
|
"""
|
||||||
# Check if the node has an error edge defined
|
# # Check if the node has an error edge defined
|
||||||
error_edge = self._find_error_edge()
|
# error_edge = self._find_error_edge()
|
||||||
|
|
||||||
# Extract input data (for logging or audit purposes)
|
# Extract input data (for logging or audit purposes)
|
||||||
input_data = self._extract_input(state, variable_pool)
|
input_data = self._extract_input(state, variable_pool)
|
||||||
@@ -447,27 +446,26 @@ class BaseNode(ABC):
|
|||||||
"error": error_message
|
"error": error_message
|
||||||
}
|
}
|
||||||
|
|
||||||
if error_edge:
|
# if error_edge:
|
||||||
# If an error edge exists, log a warning and continue to error node
|
# # If an error edge exists, log a warning and continue to error node
|
||||||
logger.warning(
|
# logger.warning(
|
||||||
f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}"
|
# f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}"
|
||||||
)
|
# )
|
||||||
return {
|
# return {
|
||||||
"node_outputs": {
|
# "node_outputs": {
|
||||||
self.node_id: node_output
|
# self.node_id: node_output
|
||||||
},
|
# },
|
||||||
"error": error_message,
|
# "error": error_message,
|
||||||
"error_node": self.node_id
|
# "error_node": self.node_id
|
||||||
}
|
# }
|
||||||
else:
|
# else:
|
||||||
# If no error edge, send the error via stream writer and stop the workflow
|
writer = get_stream_writer()
|
||||||
writer = get_stream_writer()
|
writer({
|
||||||
writer({
|
"type": "node_error",
|
||||||
"type": "node_error",
|
**node_output
|
||||||
**node_output
|
})
|
||||||
})
|
logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}")
|
||||||
logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}")
|
raise Exception(f"Node {self.node_id} execution failed: {error_message}")
|
||||||
raise Exception(f"Node {self.node_id} execution failed: {error_message}")
|
|
||||||
|
|
||||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||||
"""Extracts the input data for this node (used for logging or audit).
|
"""Extracts the input data for this node (used for logging or audit).
|
||||||
@@ -644,7 +642,7 @@ class BaseNode(ABC):
|
|||||||
if content.content_cache.get(f"{provider}_{ModelInfo.is_omni}"):
|
if content.content_cache.get(f"{provider}_{ModelInfo.is_omni}"):
|
||||||
return content.content_cache[f"{provider}_{ModelInfo.is_omni}"]
|
return content.content_cache[f"{provider}_{ModelInfo.is_omni}"]
|
||||||
with get_db_read() as db:
|
with get_db_read() as db:
|
||||||
multimodel_service = MultimodalService(db, api_config=api_config)
|
multimodal_service = MultimodalService(db, api_config=api_config)
|
||||||
file_obj = FileInput(
|
file_obj = FileInput(
|
||||||
type=content.type,
|
type=content.type,
|
||||||
url=content.url,
|
url=content.url,
|
||||||
@@ -653,7 +651,7 @@ class BaseNode(ABC):
|
|||||||
upload_file_id=uuid.UUID(content.file_id) if content.file_id else None,
|
upload_file_id=uuid.UUID(content.file_id) if content.file_id else None,
|
||||||
)
|
)
|
||||||
file_obj.set_content(content.get_content())
|
file_obj.set_content(content.get_content())
|
||||||
message = await multimodel_service.process_files(
|
message = await multimodal_service.process_files(
|
||||||
[file_obj],
|
[file_obj],
|
||||||
)
|
)
|
||||||
content.set_content(file_obj.get_content())
|
content.set_content(file_obj.get_content())
|
||||||
@@ -661,7 +659,7 @@ class BaseNode(ABC):
|
|||||||
content.content_cache[f"{provider}_{ModelInfo.is_omni}"] = message
|
content.content_cache[f"{provider}_{ModelInfo.is_omni}"] = message
|
||||||
return message
|
return message
|
||||||
return None
|
return None
|
||||||
raise TypeError(f'Unexpect input value type - {type(content)}')
|
raise TypeError(f'Unexpected input value type - {type(content)}')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def process_model_output(content) -> str:
|
def process_model_output(content) -> str:
|
||||||
|
|||||||
@@ -51,8 +51,8 @@ console.log(result)
|
|||||||
|
|
||||||
|
|
||||||
class CodeNode(BaseNode):
|
class CodeNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: CodeNodeConfig | None = None
|
self.typed_config: CodeNodeConfig | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
|
|||||||
@@ -30,17 +30,13 @@ class CycleGraphNode(BaseNode):
|
|||||||
It acts as a container and execution controller for a subgraph.
|
It acts as a container and execution controller for a subgraph.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
|
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
|
||||||
self.cycle_nodes = list() # Nodes belonging to this cycle
|
|
||||||
self.cycle_edges = list() # Edges connecting nodes within the cycle
|
|
||||||
self.start_node_id = None # ID of the start node within the cycle
|
self.start_node_id = None # ID of the start node within the cycle
|
||||||
|
|
||||||
self.graph: StateGraph | CompiledStateGraph | None = None
|
self.graph: StateGraph | CompiledStateGraph | None = None
|
||||||
self.child_variable_pool: VariablePool | None = None
|
self.child_variable_pool: VariablePool | None = None
|
||||||
self.build_graph()
|
|
||||||
self.iteration_flag = True
|
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
outputs = {"__child_state": VariableType.ARRAY_OBJECT}
|
outputs = {"__child_state": VariableType.ARRAY_OBJECT}
|
||||||
@@ -119,11 +115,11 @@ class CycleGraphNode(BaseNode):
|
|||||||
else:
|
else:
|
||||||
remain_edges.append(edge)
|
remain_edges.append(edge)
|
||||||
|
|
||||||
# Update workflow_config by removing cycle nodes and internal edges
|
# # Update workflow_config by removing cycle nodes and internal edges
|
||||||
self.workflow_config["nodes"] = [
|
# self.workflow_config["nodes"] = [
|
||||||
node for node in nodes if node.get("cycle") != self.node_id
|
# node for node in nodes if node.get("cycle") != self.node_id
|
||||||
]
|
# ]
|
||||||
self.workflow_config["edges"] = remain_edges
|
# self.workflow_config["edges"] = remain_edges
|
||||||
|
|
||||||
return cycle_nodes, cycle_edges
|
return cycle_nodes, cycle_edges
|
||||||
|
|
||||||
@@ -137,18 +133,18 @@ class CycleGraphNode(BaseNode):
|
|||||||
3. Compile the graph for runtime execution
|
3. Compile the graph for runtime execution
|
||||||
"""
|
"""
|
||||||
from app.core.workflow.engine.graph_builder import GraphBuilder
|
from app.core.workflow.engine.graph_builder import GraphBuilder
|
||||||
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
|
|
||||||
self.child_variable_pool = VariablePool()
|
self.child_variable_pool = VariablePool()
|
||||||
builder = GraphBuilder(
|
builder = GraphBuilder(
|
||||||
{
|
{
|
||||||
"nodes": self.cycle_nodes,
|
"nodes": self.cycle_nodes,
|
||||||
"edges": self.cycle_edges,
|
"edges": self.cycle_edges,
|
||||||
},
|
},
|
||||||
subgraph=True,
|
variable_pool=self.child_variable_pool,
|
||||||
variable_pool=self.child_variable_pool
|
cycle=self.node_id
|
||||||
)
|
)
|
||||||
self.start_node_id = builder.start_node_id
|
|
||||||
self.graph = builder.build()
|
self.graph = builder.build()
|
||||||
|
self.start_node_id = builder.start_node_id
|
||||||
self.child_variable_pool = builder.variable_pool
|
self.child_variable_pool = builder.variable_pool
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||||
@@ -169,6 +165,7 @@ class CycleGraphNode(BaseNode):
|
|||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If the node type is unsupported.
|
RuntimeError: If the node type is unsupported.
|
||||||
"""
|
"""
|
||||||
|
self.build_graph()
|
||||||
if self.node_type == NodeType.LOOP:
|
if self.node_type == NodeType.LOOP:
|
||||||
return await LoopRuntime(
|
return await LoopRuntime(
|
||||||
start_id=self.start_node_id,
|
start_id=self.start_node_id,
|
||||||
@@ -194,6 +191,7 @@ class CycleGraphNode(BaseNode):
|
|||||||
raise RuntimeError("Unknown cycle node type")
|
raise RuntimeError("Unknown cycle node type")
|
||||||
|
|
||||||
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
||||||
|
self.build_graph()
|
||||||
if self.node_type == NodeType.LOOP:
|
if self.node_type == NodeType.LOOP:
|
||||||
yield {
|
yield {
|
||||||
"__final__": True,
|
"__final__": True,
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
"""End 节点配置"""
|
"""End 节点配置"""
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
|
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||||
from app.core.workflow.variable.base_variable import VariableType
|
|
||||||
|
|
||||||
|
|
||||||
class EndNodeConfig(BaseNodeConfig):
|
class EndNodeConfig(BaseNodeConfig):
|
||||||
|
|||||||
@@ -36,8 +36,6 @@ class EndNode(BaseNode):
|
|||||||
Returns:
|
Returns:
|
||||||
最终输出字符串
|
最终输出字符串
|
||||||
"""
|
"""
|
||||||
logger.info(f"节点 {self.node_id} (End) 开始执行")
|
|
||||||
|
|
||||||
# 获取配置的输出模板
|
# 获取配置的输出模板
|
||||||
output_template = self.config.get("output")
|
output_template = self.config.get("output")
|
||||||
|
|
||||||
@@ -46,11 +44,4 @@ class EndNode(BaseNode):
|
|||||||
output = self._render_template(output_template, variable_pool, strict=False)
|
output = self._render_template(output_template, variable_pool, strict=False)
|
||||||
else:
|
else:
|
||||||
output = ""
|
output = ""
|
||||||
|
|
||||||
# 统计信息(用于日志)
|
|
||||||
node_outputs = state.get("node_outputs", {})
|
|
||||||
total_nodes = len(node_outputs)
|
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class NodeType(StrEnum):
|
|||||||
NOTES = "notes"
|
NOTES = "notes"
|
||||||
|
|
||||||
|
|
||||||
BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]
|
BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER})
|
||||||
|
|
||||||
|
|
||||||
class ComparisonOperator(StrEnum):
|
class ComparisonOperator(StrEnum):
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ class HttpRetryConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class HttpErrorDefaultTamplete(BaseModel):
|
class HttpErrorDefaultTemplate(BaseModel):
|
||||||
body: str = Field(
|
body: str = Field(
|
||||||
default="",
|
default="",
|
||||||
description="Default body returned on HTTP error",
|
description="Default body returned on HTTP error",
|
||||||
@@ -143,7 +143,7 @@ class HttpErrorHandleConfig(BaseModel):
|
|||||||
description="Error handling strategy: 'none', 'default', or 'branch'",
|
description="Error handling strategy: 'none', 'default', or 'branch'",
|
||||||
)
|
)
|
||||||
|
|
||||||
default: HttpErrorDefaultTamplete | None = Field(
|
default: HttpErrorDefaultTemplate | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Default response template for error handling",
|
description="Default response template for error handling",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
|||||||
from app.core.workflow.nodes.base_node import BaseNode
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
|
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
|
||||||
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
|
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
|
||||||
from app.core.workflow.utils.file_processer import mime_to_file_type
|
from app.core.workflow.utils.file_processor import mime_to_file_type
|
||||||
from app.core.workflow.variable.base_variable import VariableType, FileObject
|
from app.core.workflow.variable.base_variable import VariableType, FileObject
|
||||||
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
|
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
|
||||||
from app.schemas import FileType, TransferMethod
|
from app.schemas import FileType, TransferMethod
|
||||||
@@ -157,8 +157,8 @@ class HttpRequestNode(BaseNode):
|
|||||||
or a branch identifier string when error branching is enabled.
|
or a branch identifier string when error branching is enabled.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: HttpRequestNodeConfig | None = None
|
self.typed_config: HttpRequestNodeConfig | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
|
|||||||
@@ -14,8 +14,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class IfElseNode(BaseNode):
|
class IfElseNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: IfElseNodeConfig | None = None
|
self.typed_config: IfElseNodeConfig | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class JinjaRenderNode(BaseNode):
|
class JinjaRenderNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: JinjaRenderNodeConfig | None = None
|
self.typed_config: JinjaRenderNodeConfig | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
|
|||||||
@@ -21,8 +21,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class KnowledgeRetrievalNode(BaseNode):
|
class KnowledgeRetrievalNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
|
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
|
||||||
self.vector_service: ElasticSearchVector | None = None
|
self.vector_service: ElasticSearchVector | None = None
|
||||||
|
|
||||||
|
|||||||
@@ -70,8 +70,8 @@ class LLMNode(BaseNode):
|
|||||||
- ai/assistant: AI 消息(AIMessage)
|
- ai/assistant: AI 消息(AIMessage)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: LLMNodeConfig | None = None
|
self.typed_config: LLMNodeConfig | None = None
|
||||||
self.messages = []
|
self.messages = []
|
||||||
|
|
||||||
|
|||||||
@@ -14,8 +14,8 @@ from app.tasks import write_message_task
|
|||||||
|
|
||||||
|
|
||||||
class MemoryReadNode(BaseNode):
|
class MemoryReadNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: MemoryReadNodeConfig | None = None
|
self.typed_config: MemoryReadNodeConfig | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
@@ -45,8 +45,8 @@ class MemoryReadNode(BaseNode):
|
|||||||
|
|
||||||
|
|
||||||
class MemoryWriteNode(BaseNode):
|
class MemoryWriteNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: MemoryWriteNodeConfig | None = None
|
self.typed_config: MemoryWriteNodeConfig | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
|
|||||||
@@ -104,13 +104,15 @@ class NodeFactory:
|
|||||||
def create_node(
|
def create_node(
|
||||||
cls,
|
cls,
|
||||||
node_config: dict[str, Any],
|
node_config: dict[str, Any],
|
||||||
workflow_config: dict[str, Any]
|
workflow_config: dict[str, Any],
|
||||||
|
down_stream_nodes: list[str]
|
||||||
) -> WorkflowNode | None:
|
) -> WorkflowNode | None:
|
||||||
"""创建节点实例
|
"""创建节点实例
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_config: 节点配置
|
node_config: 节点配置
|
||||||
workflow_config: 工作流配置
|
workflow_config: 工作流配置
|
||||||
|
down_stream_nodes: 下游节点
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
节点实例或 None(对于不支持的节点类型)
|
节点实例或 None(对于不支持的节点类型)
|
||||||
@@ -127,7 +129,7 @@ class NodeFactory:
|
|||||||
|
|
||||||
# 创建节点实例
|
# 创建节点实例
|
||||||
logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})")
|
logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})")
|
||||||
return node_class(node_config, workflow_config)
|
return node_class(node_config, workflow_config, down_stream_nodes)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_types(cls) -> list[str]:
|
def get_supported_types(cls) -> list[str]:
|
||||||
|
|||||||
@@ -21,8 +21,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class ParameterExtractorNode(BaseNode):
|
class ParameterExtractorNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: ParameterExtractorNodeConfig | None = None
|
self.typed_config: ParameterExtractorNodeConfig | None = None
|
||||||
self.response_metadata = {}
|
self.response_metadata = {}
|
||||||
|
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ DEFAULT_EMPTY_QUESTION_CASE = f"{DEFAULT_CASE_PREFIX}1"
|
|||||||
class QuestionClassifierNode(BaseNode):
|
class QuestionClassifierNode(BaseNode):
|
||||||
"""问题分类器节点"""
|
"""问题分类器节点"""
|
||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: QuestionClassifierNodeConfig | None = None
|
self.typed_config: QuestionClassifierNodeConfig | None = None
|
||||||
self.category_to_case_map = {}
|
self.category_to_case_map = {}
|
||||||
self.response_metadata = {}
|
self.response_metadata = {}
|
||||||
|
|||||||
@@ -27,14 +27,8 @@ class StartNode(BaseNode):
|
|||||||
注意:变量的验证和默认值处理由 Executor 在初始化时完成。
|
注意:变量的验证和默认值处理由 Executor 在初始化时完成。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
"""初始化 Start 节点
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
|
|
||||||
Args:
|
|
||||||
node_config: 节点配置
|
|
||||||
workflow_config: 工作流配置
|
|
||||||
"""
|
|
||||||
super().__init__(node_config, workflow_config)
|
|
||||||
|
|
||||||
# 解析并验证配置
|
# 解析并验证配置
|
||||||
self.typed_config: StartNodeConfig | None = None
|
self.typed_config: StartNodeConfig | None = None
|
||||||
@@ -62,7 +56,6 @@ class StartNode(BaseNode):
|
|||||||
包含系统参数、会话变量和自定义变量的字典
|
包含系统参数、会话变量和自定义变量的字典
|
||||||
"""
|
"""
|
||||||
self.typed_config = StartNodeConfig(**self.config)
|
self.typed_config = StartNodeConfig(**self.config)
|
||||||
logger.info(f"节点 {self.node_id} (Start) 开始执行")
|
|
||||||
|
|
||||||
# 处理自定义变量(传入 pool 避免重复创建)
|
# 处理自定义变量(传入 pool 避免重复创建)
|
||||||
custom_vars = self._process_custom_variables(variable_pool)
|
custom_vars = self._process_custom_variables(variable_pool)
|
||||||
@@ -77,9 +70,9 @@ class StartNode(BaseNode):
|
|||||||
**custom_vars # 自定义变量作为节点输出的一部分
|
**custom_vars # 自定义变量作为节点输出的一部分
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"节点 {self.node_id} (Start) 执行完成,"
|
f"Node {self.node_id} (Start) execution completed, "
|
||||||
f"输出了 {len(custom_vars)} 个自定义变量"
|
f"outputting {len(custom_vars)} custom variables"
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ TEMPLATE_PATTERN = re.compile(r"\{\{.*?}}")
|
|||||||
class ToolNode(BaseNode):
|
class ToolNode(BaseNode):
|
||||||
"""工具节点"""
|
"""工具节点"""
|
||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: ToolNodeConfig | None = None
|
self.typed_config: ToolNodeConfig | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class VariableAggregatorNode(BaseNode):
|
class VariableAggregatorNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: VariableAggregatorNodeConfig | None = None
|
self.typed_config: VariableAggregatorNodeConfig | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
|
|||||||
@@ -153,7 +153,8 @@ class TemplateRenderer:
|
|||||||
|
|
||||||
|
|
||||||
# 全局渲染器实例(严格模式)
|
# 全局渲染器实例(严格模式)
|
||||||
_default_renderer = TemplateRenderer(strict=True)
|
_strict_renderer = TemplateRenderer(strict=True)
|
||||||
|
_lenient_renderer = TemplateRenderer(strict=False)
|
||||||
|
|
||||||
|
|
||||||
def render_template(
|
def render_template(
|
||||||
@@ -184,7 +185,7 @@ def render_template(
|
|||||||
... )
|
... )
|
||||||
'请分析: 这是一段文本'
|
'请分析: 这是一段文本'
|
||||||
"""
|
"""
|
||||||
renderer = TemplateRenderer(strict=strict)
|
renderer = _strict_renderer if strict else _lenient_renderer
|
||||||
return renderer.render(template, conv_vars, node_outputs, system_vars)
|
return renderer.render(template, conv_vars, node_outputs, system_vars)
|
||||||
|
|
||||||
|
|
||||||
@@ -197,4 +198,4 @@ def validate_template(template: str) -> list[str]:
|
|||||||
Returns:
|
Returns:
|
||||||
错误列表
|
错误列表
|
||||||
"""
|
"""
|
||||||
return _default_renderer.validate(template)
|
return _strict_renderer.validate(template)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
|
from collections import defaultdict, deque
|
||||||
from typing import Any, Union, TYPE_CHECKING
|
from typing import Any, Union, TYPE_CHECKING
|
||||||
|
|
||||||
from app.core.workflow.nodes.enums import NodeType
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
@@ -119,7 +120,6 @@ class WorkflowValidator:
|
|||||||
errors = []
|
errors = []
|
||||||
|
|
||||||
graphs = cls.get_subgraph(workflow_config)
|
graphs = cls.get_subgraph(workflow_config)
|
||||||
logger.info(graphs)
|
|
||||||
for index, graph in enumerate(graphs):
|
for index, graph in enumerate(graphs):
|
||||||
nodes = graph.get("nodes", [])
|
nodes = graph.get("nodes", [])
|
||||||
edges = graph.get("edges", [])
|
edges = graph.get("edges", [])
|
||||||
@@ -183,7 +183,7 @@ class WorkflowValidator:
|
|||||||
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
|
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
|
||||||
if has_cycle:
|
if has_cycle:
|
||||||
errors.append(
|
errors.append(
|
||||||
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}"
|
f"工作流存在循环依赖(请使用 loop/iteration 节点实现循环): {' -> '.join(cycle_path)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 8. 验证变量名
|
# 8. 验证变量名
|
||||||
@@ -204,18 +204,18 @@ class WorkflowValidator:
|
|||||||
Returns:
|
Returns:
|
||||||
可达节点 ID 集合
|
可达节点 ID 集合
|
||||||
"""
|
"""
|
||||||
|
adj = defaultdict(list)
|
||||||
|
for edge in edges:
|
||||||
|
adj[edge["source"]].append(edge["target"])
|
||||||
|
|
||||||
reachable = {start_id}
|
reachable = {start_id}
|
||||||
queue = [start_id]
|
queue = deque([start_id])
|
||||||
|
|
||||||
while queue:
|
while queue:
|
||||||
current = queue.pop(0)
|
current = queue.popleft()
|
||||||
for edge in edges:
|
for target in adj[current]:
|
||||||
if edge.get("source") == current:
|
if target not in reachable:
|
||||||
target = edge.get("target")
|
reachable.add(target)
|
||||||
if target and target not in reachable:
|
queue.append(target)
|
||||||
reachable.add(target)
|
|
||||||
queue.append(target)
|
|
||||||
|
|
||||||
return reachable
|
return reachable
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -229,10 +229,6 @@ class WorkflowValidator:
|
|||||||
Returns:
|
Returns:
|
||||||
(has_cycle, cycle_path): 是否有循环和循环路径
|
(has_cycle, cycle_path): 是否有循环和循环路径
|
||||||
"""
|
"""
|
||||||
# 排除 loop 类型的节点
|
|
||||||
loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"}
|
|
||||||
|
|
||||||
# 构建邻接表(排除 loop 节点的边和错误边)
|
|
||||||
graph: dict[str, list[str]] = {}
|
graph: dict[str, list[str]] = {}
|
||||||
for edge in edges:
|
for edge in edges:
|
||||||
source = edge.get("source")
|
source = edge.get("source")
|
||||||
@@ -243,10 +239,6 @@ class WorkflowValidator:
|
|||||||
if edge_type == "error":
|
if edge_type == "error":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 如果涉及 loop 节点,跳过
|
|
||||||
if source in loop_nodes or target in loop_nodes:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if source and target:
|
if source and target:
|
||||||
if source not in graph:
|
if source not in graph:
|
||||||
graph[source] = []
|
graph[source] = []
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class DictVariable(BaseVariable):
|
|||||||
|
|
||||||
def valid_value(self, value) -> dict:
|
def valid_value(self, value) -> dict:
|
||||||
if not isinstance(value, dict):
|
if not isinstance(value, dict):
|
||||||
raise TypeError(f"Value must be a dict - {type(value)}:{value}")
|
raise TypeError(f"Value must be a dict - {type(value)}:{value}")
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def to_literal(self) -> str:
|
def to_literal(self) -> str:
|
||||||
|
|||||||
@@ -27,9 +27,9 @@ class ModelType(StrEnum):
|
|||||||
RERANK = "rerank"
|
RERANK = "rerank"
|
||||||
# TTS = "tts"
|
# TTS = "tts"
|
||||||
# SPEECH2TEXT = "speech2text"
|
# SPEECH2TEXT = "speech2text"
|
||||||
# IMAGE = "image"
|
IMAGE = "image"
|
||||||
# AUDIO = "audio"
|
# AUDIO = "audio"
|
||||||
# VISION = "vision"
|
VIDEO = "video"
|
||||||
|
|
||||||
|
|
||||||
class ModelProvider(StrEnum):
|
class ModelProvider(StrEnum):
|
||||||
@@ -46,6 +46,7 @@ class ModelProvider(StrEnum):
|
|||||||
XINFERENCE = "xinference"
|
XINFERENCE = "xinference"
|
||||||
GPUSTACK = "gpustack"
|
GPUSTACK = "gpustack"
|
||||||
BEDROCK = "bedrock"
|
BEDROCK = "bedrock"
|
||||||
|
VOLCANO = "volcano"
|
||||||
COMPOSITE = "composite"
|
COMPOSITE = "composite"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,17 @@ class Tenants(Base):
|
|||||||
default_language = Column(String(10), nullable=False, default='zh', server_default='zh', index=True) # 租户默认语言
|
default_language = Column(String(10), nullable=False, default='zh', server_default='zh', index=True) # 租户默认语言
|
||||||
supported_languages = Column(ARRAY(String(10)), nullable=False, default=lambda: ['zh', 'en'], server_default=text("'{zh,en}'")) # 租户支持的语言列表
|
supported_languages = Column(ARRAY(String(10)), nullable=False, default=lambda: ['zh', 'en'], server_default=text("'{zh,en}'")) # 租户支持的语言列表
|
||||||
|
|
||||||
|
# 租户联系信息
|
||||||
|
contact_name = Column(String(100), nullable=True) # 联系人姓名
|
||||||
|
contact_email = Column(String(255), nullable=True) # 联系人邮箱
|
||||||
|
contact_phone = Column(String(50), nullable=True) # 联系人电话
|
||||||
|
|
||||||
|
# 租户套餐信息
|
||||||
|
plan = Column(String(50), nullable=True) # 套餐类型
|
||||||
|
plan_expired_at = Column(DateTime, nullable=True) # 套餐到期时间
|
||||||
|
api_ops_rate_limit = Column(String(100), nullable=True) # API 调用频率限制
|
||||||
|
status = Column(String(50), nullable=True, default='active') # 租户状态
|
||||||
|
|
||||||
# Relationship to users - one tenant has many users
|
# Relationship to users - one tenant has many users
|
||||||
users = relationship("User", back_populates="tenant")
|
users = relationship("User", back_populates="tenant")
|
||||||
|
|
||||||
|
|||||||
@@ -439,7 +439,6 @@ class ModelConfigRepository:
|
|||||||
ModelConfig.is_public
|
ModelConfig.is_public
|
||||||
),
|
),
|
||||||
ModelConfig.provider == provider,
|
ModelConfig.provider == provider,
|
||||||
ModelConfig.is_active,
|
|
||||||
~ModelConfig.is_composite
|
~ModelConfig.is_composite
|
||||||
)
|
)
|
||||||
).all()
|
).all()
|
||||||
|
|||||||
@@ -325,27 +325,30 @@ class FileStorageService:
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_file_url(self, file_key: str, expires: int = 3600) -> str:
|
async def get_file_url(
|
||||||
|
self,
|
||||||
|
file_key: str,
|
||||||
|
expires: int = 3600,
|
||||||
|
file_name: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Get an access URL for a file.
|
Get an access URL for a file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_key: The file key.
|
file_key: The file key.
|
||||||
expires: URL validity period in seconds (default: 1 hour).
|
expires: URL validity period in seconds (default: 1 hour).
|
||||||
|
file_name: If set, adds Content-Disposition: attachment to force download.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
URL for accessing the file.
|
URL for accessing the file.
|
||||||
"""
|
"""
|
||||||
logger.debug(f"Getting file URL: file_key={file_key}, expires={expires}s")
|
logger.debug(f"Getting file URL: file_key={file_key}, expires={expires}s")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
url = await self.storage.get_url(file_key, expires)
|
url = await self.storage.get_url(file_key, expires, file_name=file_name)
|
||||||
logger.debug(f"File URL generated: file_key={file_key}")
|
logger.debug(f"File URL generated: file_key={file_key}")
|
||||||
return url
|
return url
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(f"Error getting file URL: file_key={file_key}, error={str(e)}")
|
||||||
f"Error getting file URL: file_key={file_key}, error={str(e)}"
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
162
api/app/services/generation_service.py
Normal file
162
api/app/services/generation_service.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
"""
|
||||||
|
图片和视频生成服务
|
||||||
|
|
||||||
|
提供统一的生成接口,支持多种 Provider
|
||||||
|
"""
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from app.core.models import RedBearModelConfig, RedBearImageGenerator, RedBearVideoGenerator
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
from app.models.models_model import ModelType
|
||||||
|
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository
|
||||||
|
from app.services.model_service import ModelApiKeyService
|
||||||
|
|
||||||
|
|
||||||
|
class GenerationService:
|
||||||
|
"""生成服务"""
|
||||||
|
|
||||||
|
def __init__(self, db: Session):
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
async def generate_image(
|
||||||
|
self,
|
||||||
|
model_config_id: str,
|
||||||
|
prompt: str,
|
||||||
|
size: Optional[str] = "2k",
|
||||||
|
**kwargs
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
生成图片
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_config_id: 模型配置ID
|
||||||
|
prompt: 提示词
|
||||||
|
size: 图片尺寸
|
||||||
|
**kwargs: 其他参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
生成结果
|
||||||
|
"""
|
||||||
|
# 获取模型配置
|
||||||
|
model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id))
|
||||||
|
if not model_config:
|
||||||
|
raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND)
|
||||||
|
|
||||||
|
if model_config.type != ModelType.IMAGE:
|
||||||
|
raise BusinessException(
|
||||||
|
f"模型类型错误,期望 {ModelType.IMAGE},实际 {model_config.type}",
|
||||||
|
code=BizCode.INVALID_PARAMETER
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取 API Key
|
||||||
|
api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id))
|
||||||
|
if not api_key_info:
|
||||||
|
raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND)
|
||||||
|
|
||||||
|
# 创建配置
|
||||||
|
config = RedBearModelConfig(
|
||||||
|
model_name=api_key_info.model_name,
|
||||||
|
provider=api_key_info.provider,
|
||||||
|
api_key=api_key_info.api_key,
|
||||||
|
base_url=api_key_info.api_base,
|
||||||
|
extra_params=api_key_info.config or {}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 生成图片
|
||||||
|
generator = RedBearImageGenerator(config)
|
||||||
|
result = await generator.agenerate(prompt, size, **kwargs)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def generate_video(
|
||||||
|
self,
|
||||||
|
model_config_id: str,
|
||||||
|
prompt: str,
|
||||||
|
duration: Optional[int] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
生成视频
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_config_id: 模型配置ID
|
||||||
|
prompt: 提示词
|
||||||
|
duration: 视频时长(秒)
|
||||||
|
**kwargs: 其他参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
生成结果(包含任务ID)
|
||||||
|
"""
|
||||||
|
# 获取模型配置
|
||||||
|
model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id))
|
||||||
|
if not model_config:
|
||||||
|
raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND)
|
||||||
|
|
||||||
|
if model_config.type != ModelType.VIDEO:
|
||||||
|
raise BusinessException(
|
||||||
|
f"模型类型错误,期望 {ModelType.VIDEO},实际 {model_config.type}",
|
||||||
|
code=BizCode.INVALID_PARAMETER
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取 API Key
|
||||||
|
api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id))
|
||||||
|
if not api_key_info:
|
||||||
|
raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND)
|
||||||
|
|
||||||
|
# 创建配置
|
||||||
|
config = RedBearModelConfig(
|
||||||
|
model_name=api_key_info.model_name,
|
||||||
|
provider=api_key_info.provider,
|
||||||
|
api_key=api_key_info.api_key,
|
||||||
|
base_url=api_key_info.api_base,
|
||||||
|
extra_params=api_key_info.config or {}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 生成视频
|
||||||
|
generator = RedBearVideoGenerator(config)
|
||||||
|
result = await generator.agenerate(prompt, duration, **kwargs)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def get_video_task_status(
|
||||||
|
self,
|
||||||
|
model_config_id: str,
|
||||||
|
task_id: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
查询视频生成任务状态
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_config_id: 模型配置ID
|
||||||
|
task_id: 任务ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
任务状态信息
|
||||||
|
"""
|
||||||
|
# 获取模型配置
|
||||||
|
model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id))
|
||||||
|
if not model_config:
|
||||||
|
raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND)
|
||||||
|
|
||||||
|
# 获取 API Key
|
||||||
|
api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id))
|
||||||
|
if not api_key_info:
|
||||||
|
raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND)
|
||||||
|
|
||||||
|
# 创建配置
|
||||||
|
config = RedBearModelConfig(
|
||||||
|
model_name=api_key_info.model_name,
|
||||||
|
provider=api_key_info.provider,
|
||||||
|
api_key=api_key_info.api_key,
|
||||||
|
base_url=api_key_info.api_base,
|
||||||
|
extra_params=api_key_info.config or {}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 查询任务状态
|
||||||
|
generator = RedBearVideoGenerator(config)
|
||||||
|
result = await generator.aget_task_status(task_id)
|
||||||
|
|
||||||
|
return result
|
||||||
@@ -357,6 +357,7 @@ class MemoryAgentService:
|
|||||||
if file_object is None:
|
if file_object is None:
|
||||||
continue
|
continue
|
||||||
message["file_content"].append((file_object, file["type"]))
|
message["file_content"].append((file_object, file["type"]))
|
||||||
|
logger.info(messages)
|
||||||
|
|
||||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||||
try:
|
try:
|
||||||
@@ -606,7 +607,7 @@ class MemoryAgentService:
|
|||||||
retrieved_content.append({query: statements})
|
retrieved_content.append({query: statements})
|
||||||
|
|
||||||
# 如果 retrieved_content 为空,设置为空字符串
|
# 如果 retrieved_content 为空,设置为空字符串
|
||||||
if retrieved_content == []:
|
if not retrieved_content:
|
||||||
retrieved_content = ''
|
retrieved_content = ''
|
||||||
|
|
||||||
# 只有当回答不是"信息不足"且不是快速检索时才保存
|
# 只有当回答不是"信息不足"且不是快速检索时才保存
|
||||||
|
|||||||
@@ -154,10 +154,17 @@ class ModelConfigService:
|
|||||||
}
|
}
|
||||||
|
|
||||||
elif model_type_lower == "embedding":
|
elif model_type_lower == "embedding":
|
||||||
# Embedding 模型验证(在线程中运行同步方法)
|
# Embedding 模型验证
|
||||||
|
# 统一使用 RedBearEmbeddings(自动支持火山引擎多模态)
|
||||||
embedding = RedBearEmbeddings(model_config)
|
embedding = RedBearEmbeddings(model_config)
|
||||||
test_texts = [test_message, "测试文本"]
|
test_texts = [test_message, "测试文本"]
|
||||||
vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
|
|
||||||
|
# 火山引擎使用 embed_batch,其他使用 embed_documents
|
||||||
|
if provider.lower() == "volcano":
|
||||||
|
vectors = await asyncio.to_thread(embedding.embed_batch, test_texts)
|
||||||
|
else:
|
||||||
|
vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -194,6 +201,56 @@ class ModelConfigService:
|
|||||||
"error": None
|
"error": None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
elif model_type_lower == "image":
|
||||||
|
# 图片生成模型验证
|
||||||
|
from app.core.models.generation import RedBearImageGenerator
|
||||||
|
|
||||||
|
generator = RedBearImageGenerator(model_config)
|
||||||
|
result = await generator.agenerate(
|
||||||
|
prompt="a cute panda",
|
||||||
|
size="2K"
|
||||||
|
)
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
logger.info(f"成功生成图片,结果: {result}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"valid": True,
|
||||||
|
"message": "图片生成模型配置验证成功",
|
||||||
|
"response": f"成功生成图片,结果: {result}",
|
||||||
|
"elapsed_time": elapsed_time,
|
||||||
|
"usage": {
|
||||||
|
"prompt_length": len("a cute panda"),
|
||||||
|
"image_count": 1
|
||||||
|
},
|
||||||
|
"error": None
|
||||||
|
}
|
||||||
|
|
||||||
|
elif model_type_lower == "video":
|
||||||
|
# 视频生成模型验证
|
||||||
|
from app.core.models.generation import RedBearVideoGenerator
|
||||||
|
|
||||||
|
generator = RedBearVideoGenerator(model_config)
|
||||||
|
result = await generator.agenerate(
|
||||||
|
prompt="a cute panda playing in bamboo forest",
|
||||||
|
duration=5
|
||||||
|
)
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
|
# 视频生成是异步任务,返回任务ID
|
||||||
|
task_id = result.get("task_id") if isinstance(result, dict) else None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"valid": True,
|
||||||
|
"message": "视频生成模型配置验证成功",
|
||||||
|
"response": f"成功创建视频生成任务,任务ID: {task_id}",
|
||||||
|
"elapsed_time": elapsed_time,
|
||||||
|
"usage": {
|
||||||
|
"prompt_length": len("a cute panda playing in bamboo forest"),
|
||||||
|
"task_id": task_id
|
||||||
|
},
|
||||||
|
"error": None
|
||||||
|
}
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return {
|
return {
|
||||||
"valid": False,
|
"valid": False,
|
||||||
|
|||||||
@@ -294,6 +294,7 @@ PROVIDER_STRATEGIES = {
|
|||||||
"bedrock": BedrockFormatStrategy,
|
"bedrock": BedrockFormatStrategy,
|
||||||
"anthropic": BedrockFormatStrategy,
|
"anthropic": BedrockFormatStrategy,
|
||||||
"openai": OpenAIFormatStrategy,
|
"openai": OpenAIFormatStrategy,
|
||||||
|
"volcano": OpenAIFormatStrategy,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -250,6 +250,20 @@ def deactivate_user(db: Session, user_id_to_deactivate: uuid.UUID, current_user:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 检查是否为租户联系人
|
||||||
|
from app.models.tenant_model import Tenants
|
||||||
|
tenant = db.query(Tenants).filter(Tenants.id == db_user.tenant_id).first()
|
||||||
|
if tenant and tenant.contact_email and tenant.contact_email == db_user.email:
|
||||||
|
business_logger.warning(f"尝试停用租户联系人: {db_user.email}, tenant_id={db_user.tenant_id}")
|
||||||
|
raise BusinessException(
|
||||||
|
"该管理员是租户联系人,请先在租户信息中更换联系邮箱,再禁用此管理员",
|
||||||
|
code=BizCode.FORBIDDEN,
|
||||||
|
context={
|
||||||
|
"user_id": str(user_id_to_deactivate),
|
||||||
|
"tenant_id": str(db_user.tenant_id)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# 停用用户
|
# 停用用户
|
||||||
business_logger.debug(f"执行用户停用: {db_user.username} (ID: {user_id_to_deactivate})")
|
business_logger.debug(f"执行用户停用: {db_user.username} (ID: {user_id_to_deactivate})")
|
||||||
db_user.is_active = False
|
db_user.is_active = False
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from app.aioRedis import aio_redis_set, aio_redis_get
|
|||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.workflow.adapters.base_adapter import WorkflowImportResult, WorkflowParserResult
|
from app.core.workflow.adapters.base_adapter import WorkflowImportResult, WorkflowParserResult
|
||||||
from app.core.workflow.adapters.errors import UnsupportPlatform, InvalidConfiguration
|
from app.core.workflow.adapters.errors import UnsupportedPlatform, InvalidConfiguration
|
||||||
from app.core.workflow.adapters.registry import PlatformAdapterRegistry
|
from app.core.workflow.adapters.registry import PlatformAdapterRegistry
|
||||||
from app.schemas import AppCreate
|
from app.schemas import AppCreate
|
||||||
from app.schemas.workflow_schema import WorkflowConfigCreate
|
from app.schemas.workflow_schema import WorkflowConfigCreate
|
||||||
@@ -46,7 +46,7 @@ class WorkflowImportService:
|
|||||||
success=False,
|
success=False,
|
||||||
temp_id=None,
|
temp_id=None,
|
||||||
workflow_id=None,
|
workflow_id=None,
|
||||||
errors=[UnsupportPlatform(platform=platform)]
|
errors=[UnsupportedPlatform(platform=platform)]
|
||||||
)
|
)
|
||||||
|
|
||||||
adapter = self.registry.get_adapter(platform, config)
|
adapter = self.registry.get_adapter(platform, config)
|
||||||
|
|||||||
42
api/migrations/versions/1ea8fe97b5b7_202603252115.py
Normal file
42
api/migrations/versions/1ea8fe97b5b7_202603252115.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""202603252115
|
||||||
|
|
||||||
|
Revision ID: 1ea8fe97b5b7
|
||||||
|
Revises: e28bcc212da5
|
||||||
|
Create Date: 2026-03-25 21:14:41.825048
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '1ea8fe97b5b7'
|
||||||
|
down_revision: Union[str, None] = 'e28bcc212da5'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.add_column('tenants', sa.Column('contact_name', sa.String(length=100), nullable=True))
|
||||||
|
op.add_column('tenants', sa.Column('contact_email', sa.String(length=255), nullable=True))
|
||||||
|
op.add_column('tenants', sa.Column('contact_phone', sa.String(length=50), nullable=True))
|
||||||
|
op.add_column('tenants', sa.Column('plan', sa.String(length=50), nullable=True))
|
||||||
|
op.add_column('tenants', sa.Column('plan_expired_at', sa.DateTime(), nullable=True))
|
||||||
|
op.add_column('tenants', sa.Column('api_ops_rate_limit', sa.String(length=100), nullable=True))
|
||||||
|
op.add_column('tenants', sa.Column('status', sa.String(length=50), nullable=True))
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_column('tenants', 'status')
|
||||||
|
op.drop_column('tenants', 'api_ops_rate_limit')
|
||||||
|
op.drop_column('tenants', 'plan_expired_at')
|
||||||
|
op.drop_column('tenants', 'plan')
|
||||||
|
op.drop_column('tenants', 'contact_phone')
|
||||||
|
op.drop_column('tenants', 'contact_email')
|
||||||
|
op.drop_column('tenants', 'contact_name')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -147,6 +147,7 @@ dependencies = [
|
|||||||
"modelscope>=1.34.0",
|
"modelscope>=1.34.0",
|
||||||
"python-magic>=0.4.14; sys_platform == 'linux' or sys_platform == 'darwin'",
|
"python-magic>=0.4.14; sys_platform == 'linux' or sys_platform == 'darwin'",
|
||||||
"python-magic-bin>=0.4.14; sys_platform=='win32'",
|
"python-magic-bin>=0.4.14; sys_platform=='win32'",
|
||||||
|
"volcengine-python-sdk[ark]==5.0.19"
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
|
|||||||
Reference in New Issue
Block a user