Merge branch 'develop' into feature/tool_yjp
This commit is contained in:
2
api/app/cache/memory/__init__.py
vendored
2
api/app/cache/memory/__init__.py
vendored
@@ -4,7 +4,9 @@ Memory 缓存模块
|
||||
提供记忆系统相关的缓存功能
|
||||
"""
|
||||
from .interest_memory import InterestMemoryCache
|
||||
from .activity_stats_cache import ActivityStatsCache
|
||||
|
||||
__all__ = [
|
||||
"InterestMemoryCache",
|
||||
"ActivityStatsCache",
|
||||
]
|
||||
|
||||
124
api/app/cache/memory/activity_stats_cache.py
vendored
Normal file
124
api/app/cache/memory/activity_stats_cache.py
vendored
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
Recent Activity Stats Cache
|
||||
|
||||
记忆提取活动统计缓存模块
|
||||
用于缓存每次记忆提取流程的统计数据,按 workspace_id 存储,24小时后释放
|
||||
查询命令:cache:memory:activity_stats:by_workspace:7de31a97-40a6-4fc0-b8d3-15c89f523843
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import aio_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 缓存过期时间:24小时
|
||||
ACTIVITY_STATS_CACHE_EXPIRE = 86400
|
||||
|
||||
|
||||
class ActivityStatsCache:
|
||||
"""记忆提取活动统计缓存类"""
|
||||
|
||||
PREFIX = "cache:memory:activity_stats"
|
||||
|
||||
@classmethod
|
||||
def _get_key(cls, workspace_id: str) -> str:
|
||||
"""生成 Redis key
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
|
||||
Returns:
|
||||
完整的 Redis key
|
||||
"""
|
||||
return f"{cls.PREFIX}:by_workspace:{workspace_id}"
|
||||
|
||||
@classmethod
|
||||
async def set_activity_stats(
|
||||
cls,
|
||||
workspace_id: str,
|
||||
stats: Dict[str, Any],
|
||||
expire: int = ACTIVITY_STATS_CACHE_EXPIRE,
|
||||
) -> bool:
|
||||
"""设置记忆提取活动统计缓存
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
stats: 统计数据,格式:
|
||||
{
|
||||
"chunk_count": int,
|
||||
"statements_count": int,
|
||||
"triplet_entities_count": int,
|
||||
"triplet_relations_count": int,
|
||||
"temporal_count": int,
|
||||
}
|
||||
expire: 过期时间(秒),默认24小时
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(workspace_id)
|
||||
payload = {
|
||||
"stats": stats,
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"workspace_id": workspace_id,
|
||||
"cached": True,
|
||||
}
|
||||
value = json.dumps(payload, ensure_ascii=False)
|
||||
await aio_redis.set(key, value, ex=expire)
|
||||
logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"设置活动统计缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_activity_stats(
|
||||
cls,
|
||||
workspace_id: str,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""获取记忆提取活动统计缓存
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
|
||||
Returns:
|
||||
统计数据字典,缓存不存在或已过期返回 None
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(workspace_id)
|
||||
value = await aio_redis.get(key)
|
||||
if value:
|
||||
payload = json.loads(value)
|
||||
logger.info(f"命中活动统计缓存: {key}")
|
||||
return payload
|
||||
logger.info(f"活动统计缓存不存在或已过期: {key}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取活动统计缓存失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def delete_activity_stats(
|
||||
cls,
|
||||
workspace_id: str,
|
||||
) -> bool:
|
||||
"""删除记忆提取活动统计缓存
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(workspace_id)
|
||||
result = await aio_redis.delete(key)
|
||||
logger.info(f"删除活动统计缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"删除活动统计缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
@@ -113,6 +113,7 @@ celery_app.conf.update(
|
||||
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.init_implicit_emotions_for_users': {'queue': 'periodic_tasks'},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -149,6 +149,17 @@ async def get_workspace_end_users(
|
||||
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
# 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据
|
||||
try:
|
||||
from app.celery_app import celery_app as _celery_app
|
||||
_celery_app.send_task(
|
||||
"app.tasks.init_implicit_emotions_for_users",
|
||||
kwargs={"end_user_ids": end_user_ids},
|
||||
)
|
||||
api_logger.info(f"已触发隐性记忆按需初始化任务,候选用户数: {len(end_user_ids)}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"触发隐性记忆按需初始化任务失败(不影响主流程): {e}")
|
||||
|
||||
# 并发执行配置查询和记忆数量查询
|
||||
memory_configs_map, memory_nums_map = await asyncio.gather(
|
||||
get_memory_configs(),
|
||||
|
||||
@@ -544,10 +544,11 @@ async def clear_hot_memory_tags_cache(
|
||||
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
|
||||
async def get_recent_activity_stats_api(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info("Recent activity stats requested")
|
||||
) -> dict:
|
||||
workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None
|
||||
api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}")
|
||||
try:
|
||||
result = await analytics_recent_activity_stats()
|
||||
result = await analytics_recent_activity_stats(workspace_id=workspace_id)
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Recent activity stats failed: {str(e)}")
|
||||
|
||||
@@ -371,6 +371,11 @@ def update_model(
|
||||
|
||||
if model_data.type is not None or model_data.provider is not None:
|
||||
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
|
||||
|
||||
if model_data.is_active:
|
||||
active_keys = ModelApiKeyService.get_api_keys_by_model(db=db, model_config_id=model_id, is_active=model_data.is_active)
|
||||
if not active_keys:
|
||||
raise BusinessException("请先为该模型配置可用的 API Key", BizCode.INVALID_PARAMETER)
|
||||
|
||||
try:
|
||||
api_logger.debug(f"开始更新模型配置: model_id={model_id}")
|
||||
|
||||
@@ -97,6 +97,12 @@ async def create_tool(
|
||||
):
|
||||
"""创建工具"""
|
||||
try:
|
||||
# 将 MCP 来源字段合并进 config
|
||||
if request.tool_type == ToolType.MCP:
|
||||
for key in ("source_channel", "market_id", "market_config_id", "mcp_service_id"):
|
||||
val = getattr(request, key, None)
|
||||
if val is not None:
|
||||
request.config[key] = val
|
||||
tool_id = service.create_tool(
|
||||
name=request.name,
|
||||
tool_type=request.tool_type,
|
||||
|
||||
@@ -192,8 +192,10 @@ class Settings:
|
||||
# Celery configuration (internal)
|
||||
# NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持
|
||||
# 详见 docs/celery-env-bug-report.md
|
||||
REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "1"))
|
||||
REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "2"))
|
||||
# 默认使用 Redis DB 3 (broker) 和 DB 4 (backend),与业务缓存 (DB 1/2) 隔离
|
||||
# 多人共用同一 Redis 时,每位开发者应在 .env 中配置不同的 DB 编号避免任务互相干扰
|
||||
REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "3"))
|
||||
REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "4"))
|
||||
|
||||
# SMTP Email Configuration
|
||||
SMTP_SERVER: str = os.getenv("SMTP_SERVER", "smtp.gmail.com")
|
||||
|
||||
@@ -111,7 +111,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
"content_length": len(content),
|
||||
"llm_model_id": memory_config.llm_model_id if memory_config else None
|
||||
"llm_model_id": str(memory_config.llm_model_id) if memory_config else None
|
||||
}
|
||||
|
||||
logger.error(f"Split_The_Problem error details: {error_details}")
|
||||
@@ -221,7 +221,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
"questions_count": len(databasets),
|
||||
"llm_model_id": memory_config.llm_model_id if memory_config else None
|
||||
"llm_model_id": str(memory_config.llm_model_id) if memory_config else None
|
||||
}
|
||||
|
||||
logger.error(f"Problem_Extension error details: {error_details}")
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from app.cache.memory.interest_memory import InterestMemoryCache
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.utils.write_tools import write
|
||||
from app.core.logging_config import get_agent_logger
|
||||
@@ -40,6 +41,15 @@ async def write_node(state: WriteState) -> WriteState:
|
||||
)
|
||||
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
|
||||
|
||||
# 写入 neo4j 成功后,删除该用户的兴趣分布缓存,确保下次请求重新生成
|
||||
for lang in ["zh", "en"]:
|
||||
deleted = await InterestMemoryCache.delete_interest_distribution(
|
||||
end_user_id=end_user_id,
|
||||
language=lang,
|
||||
)
|
||||
if deleted:
|
||||
logger.info(f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}")
|
||||
|
||||
write_result = {
|
||||
"status": "success",
|
||||
"data": structured_messages,
|
||||
|
||||
@@ -82,7 +82,9 @@ async def get_chunked_dialogs(
|
||||
pruning_config = PruningConfig(
|
||||
pruning_switch=memory_config.pruning_enabled,
|
||||
pruning_scene=memory_config.pruning_scene or "education",
|
||||
pruning_threshold=memory_config.pruning_threshold
|
||||
pruning_threshold=memory_config.pruning_threshold,
|
||||
scene_id=str(memory_config.scene_id) if memory_config.scene_id else None,
|
||||
ontology_classes=memory_config.ontology_classes,
|
||||
)
|
||||
logger.info(f"[剪枝] 加载配置: switch={pruning_config.pruning_switch}, scene={pruning_config.pruning_scene}, threshold={pruning_config.pruning_threshold}")
|
||||
|
||||
|
||||
@@ -225,5 +225,24 @@ async def write(
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
|
||||
|
||||
# 将提取统计写入 Redis,按 workspace_id 存储
|
||||
try:
|
||||
from app.cache.memory.activity_stats_cache import ActivityStatsCache
|
||||
|
||||
stats_to_cache = {
|
||||
"chunk_count": len(all_chunk_nodes) if all_chunk_nodes else 0,
|
||||
"statements_count": len(all_statement_nodes) if all_statement_nodes else 0,
|
||||
"triplet_entities_count": len(all_entity_nodes) if all_entity_nodes else 0,
|
||||
"triplet_relations_count": len(all_entity_entity_edges) if all_entity_entity_edges else 0,
|
||||
"temporal_count": 0,
|
||||
}
|
||||
await ActivityStatsCache.set_activity_stats(
|
||||
workspace_id=str(memory_config.workspace_id),
|
||||
stats=stats_to_cache,
|
||||
)
|
||||
logger.info(f"[WRITE] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}")
|
||||
except Exception as cache_err:
|
||||
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
||||
|
||||
logger.info("=== Pipeline Complete ===")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
@@ -10,7 +10,7 @@ Classes:
|
||||
TemporalSearchParams: Parameters for temporal search queries
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -55,17 +55,26 @@ class PruningConfig(BaseModel):
|
||||
|
||||
Attributes:
|
||||
pruning_switch: Enable or disable semantic pruning
|
||||
pruning_scene: Scene type for pruning ('education', 'online_service', 'outbound')
|
||||
pruning_scene: Scene name for pruning, either a built-in key
|
||||
('education', 'online_service', 'outbound') or a custom scene_name
|
||||
from ontology_scene table
|
||||
pruning_threshold: Pruning ratio (0-0.9, max 0.9 to avoid complete removal)
|
||||
scene_id: Optional ontology scene UUID, used to load custom ontology classes
|
||||
ontology_classes: List of class_name strings from ontology_class table,
|
||||
injected into the prompt when pruning_scene is not a built-in scene
|
||||
"""
|
||||
pruning_switch: bool = Field(False, description="Enable semantic pruning when True.")
|
||||
pruning_scene: str = Field(
|
||||
"education",
|
||||
description="Scene for pruning: one of 'education', 'online_service', 'outbound'.",
|
||||
description="Scene for pruning: built-in key or custom scene_name from ontology_scene.",
|
||||
)
|
||||
pruning_threshold: float = Field(
|
||||
0.5, ge=0.0, le=0.9,
|
||||
description="Pruning ratio within 0-0.9 (max 0.9 to avoid termination).")
|
||||
scene_id: Optional[str] = Field(None, description="Ontology scene UUID (optional).")
|
||||
ontology_classes: Optional[List[str]] = Field(
|
||||
None, description="Class names from ontology_class table for custom scenes."
|
||||
)
|
||||
|
||||
|
||||
class TemporalSearchParams(BaseModel):
|
||||
|
||||
@@ -86,19 +86,26 @@ class SemanticPruner:
|
||||
self._detailed_prune_logging = True # 是否启用详细日志
|
||||
self._max_debug_msgs_per_dialog = 20 # 每个对话最多记录前N条消息的详细日志
|
||||
|
||||
# 加载场景特定配置
|
||||
# 加载场景特定配置(内置场景走专门规则,自定义场景 fallback 到通用规则)
|
||||
self.scene_config: ScenePatterns = SceneConfigRegistry.get_config(
|
||||
self.config.pruning_scene,
|
||||
fallback_to_generic=True
|
||||
)
|
||||
|
||||
# 检查场景是否有专门支持
|
||||
is_supported = SceneConfigRegistry.is_scene_supported(self.config.pruning_scene)
|
||||
if is_supported:
|
||||
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 使用专门配置")
|
||||
# 判断是否为内置专门场景
|
||||
self._is_builtin_scene = SceneConfigRegistry.is_scene_supported(self.config.pruning_scene)
|
||||
|
||||
# 自定义场景的本体类型列表(用于注入提示词)
|
||||
self._ontology_classes = getattr(self.config, "ontology_classes", None) or []
|
||||
|
||||
if self._is_builtin_scene:
|
||||
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 使用内置专门配置")
|
||||
else:
|
||||
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 未预定义,使用通用配置(保守策略)")
|
||||
self._log(f"[剪枝-初始化] 支持的场景: {SceneConfigRegistry.get_all_scenes()}")
|
||||
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 为自定义场景,使用通用规则 + 本体类型提示词注入")
|
||||
if self._ontology_classes:
|
||||
self._log(f"[剪枝-初始化] 注入本体类型: {self._ontology_classes}")
|
||||
else:
|
||||
self._log(f"[剪枝-初始化] 未找到本体类型,将使用通用提示词")
|
||||
|
||||
# Load Jinja2 template
|
||||
self.template = prompt_env.get_template("extracat_Pruning.jinja2")
|
||||
@@ -424,12 +431,16 @@ class SemanticPruner:
|
||||
self._log(f"[剪枝-缓存] LRU缓存已满,删除最旧条目")
|
||||
|
||||
rendered = self.template.render(
|
||||
pruning_scene=self.config.pruning_scene,
|
||||
pruning_scene=self.config.pruning_scene,
|
||||
is_builtin_scene=self._is_builtin_scene,
|
||||
ontology_classes=self._ontology_classes,
|
||||
dialog_text=dialog_text,
|
||||
language=self.language
|
||||
)
|
||||
log_template_rendering("extracat_Pruning.jinja2", {
|
||||
"pruning_scene": self.config.pruning_scene,
|
||||
"is_builtin_scene": self._is_builtin_scene,
|
||||
"ontology_classes_count": len(self._ontology_classes),
|
||||
"language": self.language
|
||||
})
|
||||
log_prompt_rendering("pruning-extract", rendered)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{#
|
||||
对话级抽取与相关性判定模板(用于剪枝加速)
|
||||
输入:pruning_scene, dialog_text
|
||||
输入:pruning_scene, is_builtin_scene, ontology_classes, dialog_text, language
|
||||
输出:严格 JSON(不要包含任何多余文本),字段:
|
||||
- is_related: bool,是否与所选场景相关
|
||||
- times: [string],从对话中抽取的时间相关文本(日期、时间、时间段、有效期等)
|
||||
@@ -16,7 +16,8 @@
|
||||
- 仅输出上述键;避免多余解释或字段。
|
||||
#}
|
||||
|
||||
{% set scene_instructions = {
|
||||
{# ── 内置场景的固定说明 ── #}
|
||||
{% set builtin_scene_instructions = {
|
||||
'education': {
|
||||
'zh': '教育场景:教学、课程、考试、作业、老师/学生互动、学习资源、学校管理等。',
|
||||
'en': 'Education Scenario: Teaching, courses, exams, homework, teacher/student interaction, learning resources, school management, etc.'
|
||||
@@ -31,16 +32,40 @@
|
||||
}
|
||||
} %}
|
||||
|
||||
{% set scene_key = pruning_scene %}
|
||||
{% if scene_key not in scene_instructions %}
|
||||
{% set scene_key = 'education' %}
|
||||
{# ── 确定最终使用的场景说明 ── #}
|
||||
{% if is_builtin_scene %}
|
||||
{# 内置专门场景:使用固定说明 #}
|
||||
{% set scene_key = pruning_scene %}
|
||||
{% if scene_key not in builtin_scene_instructions %}{% set scene_key = 'education' %}{% endif %}
|
||||
{% set instruction = builtin_scene_instructions[scene_key][language] if language in ['zh', 'en'] else builtin_scene_instructions[scene_key]['zh'] %}
|
||||
{% set custom_types_str = '' %}
|
||||
{% else %}
|
||||
{# 自定义场景:使用场景名称 + 本体类型列表构建说明 #}
|
||||
{% if ontology_classes and ontology_classes | length > 0 %}
|
||||
{% if language == 'en' %}
|
||||
{% set custom_types_str = ontology_classes | join(', ') %}
|
||||
{% set instruction = 'Custom scene "' ~ pruning_scene ~ '": The dialogue is related to this scene if it involves any of the following entity types: ' ~ custom_types_str ~ '.' %}
|
||||
{% else %}
|
||||
{% set custom_types_str = ontology_classes | join('、') %}
|
||||
{% set instruction = '自定义场景「' ~ pruning_scene ~ '」:对话涉及以下任意实体类型时视为相关:' ~ custom_types_str ~ '。' %}
|
||||
{% endif %}
|
||||
{% else %}
|
||||
{# 无本体类型时退化为通用说明 #}
|
||||
{% if language == 'en' %}
|
||||
{% set instruction = 'Custom scene "' ~ pruning_scene ~ '": Determine whether the dialogue content is relevant to this scene based on overall context.' %}
|
||||
{% else %}
|
||||
{% set instruction = '自定义场景「' ~ pruning_scene ~ '」:根据对话整体内容判断是否与该场景相关。' %}
|
||||
{% endif %}
|
||||
{% set custom_types_str = '' %}
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
{% set instruction = scene_instructions[scene_key][language] if language in ['zh', 'en'] else scene_instructions[scene_key]['zh'] %}
|
||||
|
||||
{% if language == "zh" %}
|
||||
请在下方对话全文基础上,按该场景进行一次性抽取并判定相关性:
|
||||
场景说明:{{ instruction }}
|
||||
{% if not is_builtin_scene and custom_types_str %}
|
||||
重要提示:只要对话中出现与上述实体类型({{ custom_types_str }})相关的内容,即判定为相关(is_related=true)。
|
||||
{% endif %}
|
||||
|
||||
对话全文:
|
||||
"""
|
||||
@@ -60,6 +85,9 @@
|
||||
{% else %}
|
||||
Based on the full dialogue below, perform one-time extraction and relevance determination according to this scenario:
|
||||
Scenario Description: {{ instruction }}
|
||||
{% if not is_builtin_scene and custom_types_str %}
|
||||
Important: If the dialogue contains content related to any of the entity types above ({{ custom_types_str }}), mark it as relevant (is_related=true).
|
||||
{% endif %}
|
||||
|
||||
Full Dialogue:
|
||||
"""
|
||||
|
||||
@@ -8,34 +8,60 @@ from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
from app.core.workflow.adapters.base_converter import BaseConverter
|
||||
from app.core.workflow.adapters.errors import UnsupportVariableType, UnknowModelWarning, ExceptionDefineition, \
|
||||
from app.core.workflow.adapters.errors import (
|
||||
UnsupportVariableType,
|
||||
UnknowModelWarning,
|
||||
ExceptionDefineition,
|
||||
ExceptionType
|
||||
from app.core.workflow.nodes.assigner import AssignerNodeConfig
|
||||
)
|
||||
from app.core.workflow.nodes.assigner.config import AssignmentItem
|
||||
from app.core.workflow.nodes.base_config import VariableDefinition, BaseNodeConfig
|
||||
from app.core.workflow.nodes.code import CodeNodeConfig
|
||||
from app.core.workflow.nodes.code.config import InputVariable, OutputVariable
|
||||
from app.core.workflow.nodes.configs import StartNodeConfig, LLMNodeConfig
|
||||
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig, IterationNodeConfig
|
||||
from app.core.workflow.nodes.cycle_graph.config import ConditionDetail as LoopConditionDetail, ConditionsConfig, \
|
||||
from app.core.workflow.nodes.configs import (
|
||||
StartNodeConfig,
|
||||
LLMNodeConfig,
|
||||
AssignerNodeConfig,
|
||||
CodeNodeConfig,
|
||||
LoopNodeConfig,
|
||||
IterationNodeConfig,
|
||||
EndNodeConfig,
|
||||
HttpRequestNodeConfig,
|
||||
IfElseNodeConfig,
|
||||
JinjaRenderNodeConfig,
|
||||
KnowledgeRetrievalNodeConfig,
|
||||
NoteNodeConfig,
|
||||
ParameterExtractorNodeConfig,
|
||||
QuestionClassifierNodeConfig,
|
||||
VariableAggregatorNodeConfig
|
||||
)
|
||||
from app.core.workflow.nodes.cycle_graph.config import (
|
||||
ConditionDetail as LoopConditionDetail,
|
||||
ConditionsConfig,
|
||||
CycleVariable
|
||||
from app.core.workflow.nodes.end import EndNodeConfig
|
||||
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, AssignmentOperator, HttpAuthType, \
|
||||
HttpContentType, HttpErrorHandle
|
||||
from app.core.workflow.nodes.http_request import HttpRequestNodeConfig
|
||||
from app.core.workflow.nodes.http_request.config import HttpAuthConfig, HttpContentTypeConfig, HttpFormData, \
|
||||
HttpTimeOutConfig, HttpRetryConfig, HttpErrorDefaultTamplete, HttpErrorHandleConfig
|
||||
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||
)
|
||||
from app.core.workflow.nodes.enums import (
|
||||
ValueInputType,
|
||||
ComparisonOperator,
|
||||
AssignmentOperator,
|
||||
HttpAuthType,
|
||||
HttpContentType,
|
||||
HttpErrorHandle,
|
||||
NodeType
|
||||
)
|
||||
from app.core.workflow.nodes.http_request.config import (
|
||||
HttpAuthConfig,
|
||||
HttpContentTypeConfig,
|
||||
HttpFormData,
|
||||
HttpTimeOutConfig,
|
||||
HttpRetryConfig,
|
||||
HttpErrorDefaultTamplete,
|
||||
HttpErrorHandleConfig
|
||||
)
|
||||
from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig
|
||||
from app.core.workflow.nodes.jinja_render import JinjaRenderNodeConfig
|
||||
from app.core.workflow.nodes.jinja_render.config import VariablesMappingConfig
|
||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
|
||||
from app.core.workflow.nodes.llm.config import MemoryWindowSetting, MessageConfig
|
||||
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNodeConfig
|
||||
from app.core.workflow.nodes.parameter_extractor.config import ParamsConfig
|
||||
from app.core.workflow.nodes.question_classifier import QuestionClassifierNodeConfig
|
||||
from app.core.workflow.nodes.question_classifier.config import ClassifierConfig
|
||||
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
|
||||
|
||||
@@ -48,24 +74,24 @@ class DifyConverter(BaseConverter):
|
||||
|
||||
def __init__(self):
|
||||
self.CONFIG_CONVERT_MAP = {
|
||||
"start": self.convert_start_node_config,
|
||||
"llm": self.convert_llm_node_config,
|
||||
"answer": self.convert_end_node_config,
|
||||
"if-else": self.convert_if_else_node_config,
|
||||
"loop": self.convert_loop_node_config,
|
||||
"iteration": self.convert_iteration_node_config,
|
||||
"assigner": self.convert_assigner_node_config,
|
||||
"code": self.convert_code_node_config,
|
||||
"http-request": self.convert_http_node_config,
|
||||
"template-transform": self.convert_jinja_render_node_config,
|
||||
"knowledge-retrieval": self.convert_knowledge_node_config,
|
||||
"parameter-extractor": self.convert_parameter_extractor_node_config,
|
||||
"question-classifier": self.convert_question_classifier_node_config,
|
||||
"variable-aggregator": self.convert_variable_aggregator_node_config,
|
||||
"tool": self.convert_tool_node_config,
|
||||
"loop-start": lambda x: {},
|
||||
"iteration-start": lambda x: {},
|
||||
"loop-end": lambda x: {},
|
||||
NodeType.START: self.convert_start_node_config,
|
||||
NodeType.LLM: self.convert_llm_node_config,
|
||||
NodeType.END: self.convert_end_node_config,
|
||||
NodeType.IF_ELSE: self.convert_if_else_node_config,
|
||||
NodeType.LOOP: self.convert_loop_node_config,
|
||||
NodeType.ITERATION: self.convert_iteration_node_config,
|
||||
NodeType.ASSIGNER: self.convert_assigner_node_config,
|
||||
NodeType.CODE: self.convert_code_node_config,
|
||||
NodeType.HTTP_REQUEST: self.convert_http_node_config,
|
||||
NodeType.JINJARENDER: self.convert_jinja_render_node_config,
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: self.convert_knowledge_node_config,
|
||||
NodeType.PARAMETER_EXTRACTOR: self.convert_parameter_extractor_node_config,
|
||||
NodeType.QUESTION_CLASSIFIER: self.convert_question_classifier_node_config,
|
||||
NodeType.VAR_AGGREGATOR: self.convert_variable_aggregator_node_config,
|
||||
NodeType.TOOL: self.convert_tool_node_config,
|
||||
NodeType.NOTES: self.convert_notes_config,
|
||||
NodeType.CYCLE_START: lambda x: {},
|
||||
NodeType.BREAK: lambda x: {},
|
||||
}
|
||||
|
||||
def get_node_convert(self, node_type):
|
||||
@@ -129,11 +155,11 @@ class DifyConverter(BaseConverter):
|
||||
|
||||
@staticmethod
|
||||
def _convert_file(var):
|
||||
pass
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _convert_array_file(var):
|
||||
pass
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def variable_type_map(source_type) -> VariableType | None:
|
||||
@@ -185,6 +211,9 @@ class DifyConverter(BaseConverter):
|
||||
"not empty": ComparisonOperator.NOT_EMPTY,
|
||||
"start with": ComparisonOperator.START_WITH,
|
||||
"end with": ComparisonOperator.END_WITH,
|
||||
"not contains": ComparisonOperator.NOT_CONTAINS,
|
||||
"exists": ComparisonOperator.NOT_EMPTY,
|
||||
"not exists": ComparisonOperator.EMPTY
|
||||
}
|
||||
return operator_map.get(operator, operator)
|
||||
|
||||
@@ -198,7 +227,7 @@ class DifyConverter(BaseConverter):
|
||||
"over-write": AssignmentOperator.COVER,
|
||||
"remove-last": AssignmentOperator.REMOVE_LAST,
|
||||
"remove-first": AssignmentOperator.REMOVE_FIRST,
|
||||
|
||||
"set": AssignmentOperator.ASSIGN,
|
||||
}
|
||||
return operator_map.get(operator, operator)
|
||||
|
||||
@@ -267,10 +296,10 @@ class DifyConverter(BaseConverter):
|
||||
type=var_type,
|
||||
required=var["required"],
|
||||
default=self.convert_variable_type(
|
||||
var_type, var["default"]
|
||||
var_type, var.get("default")
|
||||
),
|
||||
description=var["label"],
|
||||
max_length=var.get("max_length"),
|
||||
max_length=var.get("max_length", 50),
|
||||
)
|
||||
start_vars.append(var_def)
|
||||
result = StartNodeConfig.model_construct(
|
||||
@@ -333,7 +362,7 @@ class DifyConverter(BaseConverter):
|
||||
MessageConfig(
|
||||
role="user",
|
||||
content=self.trans_variable_format(
|
||||
node_data["memory"].get("query_prompt_template", "{{#sys.query#}}")
|
||||
node_data["memory"].get("query_prompt_template") or "{{#sys.query#}}"
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -364,7 +393,7 @@ class DifyConverter(BaseConverter):
|
||||
node_data = node["data"]
|
||||
cases = []
|
||||
for case in node_data["cases"]:
|
||||
case_id = case["id"]
|
||||
case_id = case.get("id") or case.get("case_id")
|
||||
logical_operator = case["logical_operator"]
|
||||
conditions = []
|
||||
for condition in case["conditions"]:
|
||||
@@ -540,7 +569,8 @@ class DifyConverter(BaseConverter):
|
||||
] = self.trans_variable_format(content["value"])
|
||||
else:
|
||||
if node_data["body"]["data"]:
|
||||
body_content = node_data["body"]["data"][0]["value"]
|
||||
body_content = (node_data["body"]["data"][0].get("value") or
|
||||
self._process_list_variable_litearl(node_data["body"]["data"][0].get("file")))
|
||||
else:
|
||||
body_content = ""
|
||||
|
||||
@@ -612,7 +642,7 @@ class DifyConverter(BaseConverter):
|
||||
),
|
||||
headers=headers,
|
||||
params=params,
|
||||
verify_ssl=node_data["ssl_verify"],
|
||||
verify_ssl=node_data.get("ssl_verify", False),
|
||||
timeouts=HttpTimeOutConfig.model_construct(
|
||||
connect_timeout=node_data["timeout"]["max_connect_timeout"] or 5,
|
||||
read_timeout=node_data["timeout"]["max_read_timeout"] or 5,
|
||||
@@ -696,7 +726,7 @@ class DifyConverter(BaseConverter):
|
||||
group_variables = {}
|
||||
group_type = {}
|
||||
if not advanced_settings or not advanced_settings["group_enabled"]:
|
||||
group_variables["output"] = [
|
||||
group_variables = [
|
||||
self._process_list_variable_litearl(variable)
|
||||
for variable in node_data["variables"]
|
||||
]
|
||||
@@ -728,3 +758,16 @@ class DifyConverter(BaseConverter):
|
||||
detail=f"Please reconfigure the tool node.",
|
||||
))
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def convert_notes_config(node: dict):
|
||||
node_data = node["data"]
|
||||
result = NoteNodeConfig.model_construct(
|
||||
author=node_data.get("author", ""),
|
||||
text=node_data.get("text", ""),
|
||||
width=node_data.get("width", 80),
|
||||
height=node_data.get("height", 80),
|
||||
theme=node_data.get("theme", "blue"),
|
||||
show_author=node_data.get("showAuthor", True)
|
||||
).model_dump()
|
||||
return result
|
||||
|
||||
@@ -44,12 +44,13 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
"parameter-extractor": NodeType.PARAMETER_EXTRACTOR,
|
||||
"question-classifier": NodeType.QUESTION_CLASSIFIER,
|
||||
"variable-aggregator": NodeType.VAR_AGGREGATOR,
|
||||
"tool": NodeType.TOOL
|
||||
"tool": NodeType.TOOL,
|
||||
"": NodeType.NOTES
|
||||
}
|
||||
|
||||
def __init__(self, config: dict[str, Any]):
|
||||
DifyConverter.__init__(self)
|
||||
BasePlatformAdapter.__init__(self, config)
|
||||
BasePlatformAdapter.__init__(self, config)
|
||||
|
||||
def get_metadata(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
@@ -58,7 +59,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
support_node_types=list(self.NODE_TYPE_MAPPING.keys())
|
||||
)
|
||||
|
||||
def map_node_type(self, platform_node_type) -> str:
|
||||
def map_node_type(self, platform_node_type) -> NodeType:
|
||||
return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN)
|
||||
|
||||
@property
|
||||
@@ -83,6 +84,12 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
require_fields = frozenset({'app', 'kind', 'version', 'workflow'})
|
||||
if not all(field in self.config for field in require_fields):
|
||||
return False
|
||||
if self.config.get("app", {}).get("mode") == "workflow":
|
||||
self.errors.append(ExceptionDefineition(
|
||||
type=ExceptionType.PLATFORM,
|
||||
detail="workflow mode is not supported"
|
||||
))
|
||||
return False
|
||||
|
||||
for node in self.origin_nodes:
|
||||
if not self._valid_nodes(node):
|
||||
@@ -134,6 +141,8 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
for node in self.origin_nodes:
|
||||
if self.map_node_type(node["data"]["type"]) == NodeType.LLM:
|
||||
self.node_output_map[f"{node['id']}.text"] = f"{node['id']}.output"
|
||||
elif self.map_node_type(node["data"]["type"]) == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
self.node_output_map[f"{node['id']}.result"] = f"{node['id']}.output"
|
||||
|
||||
def _convert_cycle_node_position(self, node_id: str, position: dict):
|
||||
for node in self.origin_nodes:
|
||||
@@ -154,13 +163,14 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None:
|
||||
node_data = node["data"]
|
||||
try:
|
||||
node_type = self.map_node_type(node_data["type"])
|
||||
return NodeDefinition(
|
||||
id=node["id"],
|
||||
type=self.map_node_type(node_data["type"]),
|
||||
name=node_data.get("title"),
|
||||
type=node_type,
|
||||
name=node_data.get("title") or "notes",
|
||||
cycle=node.get("parentId"),
|
||||
description=None,
|
||||
config=self._convert_node_config(node),
|
||||
config=self._convert_node_config(node_type, node),
|
||||
position={
|
||||
"x": node["position"]["x"],
|
||||
"y": node["position"]["y"]
|
||||
@@ -174,17 +184,16 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
except Exception as e:
|
||||
logger.debug(f"convert node error - {e}", exc_info=True)
|
||||
|
||||
def _convert_node_config(self, node: dict):
|
||||
node_data = node["data"]
|
||||
node_type = node_data["type"]
|
||||
def _convert_node_config(self, node_type: NodeType, node: dict):
|
||||
try:
|
||||
node_data = node["data"]
|
||||
converter = self.get_node_convert(node_type)
|
||||
if node_type not in self.CONFIG_CONVERT_MAP:
|
||||
if node_type == NodeType.UNKNOWN:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
type=ExceptionType.NODE,
|
||||
node_id=node["id"],
|
||||
node_name=node["data"]["title"],
|
||||
detail=f"node type {node_type} is unsupported",
|
||||
detail=f"node type {node_data.get('type')} is unsupported",
|
||||
))
|
||||
return converter(node)
|
||||
except Exception as e:
|
||||
@@ -201,16 +210,15 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
|
||||
source = edge["source"]
|
||||
target = edge["target"]
|
||||
edge_id = edge["id"]
|
||||
label = None
|
||||
if source in self.branch_node_cache:
|
||||
case_id = "-".join(edge_id.split("-")[1:-2])
|
||||
case_id = edge["sourceHandle"]
|
||||
if case_id == "false":
|
||||
label = f'CASE{len(self.branch_node_cache[source])+1}'
|
||||
label = f'CASE{len(self.branch_node_cache[source]) + 1}'
|
||||
else:
|
||||
label = f'CASE{self.branch_node_cache[source].index(case_id) + 1}'
|
||||
if source in self.error_branch_node_cache:
|
||||
case_id = "-".join(edge_id.split("-")[1:-2])
|
||||
case_id = edge["sourceHandle"]
|
||||
if case_id == "source":
|
||||
label = "SUCCESS"
|
||||
else:
|
||||
@@ -235,6 +243,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
name=variable["name"],
|
||||
default=variable["value"],
|
||||
type=self.variable_type_map(variable["value_type"]),
|
||||
description=variable.get("description")
|
||||
)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
@@ -248,5 +257,3 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
|
||||
def _convert_execution(self, execution: dict[str, Any]) -> ExecutionConfig:
|
||||
return ExecutionConfig()
|
||||
|
||||
|
||||
|
||||
@@ -292,6 +292,8 @@ class GraphBuilder:
|
||||
"""
|
||||
for node in self.nodes:
|
||||
node_type = node.get("type")
|
||||
if node_type == NodeType.NOTES:
|
||||
continue
|
||||
node_id = node.get("id")
|
||||
cycle_node = node.get("cycle")
|
||||
if cycle_node:
|
||||
@@ -320,7 +322,7 @@ class GraphBuilder:
|
||||
# Used later to determine which branch to take based on the node's output
|
||||
# Assumes node output `node.<node_id>.output` matches the edge's label
|
||||
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
|
||||
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
|
||||
related_edge[idx]['condition'] = f"node['{node_id}']['output'] == '{related_edge[idx]['label']}'"
|
||||
|
||||
if node_instance:
|
||||
# Wrap node's run method to avoid closure issues
|
||||
|
||||
@@ -158,18 +158,36 @@ class WorkflowExecutor:
|
||||
full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
|
||||
|
||||
# Append messages for user and assistant
|
||||
result["messages"].extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": input_data.get("message", '')
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_content
|
||||
}
|
||||
]
|
||||
)
|
||||
if input_data.get("files"):
|
||||
result["messages"].extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": input_data.get("message", '')
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": input_data.get("files")
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_content
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
result["messages"].extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": input_data.get("message", '')
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_content
|
||||
}
|
||||
]
|
||||
)
|
||||
# Calculate elapsed time
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
@@ -308,18 +326,36 @@ class WorkflowExecutor:
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
# Append messages for user and assistant
|
||||
result["messages"].extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": input_data.get("message", '')
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_content
|
||||
}
|
||||
]
|
||||
)
|
||||
if input_data.get("files"):
|
||||
result["messages"].extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": input_data.get("message", '')
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": input_data.get("files")
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_content
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
result["messages"].extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": input_data.get("message", '')
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_content
|
||||
}
|
||||
]
|
||||
)
|
||||
logger.info(
|
||||
f"Workflow execution completed (streaming), "
|
||||
f"elapsed: {elapsed_time:.2f}ms, execution_id: {self.execution_context.execution_id}"
|
||||
|
||||
@@ -85,20 +85,20 @@ class BaseNodeConfig(BaseModel):
|
||||
- tags: 节点标签(用于分类和搜索)
|
||||
"""
|
||||
|
||||
name: str | None = Field(
|
||||
default=None,
|
||||
description="节点名称(显示名称),如果不设置则使用节点 ID"
|
||||
)
|
||||
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="节点描述,说明节点的作用"
|
||||
)
|
||||
|
||||
tags: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="节点标签,用于分类和搜索"
|
||||
)
|
||||
# name: str | None = Field(
|
||||
# default=None,
|
||||
# description="节点名称(显示名称),如果不设置则使用节点 ID"
|
||||
# )
|
||||
#
|
||||
# description: str | None = Field(
|
||||
# default=None,
|
||||
# description="节点描述,说明节点的作用"
|
||||
# )
|
||||
#
|
||||
# tags: list[str] = Field(
|
||||
# default_factory=list,
|
||||
# description="节点标签,用于分类和搜索"
|
||||
# )
|
||||
|
||||
class Config:
|
||||
"""Pydantic 配置"""
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from functools import cached_property
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
@@ -13,6 +13,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.enums import BRANCH_NODES
|
||||
from app.core.workflow.variable.base_variable import VariableType, FileObject
|
||||
from app.db import get_db_read
|
||||
from app.models import ModelConfig, ModelApiKey, LoadBalanceStrategy
|
||||
from app.schemas import FileInput
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
|
||||
@@ -617,17 +618,31 @@ class BaseNode(ABC):
|
||||
return variable_pool.has(selector)
|
||||
|
||||
@staticmethod
|
||||
async def process_message(provider: str, content: str | FileObject, enable_file=False) -> dict | str | None:
|
||||
async def process_message(
|
||||
provider: str,
|
||||
is_omni: bool,
|
||||
content: str | dict | FileObject,
|
||||
enable_file=False
|
||||
) -> list | str | None:
|
||||
if isinstance(content, dict):
|
||||
content = FileObject(
|
||||
type=content.get("type"),
|
||||
url=content.get("url"),
|
||||
transfer_method=content.get("transfer_method"),
|
||||
origin_file_type=content.get("origin_file_type"),
|
||||
file_id=content.get("file_id"),
|
||||
is_file=True
|
||||
)
|
||||
if isinstance(content, str):
|
||||
if enable_file:
|
||||
return {"text": content}
|
||||
return [{"type": "text", "text": content}]
|
||||
return content
|
||||
|
||||
elif isinstance(content, FileObject):
|
||||
if content.content_cache.get(provider):
|
||||
return content.content_cache[provider]
|
||||
with get_db_read() as db:
|
||||
multimodel_service = MultimodalService(db, provider)
|
||||
multimodel_service = MultimodalService(db, provider, is_omni=is_omni)
|
||||
message = await multimodel_service.process_files(
|
||||
[FileInput.model_construct(
|
||||
type=content.type,
|
||||
@@ -637,10 +652,9 @@ class BaseNode(ABC):
|
||||
upload_file_id=content.file_id
|
||||
)]
|
||||
)
|
||||
|
||||
if message:
|
||||
content.content_cache[provider] = message[0]
|
||||
return message[0]
|
||||
content.content_cache[provider] = message
|
||||
return message
|
||||
return None
|
||||
raise TypeError(f'Unexpect input value type - {type(content)}')
|
||||
|
||||
@@ -658,3 +672,12 @@ class BaseNode(ABC):
|
||||
elif isinstance(content, str):
|
||||
return content
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def model_balance(model_config: ModelConfig) -> ModelApiKey:
|
||||
api_keys = [key for key in model_config.api_keys if key.is_active]
|
||||
if not api_keys:
|
||||
raise ValueError("No active API keys available for model")
|
||||
if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN:
|
||||
return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min))
|
||||
return api_keys[0]
|
||||
|
||||
@@ -23,6 +23,7 @@ from app.core.workflow.nodes.question_classifier.config import QuestionClassifie
|
||||
from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
||||
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
||||
from app.core.workflow.nodes.notes.config import NoteNodeConfig
|
||||
|
||||
__all__ = [
|
||||
# 基础类
|
||||
@@ -47,5 +48,6 @@ __all__ = [
|
||||
"ToolNodeConfig",
|
||||
"MemoryReadNodeConfig",
|
||||
"MemoryWriteNodeConfig",
|
||||
"CodeNodeConfig"
|
||||
"CodeNodeConfig",
|
||||
"NoteNodeConfig"
|
||||
]
|
||||
|
||||
@@ -25,6 +25,7 @@ class NodeType(StrEnum):
|
||||
MEMORY_WRITE = "memory-write"
|
||||
|
||||
UNKNOWN = "unknown"
|
||||
NOTES = "notes"
|
||||
|
||||
|
||||
BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]
|
||||
|
||||
@@ -180,6 +180,8 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
RuntimeError: If no valid knowledge base is found or access is denied.
|
||||
"""
|
||||
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
|
||||
if not self.typed_config.knowledge_bases:
|
||||
return []
|
||||
query = self._render_template(self.typed_config.query, variable_pool)
|
||||
with get_db_read() as db:
|
||||
knowledge_bases = self.typed_config.knowledge_bases
|
||||
|
||||
@@ -112,11 +112,12 @@ class LLMNode(BaseNode):
|
||||
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
||||
|
||||
# 在 Session 关闭前提取所有需要的数据
|
||||
api_config = config.api_keys[0]
|
||||
api_config = self.model_balance(config)
|
||||
model_name = api_config.model_name
|
||||
provider = api_config.provider
|
||||
api_key = api_config.api_key
|
||||
api_base = api_config.api_base
|
||||
is_omni = api_config.is_omni
|
||||
model_type = config.type
|
||||
|
||||
# 4. 创建 LLM 实例(使用已提取的数据)
|
||||
@@ -129,7 +130,8 @@ class LLMNode(BaseNode):
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
extra_params=extra_params
|
||||
extra_params=extra_params,
|
||||
is_omni=is_omni
|
||||
),
|
||||
type=ModelType(model_type)
|
||||
)
|
||||
@@ -151,39 +153,53 @@ class LLMNode(BaseNode):
|
||||
if role == "system":
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": content
|
||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||
})
|
||||
elif role in ["user", "human"]:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": content
|
||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||
})
|
||||
elif role in ["ai", "assistant"]:
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": content
|
||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||
})
|
||||
else:
|
||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": content
|
||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||
})
|
||||
|
||||
if self.typed_config.vision_input and self.typed_config.vision:
|
||||
file_content = []
|
||||
files = variable_pool.get_instance(self.typed_config.vision_input)
|
||||
for file in files.value:
|
||||
content = await self.process_message(provider, file.value, self.typed_config.vision)
|
||||
content = await self.process_message(provider, is_omni, file.value, self.typed_config.vision)
|
||||
if content:
|
||||
file_content.append(content)
|
||||
file_content.extend(content)
|
||||
if messages and messages[-1]["role"] == 'user':
|
||||
messages[-1]['content'] = [messages[-1]["content"]] + file_content
|
||||
messages[-1]['content'] = messages[-1]["content"] + file_content
|
||||
else:
|
||||
messages.append({"role": "user", "content": file_content})
|
||||
|
||||
if self.typed_config.memory.enable:
|
||||
messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:]
|
||||
history_message = []
|
||||
for message in state["messages"][-self.typed_config.memory.window_size:]:
|
||||
if isinstance(message["content"], list):
|
||||
file_content = []
|
||||
for file in message["content"]:
|
||||
content = await self.process_message(provider, is_omni, file, self.typed_config.vision)
|
||||
if content:
|
||||
file_content.extend(content)
|
||||
history_message.append(
|
||||
{"role": message["role"], "content": file_content}
|
||||
)
|
||||
else:
|
||||
message["content"] = await self.process_message(provider, is_omni, message["content"], self.typed_config.vision)
|
||||
history_message.append(message)
|
||||
messages = messages[:-1] + history_message + messages[-1:]
|
||||
self.messages = messages
|
||||
else:
|
||||
# 使用简单的 prompt 格式(向后兼容)
|
||||
|
||||
0
api/app/core/workflow/nodes/notes/__init__.py
Normal file
0
api/app/core/workflow/nodes/notes/__init__.py
Normal file
12
api/app/core/workflow/nodes/notes/config.py
Normal file
12
api/app/core/workflow/nodes/notes/config.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from pydantic import Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
|
||||
|
||||
class NoteNodeConfig(BaseNodeConfig):
|
||||
author: str = Field(default="", description="author")
|
||||
text: str = Field(default="", description="note content")
|
||||
width: int = Field(default=80)
|
||||
height: int = Field(default=80)
|
||||
theme: str = Field(default="blue")
|
||||
show_author: bool = Field(default=True)
|
||||
@@ -95,11 +95,12 @@ class ParameterExtractorNode(BaseNode):
|
||||
if not config.api_keys or len(config.api_keys) == 0:
|
||||
raise BusinessException("Model configuration is missing API Key", BizCode.INVALID_PARAMETER)
|
||||
|
||||
api_config = config.api_keys[0]
|
||||
api_config = self.model_balance(config)
|
||||
model_name = api_config.model_name
|
||||
provider = api_config.provider
|
||||
api_key = api_config.api_key
|
||||
api_base = api_config.api_base
|
||||
is_omni = api_config.is_omni
|
||||
model_type = config.type
|
||||
|
||||
llm = RedBearLLM(
|
||||
@@ -108,6 +109,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
is_omni=is_omni
|
||||
),
|
||||
type=ModelType(model_type)
|
||||
)
|
||||
|
||||
@@ -56,11 +56,12 @@ class QuestionClassifierNode(BaseNode):
|
||||
if not config.api_keys or len(config.api_keys) == 0:
|
||||
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
||||
|
||||
api_config = config.api_keys[0]
|
||||
api_config = self.model_balance(config)
|
||||
model_name = api_config.model_name
|
||||
provider = api_config.provider
|
||||
api_key = api_config.api_key
|
||||
base_url = api_config.api_base
|
||||
is_omni = api_config.is_omni
|
||||
model_type = config.type
|
||||
|
||||
return RedBearLLM(
|
||||
@@ -69,6 +70,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
is_omni=is_omni
|
||||
),
|
||||
type=ModelType(model_type)
|
||||
)
|
||||
|
||||
@@ -138,7 +138,7 @@ class WorkflowValidator:
|
||||
errors.append("工作流必须至少有一个 end 节点")
|
||||
|
||||
# 3. 验证节点 ID 唯一性
|
||||
node_ids = [n.get("id") for n in nodes]
|
||||
node_ids = [n.get("id") for n in nodes if n.get("type") != NodeType.NOTES]
|
||||
if len(node_ids) != len(set(node_ids)):
|
||||
duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1]
|
||||
errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}")
|
||||
|
||||
@@ -3,7 +3,7 @@ import uuid
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from sqlalchemy import Column, String, Text, DateTime, JSON, ForeignKey, Integer, Float, Boolean
|
||||
from sqlalchemy import Column, String, Text, DateTime, JSON, ForeignKey, Integer, Float, Boolean, text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@@ -163,6 +163,17 @@ class CustomToolConfig(Base):
|
||||
return f"<CustomToolConfig(id={self.id}, auth_type={self.auth_type})>"
|
||||
|
||||
|
||||
class MCPSourceChannel(StrEnum):
|
||||
"""MCP来源渠道枚举"""
|
||||
ALIYUN_BAILIAN = "aliyun_bailian" # 阿里云百炼
|
||||
MODELSCOPE = "modelscope" # ModelScope
|
||||
TOKENFLUX = "tokenflux" # TokenFlux
|
||||
LANGENG = "langeng" # 蓝耕科技
|
||||
AI_302 = "302ai" # 302.AI
|
||||
MCP_ROUTER = "mcp_router" # MCP Router
|
||||
SELF_HOSTED = "self_hosted" # 自建
|
||||
|
||||
|
||||
class MCPToolConfig(Base):
|
||||
"""MCP工具配置模型"""
|
||||
__tablename__ = "mcp_tool_configs"
|
||||
@@ -170,6 +181,13 @@ class MCPToolConfig(Base):
|
||||
id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True)
|
||||
server_url = Column(String(1000), nullable=False) # MCP服务器URL
|
||||
connection_config = Column(JSON, default=dict) # 连接配置(包含认证信息)
|
||||
|
||||
# 来源渠道
|
||||
source_channel = Column(String(50), default=MCPSourceChannel.SELF_HOSTED,
|
||||
server_default=text(f"'{MCPSourceChannel.SELF_HOSTED}'"), nullable=False, comment="来源渠道")
|
||||
market_id = Column(UUID(as_uuid=True), nullable=True, comment="渠道市场id")
|
||||
market_config_id = Column(UUID(as_uuid=True), nullable=True, comment="渠道市场配置id")
|
||||
mcp_service_id = Column(String(255), nullable=True, comment="mcp服务id")
|
||||
|
||||
# 服务状态
|
||||
last_health_check = Column(DateTime)
|
||||
|
||||
@@ -5,13 +5,15 @@ Implicit Emotions Storage Repository
|
||||
事务由调用方控制,仓储层只使用 flush/refresh
|
||||
"""
|
||||
import logging
|
||||
from datetime import datetime, date, timezone, timedelta
|
||||
from typing import Optional, Generator
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select, not_, exists
|
||||
from datetime import date, datetime, timedelta, timezone
|
||||
from typing import Generator, Optional
|
||||
|
||||
import redis
|
||||
from sqlalchemy import exists, not_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -111,6 +113,96 @@ class ImplicitEmotionsStorageRepository:
|
||||
logger.error(f"分批获取用户ID失败: offset={offset}, error={e}")
|
||||
break
|
||||
|
||||
def get_users_needing_refresh(self, redis_client: Optional[redis.StrictRedis], batch_size: int = 100) -> Generator[str, None, None]:
|
||||
"""分批次获取需要刷新隐性记忆/情绪数据的存量用户ID。
|
||||
|
||||
筛选逻辑:
|
||||
- 查询 implicit_emotions_storage 中所有用户的 end_user_id 和 updated_at
|
||||
- 从 Redis 读取 write_message:last_done:{end_user_id} 的时间戳
|
||||
- 若 Redis 中无记录(该用户从未写入过记忆),跳过
|
||||
- 若 last_done > updated_at,说明上次刷新后又有新记忆写入,需要刷新
|
||||
- 若 last_done <= updated_at,说明已是最新,跳过
|
||||
|
||||
如果 redis_client 为 None,则降级为返回所有用户(禁用时间过滤)。
|
||||
|
||||
Args:
|
||||
redis_client: 同步 redis.StrictRedis 实例(连接 CELERY_BACKEND DB),如果为 None 则禁用时间过滤
|
||||
batch_size: 每批次加载的数量
|
||||
|
||||
Yields:
|
||||
需要刷新的用户ID字符串
|
||||
"""
|
||||
from datetime import timezone
|
||||
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
# 如果 Redis 不可用,降级为处理所有用户
|
||||
if redis_client is None:
|
||||
logger.warning(
|
||||
"Redis 客户端不可用,时间过滤已禁用,将处理所有存量用户"
|
||||
)
|
||||
yield from self.get_all_user_ids(batch_size)
|
||||
return
|
||||
|
||||
offset = 0
|
||||
while True:
|
||||
try:
|
||||
stmt = (
|
||||
select(ImplicitEmotionsStorage.end_user_id, ImplicitEmotionsStorage.updated_at)
|
||||
.order_by(ImplicitEmotionsStorage.end_user_id)
|
||||
.limit(batch_size)
|
||||
.offset(offset)
|
||||
)
|
||||
batch = self.db.execute(stmt).all()
|
||||
if not batch:
|
||||
break
|
||||
|
||||
# 批量获取当前批次所有用户的 last_done 时间戳(一次网络往返)
|
||||
keys = [f"write_message:last_done:{end_user_id}" for end_user_id, _ in batch]
|
||||
|
||||
try:
|
||||
raw_values = redis_client.mget(keys)
|
||||
except RedisError as e:
|
||||
logger.error(
|
||||
f"Redis mget 操作失败: {e},当前批次降级为处理所有用户",
|
||||
extra={"offset": offset, "batch_size": len(batch)}
|
||||
)
|
||||
# Redis 操作失败,降级为返回当前批次所有用户
|
||||
yield from (end_user_id for end_user_id, _ in batch)
|
||||
offset += batch_size
|
||||
continue
|
||||
|
||||
for (end_user_id, updated_at), raw in zip(batch, raw_values):
|
||||
if raw is None:
|
||||
continue
|
||||
try:
|
||||
CST = timezone(timedelta(hours=8))
|
||||
last_done = datetime.fromisoformat(raw)
|
||||
# 统一转为 CST naive 时间做比较
|
||||
if last_done.tzinfo is None:
|
||||
last_done = last_done.replace(tzinfo=timezone.utc).astimezone(CST).replace(tzinfo=None)
|
||||
else:
|
||||
last_done = last_done.astimezone(CST).replace(tzinfo=None)
|
||||
|
||||
if updated_at is None:
|
||||
yield end_user_id
|
||||
continue
|
||||
# updated_at 同样转为 CST naive
|
||||
if updated_at.tzinfo is None:
|
||||
updated_at_cst = updated_at.replace(tzinfo=timezone.utc).astimezone(CST).replace(tzinfo=None)
|
||||
else:
|
||||
updated_at_cst = updated_at.astimezone(CST).replace(tzinfo=None)
|
||||
|
||||
if last_done > updated_at_cst:
|
||||
yield end_user_id
|
||||
except Exception as e:
|
||||
logger.warning(f"解析 last_done 时间戳失败: end_user_id={end_user_id}, raw={raw}, error={e}")
|
||||
|
||||
offset += batch_size
|
||||
except Exception as e:
|
||||
logger.error(f"get_users_needing_refresh 分批查询失败: offset={offset}, error={e}")
|
||||
break
|
||||
|
||||
def get_new_user_ids_today(self, batch_size: int = 100) -> Generator[str, None, None]:
|
||||
"""分批次获取当天新增的、尚未初始化隐性记忆和情绪建议数据的用户ID
|
||||
|
||||
@@ -124,7 +216,8 @@ class ImplicitEmotionsStorageRepository:
|
||||
Yields:
|
||||
用户ID字符串
|
||||
"""
|
||||
from sqlalchemy import cast, String as SAString
|
||||
from sqlalchemy import String as SAString
|
||||
from sqlalchemy import cast
|
||||
CST = timezone(timedelta(hours=8))
|
||||
now_cst = datetime.now(CST)
|
||||
today_start = now_cst.replace(hour=0, minute=0, second=0, microsecond=0).astimezone(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
@@ -233,6 +233,7 @@ class MemoryConfigRepository:
|
||||
config_desc=params.config_desc,
|
||||
workspace_id=params.workspace_id,
|
||||
scene_id=params.scene_id,
|
||||
pruning_scene=params.pruning_scene,
|
||||
llm_id=params.llm_id,
|
||||
embedding_id=params.embedding_id,
|
||||
rerank_id=params.rerank_id,
|
||||
|
||||
@@ -86,6 +86,7 @@ class ChatResponse(BaseModel):
|
||||
"""聊天响应(非流式)"""
|
||||
conversation_id: uuid.UUID
|
||||
message: str
|
||||
message_id: str
|
||||
usage: Optional[Dict[str, Any]] = None
|
||||
elapsed_time: Optional[float] = None
|
||||
|
||||
|
||||
@@ -417,6 +417,7 @@ class MemoryConfig:
|
||||
|
||||
# Ontology scene association
|
||||
scene_id: Optional[UUID] = None
|
||||
ontology_classes: Optional[list] = field(default=None)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate configuration after initialization."""
|
||||
|
||||
@@ -232,14 +232,15 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body,
|
||||
# 本体场景关联(可选)
|
||||
scene_id: Optional[uuid.UUID] = Field(None, description="本体场景ID(UUID),关联ontology_scene表")
|
||||
|
||||
# 语义剪枝场景(由 service 层根据 scene_id 自动推导,值为关联场景的 scene_name,前端无需传入)
|
||||
pruning_scene: Optional[str] = Field(None, description="语义剪枝场景,由 scene_id 对应的 scene_name 自动填充")
|
||||
|
||||
# 模型配置字段(可选,用于手动指定或自动填充)
|
||||
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
|
||||
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
|
||||
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
|
||||
reflection_model_id: Optional[str] = Field(None, description="反思模型ID,默认与llm_id一致")
|
||||
emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID,默认与llm_id一致")
|
||||
|
||||
|
||||
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
# config_name: str = Field("配置名称", description="配置名称(字符串)")
|
||||
@@ -274,8 +275,8 @@ class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数
|
||||
|
||||
# 剪枝配置:与 runtime.json 中 pruning 段对应
|
||||
pruning_enabled: Optional[bool] = Field(None, description="是否启动智能语义剪枝")
|
||||
pruning_scene: Optional[Literal["education", "online_service", "outbound"]] = Field(
|
||||
None, description="智能剪枝场景:education/online_service/outbound"
|
||||
pruning_scene: Optional[str] = Field(
|
||||
None, description="智能剪枝场景:education/online_service/outbound 或本体工程自定义场景"
|
||||
)
|
||||
pruning_threshold: Optional[float] = Field(
|
||||
None, ge=0.0, le=0.9, description="智能语义剪枝阈值(0-0.9)"
|
||||
|
||||
@@ -23,6 +23,7 @@ class ModelConfigBase(BaseModel):
|
||||
load_balance_strategy: Optional[str] = Field(LoadBalanceStrategy.NONE.value, description="负载均衡策略")
|
||||
capability: List[str] = Field(default_factory=list, description="模型能力列表")
|
||||
is_omni: bool = Field(False, description="是否为Omni模型")
|
||||
model_id: Optional[uuid.UUID] = Field(None, description="基础模型ID")
|
||||
|
||||
|
||||
class ApiKeyCreateNested(BaseModel):
|
||||
|
||||
@@ -155,6 +155,10 @@ class MCPToolConfigSchema(BaseModel):
|
||||
health_status: str = "unknown"
|
||||
error_message: Optional[str] = None
|
||||
available_tools: List[Dict[str, Dict[str, Any]]] = Field(default_factory=list, description="工具列表,格式: [{'tool_name': str, 'arguments': dict}]")
|
||||
source_channel: Optional[str] = Field(None, description="来源渠道")
|
||||
market_id: Optional[str] = Field(None, description="渠道市场id")
|
||||
market_config_id: Optional[str] = Field(None, description="渠道市场配置id")
|
||||
mcp_service_id: Optional[str] = Field(None, description="mcp服务id")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
@@ -192,6 +196,10 @@ class ToolCreateRequest(BaseModel):
|
||||
tool_type: ToolType
|
||||
config: Dict[str, Any] = Field(default_factory=dict)
|
||||
tags: List[str] = Field(default_factory=list)
|
||||
source_channel: Optional[str] = Field(None, description="来源渠道(仅MCP工具)")
|
||||
market_id: Optional[str] = Field(None, description="渠道市场id(仅MCP工具)")
|
||||
market_config_id: Optional[str] = Field(None, description="渠道市场配置id(仅MCP工具)")
|
||||
mcp_service_id: Optional[str] = Field(None, description="mcp服务id(仅MCP工具)")
|
||||
|
||||
|
||||
class ToolUpdateRequest(BaseModel):
|
||||
|
||||
@@ -144,7 +144,7 @@ class AppChatService:
|
||||
)
|
||||
|
||||
# 保存消息
|
||||
self.conversation_service.save_conversation_messages(
|
||||
message_id = self.conversation_service.save_conversation_messages(
|
||||
conversation_id=conversation_id,
|
||||
user_message=message,
|
||||
assistant_message=result["content"],
|
||||
@@ -163,6 +163,7 @@ class AppChatService:
|
||||
|
||||
return {
|
||||
"conversation_id": conversation_id,
|
||||
"message_id": str(message_id),
|
||||
"message": result["content"],
|
||||
"usage": result.get("usage", {
|
||||
"prompt_tokens": 0,
|
||||
@@ -191,7 +192,11 @@ class AppChatService:
|
||||
try:
|
||||
start_time = time.time()
|
||||
config_id = None
|
||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
|
||||
message_id = uuid.uuid4()
|
||||
yield f"event: start\ndata: {json.dumps({
|
||||
'conversation_id': str(conversation_id),
|
||||
"message_id": str(message_id)
|
||||
}, ensure_ascii=False)}\n\n"
|
||||
|
||||
variables = self.agent_service.prepare_variables(variables, config.variables)
|
||||
# 获取模型配置ID
|
||||
@@ -296,6 +301,7 @@ class AppChatService:
|
||||
)
|
||||
|
||||
self.conversation_service.add_message(
|
||||
message_id=message_id,
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=full_content,
|
||||
@@ -373,7 +379,7 @@ class AppChatService:
|
||||
content=message
|
||||
)
|
||||
|
||||
self.conversation_service.add_message(
|
||||
ai_message = self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=result.get("message", ""),
|
||||
@@ -391,6 +397,7 @@ class AppChatService:
|
||||
return {
|
||||
"conversation_id": conversation_id,
|
||||
"message": result.get("message", ""),
|
||||
"message_id": str(ai_message.id),
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
@@ -419,9 +426,9 @@ class AppChatService:
|
||||
variables = {}
|
||||
|
||||
try:
|
||||
|
||||
message_id = uuid.uuid4()
|
||||
# 发送开始事件
|
||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
|
||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id), "message_id": str(message_id)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
full_content = ""
|
||||
total_tokens = 0
|
||||
@@ -429,6 +436,7 @@ class AppChatService:
|
||||
# 2. 创建编排器
|
||||
orchestrator = MultiAgentOrchestrator(self.db, config)
|
||||
|
||||
|
||||
# 3. 流式执行任务
|
||||
async for event in orchestrator.execute_stream(
|
||||
message=message,
|
||||
@@ -472,6 +480,7 @@ class AppChatService:
|
||||
)
|
||||
|
||||
self.conversation_service.add_message(
|
||||
message_id=message_id,
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=full_content,
|
||||
|
||||
@@ -703,7 +703,7 @@ class AppService:
|
||||
self.db.flush()
|
||||
|
||||
# 如果是 agent 类型,复制 AgentConfig
|
||||
if source_app.type == "agent":
|
||||
if source_app.type == AppType.AGENT:
|
||||
source_config = self.db.query(AgentConfig).filter(
|
||||
AgentConfig.app_id == source_app.id
|
||||
).first()
|
||||
@@ -725,6 +725,50 @@ class AppService:
|
||||
)
|
||||
self.db.add(new_config)
|
||||
|
||||
elif source_app.type == AppType.WORKFLOW:
|
||||
source_config = self.db.query(WorkflowConfig).filter(
|
||||
WorkflowConfig.app_id == source_app.id
|
||||
).first()
|
||||
|
||||
if source_config:
|
||||
new_config = WorkflowConfig(
|
||||
id=uuid.uuid4(),
|
||||
app_id=new_app.id,
|
||||
nodes=source_config.nodes.copy() if source_config.nodes else [],
|
||||
edges=source_config.edges.copy() if source_config.edges else [],
|
||||
variables=source_config.variables.copy() if source_config.variables else [],
|
||||
execution_config=source_config.execution_config.copy() if source_config.execution_config else {},
|
||||
triggers=source_config.triggers.copy() if source_config.triggers else [],
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
self.db.add(new_config)
|
||||
|
||||
elif source_app.type == AppType.MULTI_AGENT:
|
||||
source_config = self.db.query(MultiAgentConfig).filter(
|
||||
MultiAgentConfig.app_id == source_app.id
|
||||
).first()
|
||||
|
||||
if source_config:
|
||||
new_config = MultiAgentConfig(
|
||||
id=uuid.uuid4(),
|
||||
app_id=new_app.id,
|
||||
master_agent_id=source_config.master_agent_id,
|
||||
master_agent_name=source_config.master_agent_name,
|
||||
default_model_config_id=source_config.default_model_config_id,
|
||||
model_parameters=source_config.model_parameters,
|
||||
orchestration_mode=source_config.orchestration_mode,
|
||||
sub_agents=source_config.sub_agents.copy() if source_config.sub_agents else [],
|
||||
routing_rules=source_config.routing_rules.copy() if source_config.routing_rules else None,
|
||||
execution_config=source_config.execution_config.copy() if source_config.execution_config else {},
|
||||
aggregation_strategy=source_config.aggregation_strategy,
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
self.db.add(new_config)
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(new_app)
|
||||
|
||||
|
||||
@@ -178,7 +178,8 @@ class ConversationService:
|
||||
conversation_id: uuid.UUID,
|
||||
role: str,
|
||||
content: str,
|
||||
meta_data: Optional[dict] = None
|
||||
meta_data: Optional[dict] = None,
|
||||
message_id: Optional[uuid.UUID] = None,
|
||||
) -> Message:
|
||||
"""
|
||||
Add a message to a conversation using UnitOfWork.
|
||||
@@ -188,6 +189,7 @@ class ConversationService:
|
||||
role (str): Role of the message sender ('user' or 'assistant').
|
||||
content (str): Message content.
|
||||
meta_data (Optional[dict]): Optional metadata.
|
||||
message_id (Optional[uuid.UUID]): Optional custom message UUID.
|
||||
|
||||
Returns:
|
||||
Message: Newly created Message instance.
|
||||
@@ -198,6 +200,7 @@ class ConversationService:
|
||||
)
|
||||
|
||||
message = Message(
|
||||
id=message_id if message_id else uuid.uuid4(),
|
||||
conversation_id=conversation_id,
|
||||
role=role,
|
||||
content=content,
|
||||
@@ -317,7 +320,7 @@ class ConversationService:
|
||||
content=user_message
|
||||
)
|
||||
|
||||
self.add_message(
|
||||
ai_message = self.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=assistant_message,
|
||||
@@ -332,6 +335,7 @@ class ConversationService:
|
||||
"assistant_message_length": len(assistant_message)
|
||||
}
|
||||
)
|
||||
return ai_message.id
|
||||
|
||||
def delete_conversation(
|
||||
self,
|
||||
|
||||
@@ -107,6 +107,40 @@ def _validate_config_id(config_id, db: Session = None):
|
||||
)
|
||||
|
||||
|
||||
# 专门场景的内置 key 集合,直接从 SceneConfigRegistry 派生,避免重复维护
|
||||
# 使用懒加载函数避免模块级循环导入
|
||||
def _get_builtin_pruning_scenes() -> set:
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene_config import SceneConfigRegistry
|
||||
return set(SceneConfigRegistry.get_all_scenes())
|
||||
|
||||
|
||||
def _load_ontology_classes(db: Session, scene_id, pruning_scene: Optional[str]) -> Optional[list]:
|
||||
"""当 pruning_scene 不是内置场景时,从 ontology_class 表加载类型名称列表。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
scene_id: 本体场景 UUID
|
||||
pruning_scene: 语义剪枝场景名称
|
||||
|
||||
Returns:
|
||||
class_name 字符串列表,或 None(内置场景 / 无数据时)
|
||||
"""
|
||||
if not scene_id:
|
||||
return None
|
||||
# 内置场景走 SceneConfigRegistry,不需要注入类型列表
|
||||
if pruning_scene in _get_builtin_pruning_scenes():
|
||||
return None
|
||||
try:
|
||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
||||
repo = OntologyClassRepository(db)
|
||||
classes = repo.get_classes_by_scene(scene_id)
|
||||
names = [c.class_name for c in classes if c.class_name]
|
||||
return names if names else None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load ontology classes for scene_id={scene_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class MemoryConfigService:
|
||||
"""
|
||||
Centralized service for memory configuration loading and validation.
|
||||
@@ -359,6 +393,7 @@ class MemoryConfigService:
|
||||
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
||||
# Ontology scene association
|
||||
scene_id=memory_config.scene_id,
|
||||
ontology_classes=_load_ontology_classes(self.db, memory_config.scene_id, memory_config.pruning_scene),
|
||||
)
|
||||
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
@@ -146,6 +146,10 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
if not params.emotion_model_id:
|
||||
params.emotion_model_id = params.llm_id
|
||||
|
||||
# 根据关联的本体场景推导 pruning_scene(语义剪枝场景与本体工程场景保持一致)
|
||||
if params.scene_id and not getattr(params, 'pruning_scene', None):
|
||||
params.pruning_scene = self._resolve_pruning_scene_from_scene_id(params.scene_id)
|
||||
|
||||
config = MemoryConfigRepository.create(self.db, params)
|
||||
self.db.commit()
|
||||
return {"affected": 1, "config_id": config.config_id}
|
||||
@@ -161,6 +165,23 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
def _resolve_pruning_scene_from_scene_id(self, scene_id) -> Optional[str]:
|
||||
"""根据本体场景ID获取对应的 scene_name,作为语义剪枝场景值
|
||||
|
||||
Args:
|
||||
scene_id: 本体场景UUID
|
||||
|
||||
Returns:
|
||||
scene_name 字符串,查询失败时返回 None
|
||||
"""
|
||||
try:
|
||||
from app.models.ontology_scene import OntologyScene
|
||||
scene = self.db.query(OntologyScene).filter_by(scene_id=scene_id).first()
|
||||
return scene.scene_name if scene else None
|
||||
except Exception as e:
|
||||
logger.warning(f"_resolve_pruning_scene_from_scene_id failed for scene_id={scene_id}: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
# --- Delete ---
|
||||
def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置ID)
|
||||
success = MemoryConfigRepository.delete(self.db, key.config_id)
|
||||
@@ -196,6 +217,19 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
|
||||
results = MemoryConfigRepository.get_all(self.db, workspace_id)
|
||||
|
||||
# 检查并修正 pruning_scene 与 scene_name 不一致的记录
|
||||
needs_commit = False
|
||||
for config, scene_name in results:
|
||||
if scene_name and config.pruning_scene != scene_name:
|
||||
logger.info(
|
||||
f"修正 pruning_scene: config_id={config.config_id} "
|
||||
f"'{config.pruning_scene}' -> '{scene_name}'"
|
||||
)
|
||||
config.pruning_scene = scene_name
|
||||
needs_commit = True
|
||||
if needs_commit:
|
||||
self.db.commit()
|
||||
|
||||
# 将 ORM 对象转换为字典列表
|
||||
data_list = []
|
||||
for config, scene_name in results:
|
||||
@@ -749,8 +783,37 @@ async def analytics_hot_memory_tags(
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def analytics_recent_activity_stats() -> Dict[str, Any]:
|
||||
stats, _msg = get_recent_activity_stats()
|
||||
async def analytics_recent_activity_stats(workspace_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""获取最近记忆提取活动统计。
|
||||
|
||||
优先从 Redis 缓存读取(按 workspace_id),缓存不存在时降级到日志文件解析。
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID,用于从 Redis 读取对应缓存
|
||||
|
||||
Returns:
|
||||
包含 total、stats、latest_relative、source 的统计字典
|
||||
"""
|
||||
stats = None
|
||||
source = "log"
|
||||
|
||||
# 优先从 Redis 读取
|
||||
if workspace_id:
|
||||
try:
|
||||
from app.cache.memory.activity_stats_cache import ActivityStatsCache
|
||||
cached = await ActivityStatsCache.get_activity_stats(workspace_id)
|
||||
if cached:
|
||||
stats = cached.get("stats", {})
|
||||
source = "redis"
|
||||
logger.info(f"[ANALYTICS] 从 Redis 读取活动统计: workspace_id={workspace_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[ANALYTICS] 读取 Redis 活动统计失败,降级到日志: {e}")
|
||||
|
||||
# 降级:从日志文件解析
|
||||
if stats is None:
|
||||
stats, _msg = get_recent_activity_stats()
|
||||
source = "log"
|
||||
|
||||
total = (
|
||||
stats.get("chunk_count", 0)
|
||||
+ stats.get("statements_count", 0)
|
||||
@@ -758,26 +821,29 @@ async def analytics_recent_activity_stats() -> Dict[str, Any]:
|
||||
+ stats.get("triplet_relations_count", 0)
|
||||
+ stats.get("temporal_count", 0)
|
||||
)
|
||||
# 精简:仅提供“最新一次活动多久前”
|
||||
latest_relative = None
|
||||
try:
|
||||
info = stats.get("log_path", "")
|
||||
idx = info.rfind("最新:")
|
||||
if idx != -1:
|
||||
latest_path = info[idx + 3 :].strip()
|
||||
if latest_path and os.path.exists(latest_path):
|
||||
import time
|
||||
diff = max(0.0, time.time() - os.path.getmtime(latest_path))
|
||||
m = int(diff // 60)
|
||||
if m < 1:
|
||||
latest_relative = "刚刚"
|
||||
elif m < 60:
|
||||
latest_relative = "一会前"
|
||||
else:
|
||||
latest_relative = "较早前"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
data = {"total": total, "stats": stats, "latest_relative": latest_relative}
|
||||
# 计算"最新一次活动多久前"(仅日志来源时有效)
|
||||
latest_relative = None
|
||||
if source == "log":
|
||||
try:
|
||||
info = stats.get("log_path", "")
|
||||
idx = info.rfind("最新:")
|
||||
if idx != -1:
|
||||
latest_path = info[idx + 3:].strip()
|
||||
if latest_path and os.path.exists(latest_path):
|
||||
import time
|
||||
diff = max(0.0, time.time() - os.path.getmtime(latest_path))
|
||||
m = int(diff // 60)
|
||||
if m < 1:
|
||||
latest_relative = "刚刚"
|
||||
elif m < 60:
|
||||
latest_relative = "一会前"
|
||||
else:
|
||||
latest_relative = "较早前"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
data = {"total": total, "stats": stats, "latest_relative": latest_relative, "source": source}
|
||||
return data
|
||||
|
||||
|
||||
|
||||
@@ -780,6 +780,7 @@ class ModelBaseService:
|
||||
"description": model_base.description,
|
||||
"capability": model_base.capability,
|
||||
"is_omni": model_base.is_omni,
|
||||
"is_active": False,
|
||||
"is_composite": False
|
||||
}
|
||||
model_config = ModelConfigRepository.create(db, model_config_data)
|
||||
|
||||
@@ -326,6 +326,25 @@ async def run_pilot_extraction(
|
||||
|
||||
logger.info("Pilot run completed: Skipping Neo4j save")
|
||||
|
||||
# 将提取统计写入 Redis,按 workspace_id 存储
|
||||
try:
|
||||
from app.cache.memory.activity_stats_cache import ActivityStatsCache
|
||||
|
||||
stats_to_cache = {
|
||||
"chunk_count": len(chunk_nodes) if chunk_nodes else 0,
|
||||
"statements_count": len(statement_nodes) if statement_nodes else 0,
|
||||
"triplet_entities_count": len(entity_nodes) if entity_nodes else 0,
|
||||
"triplet_relations_count": len(entity_edges) if entity_edges else 0,
|
||||
"temporal_count": 0, # temporal 数据在日志中,此处暂置0
|
||||
}
|
||||
await ActivityStatsCache.set_activity_stats(
|
||||
workspace_id=str(memory_config.workspace_id),
|
||||
stats=stats_to_cache,
|
||||
)
|
||||
logger.info(f"[PILOT_RUN] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}")
|
||||
except Exception as cache_err:
|
||||
logger.warning(f"[PILOT_RUN] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pilot run failed: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
@@ -85,7 +85,7 @@ class ToolService:
|
||||
"""检查工具名称是否重复"""
|
||||
query = self.db.query(ToolConfig).filter(
|
||||
ToolConfig.name == name,
|
||||
ToolConfig.tool_type == tool_type.value,
|
||||
ToolConfig.tool_type == tool_type,
|
||||
ToolConfig.tenant_id == tenant_id
|
||||
)
|
||||
if exclude_id:
|
||||
@@ -910,7 +910,11 @@ class ToolService:
|
||||
config_data.update({
|
||||
"last_health_check": int(mcp_config.last_health_check.timestamp() * 1000) if mcp_config.last_health_check else None,
|
||||
"health_status": mcp_config.health_status,
|
||||
"available_tools": available_tools_display
|
||||
"available_tools": available_tools_display,
|
||||
"source_channel": mcp_config.source_channel,
|
||||
"market_id": mcp_config.market_id,
|
||||
"market_config_id": mcp_config.market_config_id,
|
||||
"mcp_service_id": mcp_config.mcp_service_id
|
||||
})
|
||||
|
||||
return ToolInfo(
|
||||
@@ -965,7 +969,11 @@ class ToolService:
|
||||
id=tool_config.id,
|
||||
server_url=config.get("server_url"),
|
||||
connection_config=config.get("connection_config", {}),
|
||||
available_tools=config.get("available_tools", [])
|
||||
available_tools=config.get("available_tools", []),
|
||||
source_channel=config.get("source_channel", "self_hosted"),
|
||||
market_id=config.get("market_id"),
|
||||
market_config_id=config.get("market_config_id"),
|
||||
mcp_service_id=config.get("mcp_service_id"),
|
||||
)
|
||||
self.db.add(mcp_config)
|
||||
|
||||
@@ -1018,6 +1026,14 @@ class ToolService:
|
||||
mcp_config.server_url = config.get("server_url")
|
||||
mcp_config.connection_config = config.get("connection_config", {})
|
||||
mcp_config.available_tools = config.get("available_tools", [])
|
||||
if config.get("source_channel") is not None:
|
||||
mcp_config.source_channel = config.get("source_channel")
|
||||
if config.get("market_id") is not None:
|
||||
mcp_config.market_id = config.get("market_id")
|
||||
if config.get("market_config_id") is not None:
|
||||
mcp_config.market_config_id = config.get("market_config_id")
|
||||
if config.get("mcp_service_id") is not None:
|
||||
mcp_config.mcp_service_id = config.get("mcp_service_id")
|
||||
|
||||
@staticmethod
|
||||
def _determine_initial_status(tool_info: Dict[str, Any]) -> str:
|
||||
|
||||
@@ -56,7 +56,7 @@ class WorkflowImportService:
|
||||
success=False,
|
||||
temp_id=None,
|
||||
workflow_id=None,
|
||||
errors=[InvalidConfiguration()]
|
||||
errors=[InvalidConfiguration()] + adapter.errors
|
||||
)
|
||||
|
||||
workflow_config = adapter.parse_workflow()
|
||||
|
||||
@@ -25,7 +25,7 @@ from app.repositories.workflow_repository import (
|
||||
WorkflowExecutionRepository,
|
||||
WorkflowNodeExecutionRepository
|
||||
)
|
||||
from app.schemas import DraftRunRequest, FileInput
|
||||
from app.schemas import DraftRunRequest, FileInput, FileType
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.multi_agent_service import convert_uuids_to_str
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
@@ -496,6 +496,7 @@ class WorkflowService:
|
||||
"event": "start",
|
||||
"data": {
|
||||
"conversation_id": payload.get("conversation_id"),
|
||||
"message_id": payload.get("message_id")
|
||||
}
|
||||
}
|
||||
case "workflow_end":
|
||||
@@ -600,6 +601,7 @@ class WorkflowService:
|
||||
try:
|
||||
files = await self._handle_file_input(payload.files)
|
||||
input_data["files"] = files
|
||||
message_id = uuid.uuid4()
|
||||
# 更新状态为运行中
|
||||
self.update_execution_status(execution.execution_id, "running")
|
||||
|
||||
@@ -624,24 +626,45 @@ class WorkflowService:
|
||||
workspace_id=str(workspace_id),
|
||||
user_id=payload.user_id
|
||||
)
|
||||
|
||||
# 更新执行结果
|
||||
if result.get("status") == "completed":
|
||||
token_usage = result.get("token_usage", {}) or {}
|
||||
|
||||
final_messages = result.get("messages", [])[init_message_length:]
|
||||
human_message = ""
|
||||
assistant_message = ""
|
||||
for message in final_messages:
|
||||
if message["role"] == "user":
|
||||
if isinstance(message["content"], str):
|
||||
human_message += message["content"]
|
||||
elif isinstance(message["content"], list):
|
||||
for file in message["content"]:
|
||||
if file.get("type") == FileType.IMAGE:
|
||||
human_message += f"})"
|
||||
else:
|
||||
human_message += f"[{file.get('type')}]({file.get('url', '')})"
|
||||
if message["role"] == "assistant":
|
||||
assistant_message = message["content"]
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id_uuid,
|
||||
role="user",
|
||||
content=human_message,
|
||||
meta_data=None
|
||||
)
|
||||
self.conversation_service.add_message(
|
||||
message_id=message_id,
|
||||
conversation_id=conversation_id_uuid,
|
||||
role="assistant",
|
||||
content=assistant_message,
|
||||
meta_data={"usage": token_usage}
|
||||
)
|
||||
self.update_execution_status(
|
||||
execution.execution_id,
|
||||
"completed",
|
||||
output_data=result,
|
||||
token_usage=token_usage.get("total_tokens", None)
|
||||
)
|
||||
final_messages = result.get("messages", [])[init_message_length:]
|
||||
for message in final_messages:
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id_uuid,
|
||||
role=message["role"],
|
||||
content=message["content"],
|
||||
meta_data=None if message["role"] == "user" else {"usage": token_usage}
|
||||
)
|
||||
|
||||
logger.info(f"Workflow Run Success, "
|
||||
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
|
||||
else:
|
||||
@@ -650,6 +673,8 @@ class WorkflowService:
|
||||
"failed",
|
||||
error_message=result.get("error")
|
||||
)
|
||||
logger.error(f"Workflow Run Failed, execution_id: {execution.execution_id},"
|
||||
f" error: {result.get('error')}")
|
||||
|
||||
# 返回增强的响应结构
|
||||
return {
|
||||
@@ -659,6 +684,7 @@ class WorkflowService:
|
||||
# "messages": result.get("messages"),
|
||||
"output": result.get("output"), # 最终输出(字符串)
|
||||
"message": result.get("output"), # 最终输出(字符串)
|
||||
"message_id": str(message_id),
|
||||
# "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
|
||||
"conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID
|
||||
"error_message": result.get("error"),
|
||||
@@ -756,7 +782,7 @@ class WorkflowService:
|
||||
input_data["conv_messages"] = last_state.get("messages") or []
|
||||
break
|
||||
init_message_length = len(input_data.get("conv_messages", []))
|
||||
|
||||
message_id = uuid.uuid4()
|
||||
async for event in execute_workflow_stream(
|
||||
workflow_config=workflow_config_dict,
|
||||
input_data=input_data,
|
||||
@@ -765,24 +791,43 @@ class WorkflowService:
|
||||
user_id=payload.user_id,
|
||||
):
|
||||
if event.get("event") == "workflow_end":
|
||||
|
||||
status = event.get("data", {}).get("status")
|
||||
token_usage = event.get("data", {}).get("token_usage", {}) or {}
|
||||
if status == "completed":
|
||||
final_messages = event.get("data", {}).get("messages", [])[init_message_length:]
|
||||
human_message = ""
|
||||
assistant_message = ""
|
||||
for message in final_messages:
|
||||
if message["role"] == "user":
|
||||
if isinstance(message["content"], str):
|
||||
human_message += message["content"]
|
||||
elif isinstance(message["content"], list):
|
||||
for file in message["content"]:
|
||||
if file.get("type") == FileType.IMAGE:
|
||||
human_message += f"})"
|
||||
else:
|
||||
human_message += f"[{file.get('type')}]({file.get('url', '')})"
|
||||
if message["role"] == "assistant":
|
||||
assistant_message = message["content"]
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id_uuid,
|
||||
role="user",
|
||||
content=human_message,
|
||||
meta_data=None
|
||||
)
|
||||
self.conversation_service.add_message(
|
||||
message_id=message_id,
|
||||
conversation_id=conversation_id_uuid,
|
||||
role="assistant",
|
||||
content=assistant_message,
|
||||
meta_data={"usage": token_usage}
|
||||
)
|
||||
self.update_execution_status(
|
||||
execution.execution_id,
|
||||
"completed",
|
||||
output_data=event.get("data"),
|
||||
token_usage=token_usage.get("total_tokens", None)
|
||||
)
|
||||
final_messages = event.get("data", {}).get("messages", [])[init_message_length:]
|
||||
for message in final_messages:
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id_uuid,
|
||||
role=message["role"],
|
||||
content=message["content"],
|
||||
meta_data=None if message["role"] == "user" else {"usage": token_usage}
|
||||
)
|
||||
logger.info(f"Workflow Run Success, "
|
||||
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
|
||||
elif status == "failed":
|
||||
@@ -793,6 +838,8 @@ class WorkflowService:
|
||||
)
|
||||
else:
|
||||
logger.error(f"unexpect workflow run status, status: {status}")
|
||||
elif event.get("event") == "workflow_start":
|
||||
event["data"]["message_id"] = str(message_id)
|
||||
event = self._emit(public, event)
|
||||
if event:
|
||||
yield event
|
||||
|
||||
@@ -130,6 +130,7 @@ def _create_workspace_only(
|
||||
business_logger.error(f"创建工作空间失败: {workspace.name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def create_workspace(
|
||||
db: Session, workspace: WorkspaceCreate, user: User, language: str = "zh"
|
||||
) -> Workspace:
|
||||
@@ -152,6 +153,7 @@ def create_workspace(
|
||||
|
||||
# Initialize default ontology scenes for the workspace (先创建本体场景)
|
||||
default_scene_id = None
|
||||
default_scene_name = None
|
||||
try:
|
||||
initializer = DefaultOntologyInitializer(db)
|
||||
success, error_msg = initializer.initialize_default_scenes(
|
||||
@@ -163,7 +165,7 @@ def create_workspace(
|
||||
f"为工作空间 {db_workspace.id} 创建默认本体场景成功 (language={language})"
|
||||
)
|
||||
|
||||
# 获取默认场景ID,优先使用"在线教育"场景,如果不存在则使用"情感陪伴"场景
|
||||
# 获取默认场景ID,优先使用"在线教育"场景,如果不存在则使用"情感陪伴"场景
|
||||
from app.repositories.ontology_scene_repository import OntologySceneRepository
|
||||
from app.config.default_ontology_config import (
|
||||
ONLINE_EDUCATION_SCENE,
|
||||
@@ -179,6 +181,7 @@ def create_workspace(
|
||||
|
||||
if education_scene:
|
||||
default_scene_id = education_scene.scene_id
|
||||
default_scene_name = education_scene.scene_name
|
||||
business_logger.info(
|
||||
f"获取到教育场景ID用于默认记忆配置: {default_scene_id} (scene_name={education_scene_name})"
|
||||
)
|
||||
@@ -189,6 +192,7 @@ def create_workspace(
|
||||
|
||||
if companion_scene:
|
||||
default_scene_id = companion_scene.scene_id
|
||||
default_scene_name = companion_scene.scene_name
|
||||
business_logger.info(
|
||||
f"教育场景不存在,使用情感陪伴场景ID用于默认记忆配置: {default_scene_id} (scene_name={companion_scene_name})"
|
||||
)
|
||||
@@ -219,6 +223,7 @@ def create_workspace(
|
||||
embedding_id=embedding,
|
||||
rerank_id=rerank,
|
||||
scene_id=default_scene_id, # 传入默认场景ID(优先教育场景,其次情感陪伴场景)
|
||||
pruning_scene_name=default_scene_name, # 传入场景名称作为语义剪枝场景值
|
||||
)
|
||||
business_logger.info(
|
||||
f"为工作空间 {db_workspace.id} 创建默认记忆配置成功 (scene_id={default_scene_id})"
|
||||
@@ -962,6 +967,125 @@ def update_workspace_models_configs(
|
||||
raise BusinessException(f"更新模型配置失败: {str(e)}", BizCode.INTERNAL_ERROR)
|
||||
|
||||
|
||||
def _fill_workspace_configs_model_defaults(
|
||||
db: Session,
|
||||
workspace: Workspace
|
||||
) -> None:
|
||||
"""Fill empty model fields for all memory configs in a workspace.
|
||||
|
||||
Updates llm_id, embedding_id, rerank_id, reflection_model_id, and emotion_model_id
|
||||
if they are None, using the corresponding workspace default models.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
workspace: The workspace containing default model settings
|
||||
"""
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
|
||||
# Get all configs for this workspace
|
||||
configs = db.query(MemoryConfig).filter(
|
||||
MemoryConfig.workspace_id == workspace.id
|
||||
).all()
|
||||
|
||||
if not configs:
|
||||
return
|
||||
|
||||
# Map of memory_config field -> workspace field
|
||||
model_field_mappings = [
|
||||
("llm_id", "llm"),
|
||||
("embedding_id", "embedding"),
|
||||
("rerank_id", "rerank"),
|
||||
("reflection_model_id", "llm"), # reflection uses LLM
|
||||
("emotion_model_id", "llm"), # emotion uses LLM
|
||||
]
|
||||
|
||||
configs_updated = 0
|
||||
|
||||
for memory_config in configs:
|
||||
updated_fields = []
|
||||
|
||||
for config_field, workspace_field in model_field_mappings:
|
||||
config_value = getattr(memory_config, config_field, None)
|
||||
workspace_value = getattr(workspace, workspace_field, None)
|
||||
|
||||
if not config_value and workspace_value:
|
||||
setattr(memory_config, config_field, workspace_value)
|
||||
updated_fields.append(config_field)
|
||||
|
||||
if updated_fields:
|
||||
configs_updated += 1
|
||||
business_logger.debug(
|
||||
f"Updated memory config {memory_config.config_id} fields: {updated_fields}"
|
||||
)
|
||||
|
||||
if configs_updated > 0:
|
||||
try:
|
||||
db.commit()
|
||||
business_logger.info(
|
||||
f"Updated {configs_updated} memory configs in workspace {workspace.id} with default models"
|
||||
)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
business_logger.error(
|
||||
f"Failed to update memory configs in workspace {workspace.id}: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
def _create_default_memory_config(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
workspace_name: str,
|
||||
llm_id: Optional[uuid.UUID] = None,
|
||||
embedding_id: Optional[uuid.UUID] = None,
|
||||
rerank_id: Optional[uuid.UUID] = None,
|
||||
scene_id: Optional[uuid.UUID] = None,
|
||||
pruning_scene_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Create a default memory config for a newly created workspace.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
workspace_id: The workspace ID
|
||||
workspace_name: The workspace name (used for config naming)
|
||||
llm_id: Optional LLM model ID
|
||||
embedding_id: Optional embedding model ID
|
||||
rerank_id: Optional rerank model ID
|
||||
scene_id: Optional ontology scene ID (默认关联教育场景)
|
||||
pruning_scene_name: Optional pruning scene name,取自 ontology_scene.scene_name
|
||||
"""
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
|
||||
config_id = uuid.uuid4()
|
||||
|
||||
default_config = MemoryConfig(
|
||||
config_id=config_id,
|
||||
config_name=f"{workspace_name} 默认配置",
|
||||
config_desc="工作空间创建时自动生成的默认记忆配置",
|
||||
workspace_id=workspace_id,
|
||||
llm_id=str(llm_id) if llm_id else None,
|
||||
embedding_id=str(embedding_id) if embedding_id else None,
|
||||
rerank_id=str(rerank_id) if rerank_id else None,
|
||||
scene_id=scene_id, # 关联本体场景ID(默认为"在线教育"场景)
|
||||
pruning_scene=pruning_scene_name, # 语义剪枝场景直接使用 scene_name
|
||||
state=True, # Active by default
|
||||
is_default=True, # Mark as workspace default
|
||||
)
|
||||
|
||||
db.add(default_config)
|
||||
db.flush() # 使用 flush 而不是 commit,让调用者统一提交
|
||||
|
||||
business_logger.info(
|
||||
"Created default memory config for workspace",
|
||||
extra={
|
||||
"workspace_id": str(workspace_id),
|
||||
"config_id": str(config_id),
|
||||
"config_name": default_config.config_name,
|
||||
"scene_id": str(scene_id) if scene_id else None,
|
||||
}
|
||||
)
|
||||
|
||||
# ==================== 检查配置相关服务 ====================
|
||||
|
||||
def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None:
|
||||
"""Ensure a workspace has a default memory config, creating one if missing.
|
||||
|
||||
@@ -1041,70 +1165,6 @@ def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None:
|
||||
_fill_workspace_configs_model_defaults(db, workspace)
|
||||
|
||||
|
||||
def _fill_workspace_configs_model_defaults(
|
||||
db: Session,
|
||||
workspace: Workspace
|
||||
) -> None:
|
||||
"""Fill empty model fields for all memory configs in a workspace.
|
||||
|
||||
Updates llm_id, embedding_id, rerank_id, reflection_model_id, and emotion_model_id
|
||||
if they are None, using the corresponding workspace default models.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
workspace: The workspace containing default model settings
|
||||
"""
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
|
||||
# Get all configs for this workspace
|
||||
configs = db.query(MemoryConfig).filter(
|
||||
MemoryConfig.workspace_id == workspace.id
|
||||
).all()
|
||||
|
||||
if not configs:
|
||||
return
|
||||
|
||||
# Map of memory_config field -> workspace field
|
||||
model_field_mappings = [
|
||||
("llm_id", "llm"),
|
||||
("embedding_id", "embedding"),
|
||||
("rerank_id", "rerank"),
|
||||
("reflection_model_id", "llm"), # reflection uses LLM
|
||||
("emotion_model_id", "llm"), # emotion uses LLM
|
||||
]
|
||||
|
||||
configs_updated = 0
|
||||
|
||||
for memory_config in configs:
|
||||
updated_fields = []
|
||||
|
||||
for config_field, workspace_field in model_field_mappings:
|
||||
config_value = getattr(memory_config, config_field, None)
|
||||
workspace_value = getattr(workspace, workspace_field, None)
|
||||
|
||||
if not config_value and workspace_value:
|
||||
setattr(memory_config, config_field, workspace_value)
|
||||
updated_fields.append(config_field)
|
||||
|
||||
if updated_fields:
|
||||
configs_updated += 1
|
||||
business_logger.debug(
|
||||
f"Updated memory config {memory_config.config_id} fields: {updated_fields}"
|
||||
)
|
||||
|
||||
if configs_updated > 0:
|
||||
try:
|
||||
db.commit()
|
||||
business_logger.info(
|
||||
f"Updated {configs_updated} memory configs in workspace {workspace.id} with default models"
|
||||
)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
business_logger.error(
|
||||
f"Failed to update memory configs in workspace {workspace.id}: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
def _ensure_default_ontology_scenes(db: Session, workspace: Workspace) -> None:
|
||||
"""Ensure a workspace has default ontology scenes, creating them if missing.
|
||||
|
||||
@@ -1150,53 +1210,3 @@ def _ensure_default_ontology_scenes(db: Session, workspace: Workspace) -> None:
|
||||
f"为工作空间 {workspace.id} 补建默认本体场景异常: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
def _create_default_memory_config(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
workspace_name: str,
|
||||
llm_id: Optional[uuid.UUID] = None,
|
||||
embedding_id: Optional[uuid.UUID] = None,
|
||||
rerank_id: Optional[uuid.UUID] = None,
|
||||
scene_id: Optional[uuid.UUID] = None,
|
||||
) -> None:
|
||||
"""Create a default memory config for a newly created workspace.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
workspace_id: The workspace ID
|
||||
workspace_name: The workspace name (used for config naming)
|
||||
llm_id: Optional LLM model ID
|
||||
embedding_id: Optional embedding model ID
|
||||
rerank_id: Optional rerank model ID
|
||||
scene_id: Optional ontology scene ID (默认关联教育场景)
|
||||
"""
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
|
||||
config_id = uuid.uuid4()
|
||||
|
||||
default_config = MemoryConfig(
|
||||
config_id=config_id,
|
||||
config_name=f"{workspace_name} 默认配置",
|
||||
config_desc="工作空间创建时自动生成的默认记忆配置",
|
||||
workspace_id=workspace_id,
|
||||
llm_id=str(llm_id) if llm_id else None,
|
||||
embedding_id=str(embedding_id) if embedding_id else None,
|
||||
rerank_id=str(rerank_id) if rerank_id else None,
|
||||
scene_id=scene_id, # 关联本体场景ID
|
||||
state=True, # Active by default
|
||||
is_default=True, # Mark as workspace default
|
||||
)
|
||||
|
||||
db.add(default_config)
|
||||
db.flush() # 使用 flush 而不是 commit,让调用者统一提交
|
||||
|
||||
business_logger.info(
|
||||
"Created default memory config for workspace",
|
||||
extra={
|
||||
"workspace_id": str(workspace_id),
|
||||
"config_id": str(config_id),
|
||||
"config_name": default_config.config_name,
|
||||
"scene_id": str(scene_id) if scene_id else None,
|
||||
}
|
||||
)
|
||||
|
||||
261
api/app/tasks.py
261
api/app/tasks.py
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
@@ -14,6 +15,62 @@ from uuid import UUID
|
||||
|
||||
import redis
|
||||
import requests
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 模块级同步 Redis 连接池,供 Celery 任务共享使用
|
||||
# 连接 CELERY_BACKEND DB,与 write_message:last_done 时间戳写入保持一致
|
||||
# 使用连接池而非单例客户端,提供更好的并发性能和自动重连
|
||||
_sync_redis_pool: redis.ConnectionPool = None
|
||||
|
||||
def _get_or_create_redis_pool() -> redis.ConnectionPool:
|
||||
"""获取或创建 Redis 连接池(懒初始化)"""
|
||||
global _sync_redis_pool
|
||||
if _sync_redis_pool is None:
|
||||
try:
|
||||
_sync_redis_pool = redis.ConnectionPool(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB_CELERY_BACKEND,
|
||||
password=settings.REDIS_PASSWORD,
|
||||
decode_responses=True,
|
||||
max_connections=10,
|
||||
socket_connect_timeout=5,
|
||||
socket_timeout=5,
|
||||
retry_on_timeout=True,
|
||||
health_check_interval=30,
|
||||
)
|
||||
logger.info("Redis connection pool created for Celery tasks")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create Redis connection pool: {e}", exc_info=True)
|
||||
return None
|
||||
return _sync_redis_pool
|
||||
|
||||
def get_sync_redis_client() -> Optional[redis.StrictRedis]:
|
||||
"""获取同步 Redis 客户端(使用连接池)
|
||||
|
||||
使用连接池提供的客户端,支持自动重连和健康检查。
|
||||
如果 Redis 不可用,返回 None,调用方应优雅降级。
|
||||
|
||||
Returns:
|
||||
redis.StrictRedis: Redis 客户端实例,如果连接失败则返回 None
|
||||
"""
|
||||
try:
|
||||
pool = _get_or_create_redis_pool()
|
||||
if pool is None:
|
||||
return None
|
||||
|
||||
client = redis.StrictRedis(connection_pool=pool)
|
||||
# 验证连接可用性
|
||||
client.ping()
|
||||
return client
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis connection failed: {e}", exc_info=True)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error getting Redis client: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
# Import a unified Celery instance
|
||||
from app.celery_app import celery_app
|
||||
@@ -1090,6 +1147,22 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
||||
logger.info(
|
||||
f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
|
||||
|
||||
# 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用
|
||||
try:
|
||||
_r = get_sync_redis_client()
|
||||
if _r is not None:
|
||||
from datetime import timedelta as _td
|
||||
from datetime import timezone as _tz
|
||||
_CST = _tz(_td(hours=8))
|
||||
_now_cst = datetime.now(_CST).replace(tzinfo=None).isoformat()
|
||||
_r.set(
|
||||
f"write_message:last_done:{end_user_id}",
|
||||
_now_cst,
|
||||
ex=86400 * 30,
|
||||
)
|
||||
except Exception as _e:
|
||||
logger.warning(f"[CELERY WRITE] 写入 last_done 时间戳失败(不影响主流程): {_e}")
|
||||
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"result": result,
|
||||
@@ -2149,12 +2222,15 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository
|
||||
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
|
||||
from sqlalchemy import select, func
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
from app.repositories.implicit_emotions_storage_repository import (
|
||||
ImplicitEmotionsStorageRepository,
|
||||
)
|
||||
from app.services.emotion_analytics_service import EmotionAnalyticsService
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info("开始执行隐性记忆和情绪数据更新定时任务")
|
||||
@@ -2167,18 +2243,20 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
# 获取所有已存储数据的用户ID(分批次处理)
|
||||
repo = ImplicitEmotionsStorageRepository(db)
|
||||
|
||||
|
||||
# 先统计总数用于日志
|
||||
from sqlalchemy import func
|
||||
total_users = db.execute(
|
||||
select(func.count()).select_from(ImplicitEmotionsStorage)
|
||||
).scalar() or 0
|
||||
logger.info(f"找到 {total_users} 个需要更新的用户")
|
||||
logger.info(f"表中存量用户总数: {total_users},开始时间轴筛选")
|
||||
|
||||
# 遍历每个用户并更新数据(分批次,避免一次性加载所有ID)
|
||||
for end_user_id in repo.get_all_user_ids(batch_size=100):
|
||||
# 构建 Redis 同步客户端,用于时间轴筛选
|
||||
_redis_client = get_sync_redis_client()
|
||||
|
||||
# 只处理 last_done > updated_at 的用户(有新记忆写入的用户)
|
||||
for end_user_id in repo.get_users_needing_refresh(_redis_client, batch_size=100):
|
||||
logger.info(f"开始处理用户: {end_user_id}")
|
||||
user_start_time = time.time()
|
||||
|
||||
@@ -2264,10 +2342,10 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
user_results.append(error_info)
|
||||
logger.error(f"处理用户 {end_user_id} 时出错: {str(e)}")
|
||||
|
||||
# ---- 处理增量用户(当天新增、尚未初始化的用户)----
|
||||
# ---- 当天新增用户兜底初始化 ----
|
||||
new_users_initialized = 0
|
||||
new_users_failed = 0
|
||||
logger.info("开始处理当天新增的增量用户初始化")
|
||||
logger.info("开始处理当天新增用户的兜底初始化")
|
||||
|
||||
for end_user_id in repo.get_new_user_ids_today(batch_size=100):
|
||||
logger.info(f"开始初始化新用户: {end_user_id}")
|
||||
@@ -2281,35 +2359,27 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
implicit_service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
profile_data = await implicit_service.generate_complete_profile(user_id=end_user_id)
|
||||
await implicit_service.save_profile_cache(
|
||||
end_user_id=end_user_id,
|
||||
profile_data=profile_data,
|
||||
db=db
|
||||
end_user_id=end_user_id, profile_data=profile_data, db=db
|
||||
)
|
||||
implicit_success = True
|
||||
logger.info(f"成功初始化新用户 {end_user_id} 的隐性记忆画像")
|
||||
except Exception as e:
|
||||
error_msg = f"隐性记忆初始化失败: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(f"新用户 {end_user_id} {error_msg}")
|
||||
errors.append(f"隐性记忆初始化失败: {str(e)}")
|
||||
logger.error(f"新用户 {end_user_id} 隐性记忆初始化失败: {e}")
|
||||
|
||||
try:
|
||||
emotion_service = EmotionAnalyticsService()
|
||||
suggestions_data = await emotion_service.generate_emotion_suggestions(
|
||||
end_user_id=end_user_id,
|
||||
db=db,
|
||||
language="zh"
|
||||
end_user_id=end_user_id, db=db, language="zh"
|
||||
)
|
||||
await emotion_service.save_suggestions_cache(
|
||||
end_user_id=end_user_id,
|
||||
suggestions_data=suggestions_data,
|
||||
db=db
|
||||
end_user_id=end_user_id, suggestions_data=suggestions_data, db=db
|
||||
)
|
||||
emotion_success = True
|
||||
logger.info(f"成功初始化新用户 {end_user_id} 的情绪建议")
|
||||
except Exception as e:
|
||||
error_msg = f"情绪建议初始化失败: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(f"新用户 {end_user_id} {error_msg}")
|
||||
errors.append(f"情绪建议初始化失败: {str(e)}")
|
||||
logger.error(f"新用户 {end_user_id} 情绪建议初始化失败: {e}")
|
||||
|
||||
if implicit_success or emotion_success:
|
||||
new_users_initialized += 1
|
||||
@@ -2319,7 +2389,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
user_elapsed = time.time() - user_start_time
|
||||
user_results.append({
|
||||
"end_user_id": end_user_id,
|
||||
"type": "init",
|
||||
"type": "new_user_init",
|
||||
"implicit_success": implicit_success,
|
||||
"emotion_success": emotion_success,
|
||||
"errors": errors,
|
||||
@@ -2331,7 +2401,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
user_elapsed = time.time() - user_start_time
|
||||
user_results.append({
|
||||
"end_user_id": end_user_id,
|
||||
"type": "init",
|
||||
"type": "new_user_init",
|
||||
"implicit_success": False,
|
||||
"emotion_success": False,
|
||||
"errors": [str(e)],
|
||||
@@ -2339,27 +2409,24 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
})
|
||||
logger.error(f"初始化新用户 {end_user_id} 时出错: {str(e)}")
|
||||
|
||||
logger.info(
|
||||
f"增量用户初始化完成: 成功={new_users_initialized}, 失败={new_users_failed}"
|
||||
)
|
||||
# ---- 增量用户处理结束 ----
|
||||
logger.info(f"当天新增用户兜底初始化完成: 成功={new_users_initialized}, 失败={new_users_failed}")
|
||||
# ---- 新增用户兜底初始化结束 ----
|
||||
|
||||
# 记录总体统计信息
|
||||
logger.info(
|
||||
f"隐性记忆和情绪数据更新定时任务完成: "
|
||||
f"存量用户总数={total_users}, "
|
||||
f"隐性记忆成功={successful_implicit}, "
|
||||
f"情绪建议成功={successful_emotion}, "
|
||||
f"存量失败={failed}, "
|
||||
f"增量初始化成功={new_users_initialized}, "
|
||||
f"增量初始化失败={new_users_failed}"
|
||||
f"新增用户初始化成功={new_users_initialized}, "
|
||||
f"新增用户初始化失败={new_users_failed}"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"message": (
|
||||
f"存量用户 {total_users} 个,隐性记忆 {successful_implicit} 个成功,情绪建议 {successful_emotion} 个成功;"
|
||||
f"增量新用户初始化 {new_users_initialized} 个成功,{new_users_failed} 个失败"
|
||||
f"当天新增用户初始化 {new_users_initialized} 个成功,{new_users_failed} 个失败"
|
||||
),
|
||||
"total_users": total_users,
|
||||
"successful_implicit": successful_implicit,
|
||||
@@ -2367,7 +2434,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
"failed": failed,
|
||||
"new_users_initialized": new_users_initialized,
|
||||
"new_users_failed": new_users_failed,
|
||||
"user_results": user_results[:50] # 只保留前50个用户的详细结果
|
||||
"user_results": user_results[:50]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -2416,3 +2483,125 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
||||
@celery_app.task(
|
||||
name="app.tasks.init_implicit_emotions_for_users",
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
max_retries=0,
|
||||
acks_late=False,
|
||||
time_limit=3600,
|
||||
soft_time_limit=3300,
|
||||
# 触发型任务标识,区别于 periodic_tasks 队列中的定时任务
|
||||
triggered=True,
|
||||
)
|
||||
def init_implicit_emotions_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]:
|
||||
"""事件触发任务:对指定用户列表做存在性检查,无记录则执行首次初始化。
|
||||
|
||||
由 /dashboard/end_users 接口触发,已有数据的用户直接跳过。
|
||||
存量用户的数据刷新由定时任务 update_implicit_emotions_storage 负责。
|
||||
|
||||
Args:
|
||||
end_user_ids: 需要检查的用户ID列表
|
||||
|
||||
Returns:
|
||||
包含任务执行结果的字典
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_logger
|
||||
from app.repositories.implicit_emotions_storage_repository import (
|
||||
ImplicitEmotionsStorageRepository,
|
||||
)
|
||||
from app.services.emotion_analytics_service import EmotionAnalyticsService
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info(f"开始按需初始化隐性记忆/情绪数据,候选用户数: {len(end_user_ids)}")
|
||||
|
||||
initialized = 0
|
||||
failed = 0
|
||||
skipped = 0
|
||||
|
||||
with get_db_context() as db:
|
||||
repo = ImplicitEmotionsStorageRepository(db)
|
||||
|
||||
for end_user_id in end_user_ids:
|
||||
existing = repo.get_by_end_user_id(end_user_id)
|
||||
if existing is not None:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
logger.info(f"用户 {end_user_id} 无记录,开始初始化")
|
||||
implicit_ok = False
|
||||
emotion_ok = False
|
||||
try:
|
||||
try:
|
||||
implicit_service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
profile_data = await implicit_service.generate_complete_profile(user_id=end_user_id)
|
||||
await implicit_service.save_profile_cache(
|
||||
end_user_id=end_user_id, profile_data=profile_data, db=db
|
||||
)
|
||||
implicit_ok = True
|
||||
except Exception as e:
|
||||
logger.error(f"用户 {end_user_id} 隐性记忆初始化失败: {e}")
|
||||
|
||||
try:
|
||||
emotion_service = EmotionAnalyticsService()
|
||||
suggestions_data = await emotion_service.generate_emotion_suggestions(
|
||||
end_user_id=end_user_id, db=db, language="zh"
|
||||
)
|
||||
await emotion_service.save_suggestions_cache(
|
||||
end_user_id=end_user_id, suggestions_data=suggestions_data, db=db
|
||||
)
|
||||
emotion_ok = True
|
||||
except Exception as e:
|
||||
logger.error(f"用户 {end_user_id} 情绪建议初始化失败: {e}")
|
||||
|
||||
if implicit_ok or emotion_ok:
|
||||
initialized += 1
|
||||
else:
|
||||
failed += 1
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
logger.error(f"用户 {end_user_id} 初始化异常: {e}")
|
||||
|
||||
logger.info(f"按需初始化完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}")
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"initialized": initialized,
|
||||
"skipped": skipped,
|
||||
"failed": failed,
|
||||
}
|
||||
|
||||
try:
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
result["elapsed_time"] = time.time() - start_time
|
||||
result["task_id"] = self.request.id
|
||||
return result
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"elapsed_time": time.time() - start_time,
|
||||
"task_id": self.request.id,
|
||||
}
|
||||
|
||||
@@ -1,4 +1,36 @@
|
||||
{
|
||||
"v0.2.6": {
|
||||
"introduction": {
|
||||
"codeName": "听剑",
|
||||
"releaseDate": "2026-3-6",
|
||||
"upgradePosition": "🐻 多模态交互全面升级,记忆剪枝与工作流迁移双线并进,锋芒初露,兼收并蓄",
|
||||
"coreUpgrades": [
|
||||
"1. 工作流与应用框架<br>* 工作流导入适配(Dify):支持 Dify 工作流定义无缝迁移<br>* 字段字数限制与校验规则:可配置字符限制与产品级校验<br>* 应用复制(Agent、工作流、集群):一键复制完整应用配置<br>* 对话变量(调试+分享):支持有状态多轮交互<br>* Chat 接口流式输出 message_id:流式响应包含消息追踪标识",
|
||||
"2. 多模态与交互 💬<br>* 音频输入与输出:应用支持音频模态<br>* 文件类型输入支持:扩展支持语音、文件、视频上传",
|
||||
"3. 模型与智能 🧠<br>* 模型视觉与 Omni 区分:精确区分视觉与 Omni 模型能力<br>* 教育记忆与陪伴玩具场景预设:垂直领域本体配置开箱即用<br>* 本体配置默认标识:支持基线配置标记<br>* 记忆配置默认标识:自动应用默认记忆设置",
|
||||
"4. 记忆智能 🔬<br>* 记忆剪枝模块:智能裁剪冗余低价值记忆<br>* RAG 快速检索集成记忆:深度思考与正常回复双模式检索",
|
||||
"5. 稳健性与缺陷修复 🔧<br>* 模型管理:修复自定义模型 API Key 批量配置错误<br>* 知识库管理:修复非源文档下载原始内容接口错误,更新分享停用提示文案<br>* 用户记忆:优化档案提取准确性(姓名、职业、兴趣分布)<br>* 长期记忆:修复情景记忆卡片重复和用户归属错误<br>* 工作空间首页:修复知识库数量、应用数量、总记忆容量、API 调用次数、知识库类型分布等数据不一致问题<br>* 基础设施:修正 Celery 环境变量配置,修复数据库连接池 idle-in-transaction 泄漏",
|
||||
"<br>",
|
||||
"v0.2.6 标志着 MemoryBear 在多模态交互、跨平台工作流迁移和智能记忆管理方面的重要突破。下一版本将聚焦 A2A 协议支持实现多智能体协作、多模态记忆能力扩展至语音与视觉领域,以及应用导入导出功能支持跨环境便携部署。",
|
||||
"MemoryBear,让记忆有熊力 🐻✨"
|
||||
]
|
||||
},
|
||||
"introduction_en": {
|
||||
"codeName": "TingJian",
|
||||
"releaseDate": "2026-3-6",
|
||||
"upgradePosition": "🐻 Full multimodal interaction upgrade with memory pruning and workflow migration — sharpened edge, broader reach",
|
||||
"coreUpgrades": [
|
||||
"1. Workflow & Application Framework<br>* Workflow Import Adaptation (Dify): Seamless Dify workflow migration<br>* Field Character Limits & Validation: Configurable limits with product-defined rules<br>* Application Cloning (Agent, Workflow, Cluster): One-click full config duplication<br>* Conversation Variables (Debug + Share): Stateful multi-turn interactions<br>* Streaming message_id in Chat API: Message tracking in streaming responses",
|
||||
"2. Multimodal & Interaction 💬<br>* Audio Input & Output: Audio modality support for applications<br>* File Type Input Support: Voice, file, and video upload support",
|
||||
"3. Model & Intelligence 🧠<br>* Model Vision & Omni Differentiation: Precise capability routing<br>* Education Memory & Companion Toy Presets: Domain-specific ontology configs<br>* Ontology Default Identifier: Baseline configuration flagging<br>* Memory Configuration Default Identifier: Auto-apply default settings",
|
||||
"4. Memory Intelligence 🔬<br>* Memory Pruning Module: Intelligent trimming of redundant memories<br>* RAG Quick Retrieval with Memory: Deep think and normal reply dual-mode retrieval",
|
||||
"5. Robustness & Bug Fixes 🔧<br>* Model Management: Fixed custom model API key batch configuration error<br>* Knowledge Base: Fixed download original content API error for non-source documents, updated share disable prompt text<br>* User Memory: Improved profile extraction accuracy (name, occupation, interests)<br>* Long-Term Memory: Fixed duplicate episodic memory cards and wrong user attribution<br>* Dashboard: Fixed data inconsistencies in knowledge count, app count, memory capacity, API calls, and knowledge type distribution<br>* Infrastructure: Corrected Celery environment variables, fixed database connection pool idle-in-transaction leak",
|
||||
"<br>",
|
||||
"v0.2.6 marks a significant milestone for MemoryBear in multimodal interaction, cross-platform workflow migration, and intelligent memory management. The next release will focus on A2A protocol support for multi-agent collaboration, multimodal memory extending extraction to voice and visual domains, and application import/export for portable cross-environment deployment.",
|
||||
"MemoryBear, Memory with Bear Power 🐻✨"
|
||||
]
|
||||
}
|
||||
},
|
||||
"v0.2.5": {
|
||||
"introduction": {
|
||||
"codeName": "行云",
|
||||
|
||||
@@ -49,7 +49,7 @@ services:
|
||||
networks:
|
||||
- celery
|
||||
|
||||
# Periodic worker - Scheduled/beat tasks (prefork, low concurrency)
|
||||
# Periodic worker - Scheduled/beat tasks + API-triggered tasks (prefork, low concurrency)
|
||||
worker-periodic:
|
||||
image: redbear-mem-open:latest
|
||||
container_name: worker-periodic
|
||||
|
||||
36
api/migrations/versions/1ac07dc7366f_202603061644.py
Normal file
36
api/migrations/versions/1ac07dc7366f_202603061644.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""202603061644
|
||||
|
||||
Revision ID: 1ac07dc7366f
|
||||
Revises: 6a4641cf192b
|
||||
Create Date: 2026-03-06 16:51:10.152305
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '1ac07dc7366f'
|
||||
down_revision: Union[str, None] = '6a4641cf192b'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('mcp_tool_configs', sa.Column('source_channel', sa.String(length=50), server_default=sa.text("'self_hosted'"), nullable=False, comment='来源渠道'))
|
||||
op.add_column('mcp_tool_configs', sa.Column('market_id', sa.UUID(), nullable=True, comment='渠道市场id'))
|
||||
op.add_column('mcp_tool_configs', sa.Column('market_config_id', sa.UUID(), nullable=True, comment='渠道市场配置id'))
|
||||
op.add_column('mcp_tool_configs', sa.Column('mcp_service_id', sa.String(length=255), nullable=True, comment='mcp服务id'))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('mcp_tool_configs', 'mcp_service_id')
|
||||
op.drop_column('mcp_tool_configs', 'market_config_id')
|
||||
op.drop_column('mcp_tool_configs', 'market_id')
|
||||
op.drop_column('mcp_tool_configs', 'source_channel')
|
||||
# ### end Alembic commands ###
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2025-12-10 16:46:14
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-04 18:42:49
|
||||
* @Last Modified time: 2026-03-06 13:36:20
|
||||
*/
|
||||
import { type FC, useEffect, useMemo } from 'react'
|
||||
import { Flex, Input, Form } from 'antd'
|
||||
@@ -50,13 +50,17 @@ const ChatInput: FC<ChatInputProps> = ({
|
||||
|
||||
|
||||
const handleDelete = (file: any) => {
|
||||
fileChange?.(fileList?.filter(item => item.uid !== file.uid) || [])
|
||||
fileChange?.(fileList?.filter(item => {
|
||||
return item.thumbUrl && file.thumbUrl ? item.thumbUrl !== file.thumbUrl
|
||||
: item.url && file.url ? item.url !== file.url
|
||||
: item.uid !== file.uid
|
||||
}) || [])
|
||||
}
|
||||
// Convert file object to preview URL
|
||||
const previewFileList = useMemo(() => {
|
||||
return fileList?.map(file => ({
|
||||
...file,
|
||||
url: file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : file.thumbUrl)
|
||||
url: file.thumbUrl || file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : undefined)
|
||||
})) || []
|
||||
}, [fileList])
|
||||
|
||||
@@ -72,7 +76,7 @@ const ChatInput: FC<ChatInputProps> = ({
|
||||
{previewFileList.map((file) => {
|
||||
if (file.type.includes('image')) {
|
||||
return (
|
||||
<div key={file.uid} className="rb:inline-block rb:group rb:relative rb:rounded-lg">
|
||||
<div key={file.url || file.uid} className="rb:inline-block rb:group rb:relative rb:rounded-lg">
|
||||
<img src={file.url} alt={file.name} className="rb:size-12! rb:rounded-lg rb:object-cover rb:cursor-pointer" />
|
||||
<div
|
||||
className="rb:hidden rb:group-hover:block rb:absolute rb:-right-1 rb:-top-1 rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/delete.svg')] rb:hover:bg-[url('@/assets/images/conversation/delete_hover.svg')]"
|
||||
@@ -83,7 +87,7 @@ const ChatInput: FC<ChatInputProps> = ({
|
||||
}
|
||||
if (file.type.includes('video')) {
|
||||
return (
|
||||
<div key={file.uid} className="rb:w-45 rb:h-16 rb:inline-block rb:group rb:relative rb:rounded-lg">
|
||||
<div key={file.url || file.uid} className="rb:w-45 rb:h-16 rb:inline-block rb:group rb:relative rb:rounded-lg">
|
||||
<video src={file.url} controls className="rb:w-45 rb:h-16 rb:rounded-lg rb:object-cover rb:cursor-pointer" />
|
||||
<div
|
||||
className="rb:hidden rb:group-hover:block rb:absolute rb:-right-1 rb:-top-1 rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/delete.svg')] rb:hover:bg-[url('@/assets/images/conversation/delete_hover.svg')]"
|
||||
@@ -94,7 +98,7 @@ const ChatInput: FC<ChatInputProps> = ({
|
||||
}
|
||||
if (file.type.includes('audio')) {
|
||||
return (
|
||||
<div key={file.uid} className="rb:w-45 rb:h-16 rb:inline-flex rb:items-center rb:group rb:relative rb:rounded-lg rb:bg-[#F0F3F8] rb:py-2 rb:px-2.5 rb:gap-2">
|
||||
<div key={file.url || file.uid} className="rb:w-45 rb:h-16 rb:inline-flex rb:items-center rb:group rb:relative rb:rounded-lg rb:bg-[#F0F3F8] rb:py-2 rb:px-2.5 rb:gap-2">
|
||||
<audio src={file.url} controls className="rb:w-45 rb:h-16" />
|
||||
<div
|
||||
className="rb:hidden rb:group-hover:block rb:absolute rb:-right-1 rb:-top-1 rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/delete.svg')] rb:hover:bg-[url('@/assets/images/conversation/delete_hover.svg')]"
|
||||
@@ -104,7 +108,7 @@ const ChatInput: FC<ChatInputProps> = ({
|
||||
)
|
||||
}
|
||||
return (
|
||||
<div key={file.uid} className="rb:w-45 rb:text-[12px] rb:gap-2.5 rb:flex rb:items-center rb:group rb:relative rb:rounded-lg rb:bg-[#F0F3F8] rb:py-2 rb:px-2.5">
|
||||
<div key={file.url || file.uid} className="rb:w-45 rb:text-[12px] rb:gap-2.5 rb:flex rb:items-center rb:group rb:relative rb:rounded-lg rb:bg-[#F0F3F8] rb:py-2 rb:px-2.5">
|
||||
{(file.type.includes('doc') || file.type.includes('docx') || file.type.includes('word') || file.type.includes('wordprocessingml.document')) && <div
|
||||
className="rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/word_disabled.svg')] rb:hover:bg-[url('@/assets/images/conversation/word.svg')]"
|
||||
></div>}
|
||||
|
||||
@@ -1361,6 +1361,7 @@ export const en = {
|
||||
complex: 'Compatibility Analysis',
|
||||
sureInfo: 'Information Confirmation',
|
||||
completed: 'Import Completed',
|
||||
baseInfo: 'Basic Information',
|
||||
workflowName: 'Workflow Name',
|
||||
fileName: 'File Name',
|
||||
fileSize: 'File Size',
|
||||
@@ -1572,7 +1573,7 @@ export const en = {
|
||||
intelligentSemanticPruningFunction: 'Intelligent Semantic Pruning Function',
|
||||
intelligentSemanticPruningFunctionDesc: 'Whether to activate intelligent semantic pruning (true/false).',
|
||||
intelligentSemanticPruningScene: 'Intelligent Semantic Pruning Scene',
|
||||
intelligentSemanticPruningSceneDesc: 'Select intelligent semantic pruning scene (education, online_service, outbound).',
|
||||
intelligentSemanticPruningSceneDesc: 'Semantic pruning scenarios are consistent with ontology engineering scenarios',
|
||||
intelligentSemanticPruningThreshold: 'Intelligent Semantic Pruning Threshold',
|
||||
intelligentSemanticPruningThresholdDesc: 'Set intelligent semantic pruning threshold (0-0.9).',
|
||||
reflectionEngine: 'Self-Reflexion Engine',
|
||||
|
||||
@@ -96,7 +96,7 @@ export const zh = {
|
||||
createMemorySummary: '创建记忆摘要',
|
||||
memoryManagement: '记忆管理',
|
||||
spaceManagement: '空间管理',
|
||||
memoryExtractionEngine: '记忆提取引擎',
|
||||
memoryExtractionEngine: '记忆萃取引擎',
|
||||
forgettingEngine: '遗忘引擎',
|
||||
apiKeyManagement: 'API KEY管理',
|
||||
knowledgePrivate: '详情',
|
||||
@@ -1283,7 +1283,7 @@ export const zh = {
|
||||
createConfiguration: '创建配置',
|
||||
editConfiguration: '编辑配置',
|
||||
desc: '描述',
|
||||
memoryExtractionEngine: '记忆提取引擎',
|
||||
memoryExtractionEngine: '记忆萃取引擎',
|
||||
forgottenEngine: '遗忘引擎',
|
||||
active: '活跃',
|
||||
inactive: '不活跃',
|
||||
@@ -1571,7 +1571,7 @@ export const zh = {
|
||||
intelligentSemanticPruningFunction: '智能语义修剪功能',
|
||||
intelligentSemanticPruningFunctionDesc: '是否激活智能语义修剪(true/false)。',
|
||||
intelligentSemanticPruningScene: '智能语义修剪场景',
|
||||
intelligentSemanticPruningSceneDesc: '选择智能语义修剪场景(education、online_service、outbound)。',
|
||||
intelligentSemanticPruningSceneDesc: '语义剪枝场景与本体工程场景一致',
|
||||
intelligentSemanticPruningThreshold: '智能语义修剪阈值',
|
||||
intelligentSemanticPruningThresholdDesc: '设置智能语义修剪阈值(0-0.9)。',
|
||||
reflectionEngine: '自我反思引擎',
|
||||
|
||||
@@ -356,12 +356,11 @@ export const request = {
|
||||
* Get parent domain for cookie setting
|
||||
* @returns Parent domain or IP address
|
||||
*/
|
||||
const isIp = (hostname: string) => /^\d+\.\d+\.\d+\.\d+$/.test(hostname)
|
||||
|
||||
const getParentDomain = () => {
|
||||
const hostname = window.location.hostname
|
||||
// Check if it's an IP address
|
||||
if (/^\d+\.\d+\.\d+\.\d+$/.test(hostname)) {
|
||||
return hostname
|
||||
}
|
||||
if (isIp(hostname)) return hostname
|
||||
const parts = hostname.split('.')
|
||||
return parts.length > 2 ? `.${parts.slice(-2).join('.')}` : hostname
|
||||
}
|
||||
@@ -371,7 +370,10 @@ const getParentDomain = () => {
|
||||
*/
|
||||
export const cookieUtils = {
|
||||
set: (name: string, value: string, domain = getParentDomain()) => {
|
||||
document.cookie = `${name}=${value}; domain=${domain}; path=/; secure; samesite=strict`
|
||||
const ip = isIp(window.location.hostname)
|
||||
const domainPart = ip ? '' : `; domain=${domain}`
|
||||
const securePart = window.location.protocol === 'https:' ? '; secure' : ''
|
||||
document.cookie = `${name}=${value}${domainPart}; path=/${securePart}; samesite=strict`
|
||||
},
|
||||
get: (name: string) => {
|
||||
const value = `; ${document.cookie}`
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-28 14:08:14
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-02 17:39:49
|
||||
* @Last Modified time: 2026-03-06 12:05:46
|
||||
*/
|
||||
/**
|
||||
* UploadWorkflowModal Component
|
||||
@@ -101,6 +101,7 @@ const UploadWorkflowModal = forwardRef<UploadWorkflowModalRef, UploadWorkflowMod
|
||||
formData.append('platform', values.platform);
|
||||
formData.append('file', values.file[0]);
|
||||
|
||||
setLoading(true)
|
||||
// Call import workflow API
|
||||
importWorkflow(formData)
|
||||
.then(res => {
|
||||
@@ -114,21 +115,24 @@ const UploadWorkflowModal = forwardRef<UploadWorkflowModalRef, UploadWorkflowMod
|
||||
} else {
|
||||
setCurrent(2);
|
||||
// Pre-fill form with file information
|
||||
const fileNameSplit = values.file[0].name.split('.')
|
||||
form.setFieldsValue({
|
||||
name: values.file[0].name.split('.')[0],
|
||||
name: fileNameSplit.slice(0, fileNameSplit.length - 1).join('.'),
|
||||
platform: values.platform,
|
||||
fileName: values.file[0].name,
|
||||
fileSize: values.file[0].size,
|
||||
});
|
||||
}
|
||||
});
|
||||
})
|
||||
.finally(() => setLoading(false));
|
||||
break;
|
||||
case 1: // Step 2: Error/warning display
|
||||
if (firstFormData) {
|
||||
const { file, platform } = firstFormData;
|
||||
const fileNameSplit = firstFormData.file[0].name.split('.')
|
||||
// Pre-fill form with file information
|
||||
form.setFieldsValue({
|
||||
name: file[0].name.split('.')[0],
|
||||
name: fileNameSplit.slice(0, fileNameSplit.length - 1).join('.'),
|
||||
platform: platform,
|
||||
fileName: file[0].name,
|
||||
fileSize: file[0].size,
|
||||
@@ -138,6 +142,7 @@ const UploadWorkflowModal = forwardRef<UploadWorkflowModalRef, UploadWorkflowMod
|
||||
break;
|
||||
case 2: // Step 3: Confirm information
|
||||
if (data) {
|
||||
setLoading(true);
|
||||
// Complete import workflow
|
||||
completeImportWorkflow({
|
||||
temp_id: data.temp_id,
|
||||
@@ -148,7 +153,8 @@ const UploadWorkflowModal = forwardRef<UploadWorkflowModalRef, UploadWorkflowMod
|
||||
const response = res as { id: string };
|
||||
setCurrent(3);
|
||||
setAppId(response.id);
|
||||
});
|
||||
})
|
||||
.finally(() => setLoading(false));
|
||||
}
|
||||
break;
|
||||
default:
|
||||
@@ -175,7 +181,9 @@ const UploadWorkflowModal = forwardRef<UploadWorkflowModalRef, UploadWorkflowMod
|
||||
}
|
||||
|
||||
// Reset form if not going back to error/warning step
|
||||
if (newStep !== 1) {
|
||||
if (newStep === 0) {
|
||||
form.setFieldsValue(firstFormData || {})
|
||||
} else if (newStep !== 1) {
|
||||
form.resetFields();
|
||||
}
|
||||
setCurrent(newStep);
|
||||
@@ -186,14 +194,16 @@ const UploadWorkflowModal = forwardRef<UploadWorkflowModalRef, UploadWorkflowMod
|
||||
* @param {string} type - Navigation type ('detail' or 'list')
|
||||
*/
|
||||
const handleJump = (type: string) => {
|
||||
switch(type) {
|
||||
case 'detail':
|
||||
// Open application detail page in new tab
|
||||
window.open(`/#/application/config/${appId}`, '_blank');
|
||||
break;
|
||||
}
|
||||
refresh();
|
||||
handleClose();
|
||||
refresh();
|
||||
setTimeout(() => {
|
||||
switch (type) {
|
||||
case 'detail':
|
||||
// Open application detail page in new tab
|
||||
window.open(`/#/application/config/${appId}`, '_blank');
|
||||
break;
|
||||
}
|
||||
}, 100)
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -235,7 +245,7 @@ const UploadWorkflowModal = forwardRef<UploadWorkflowModalRef, UploadWorkflowMod
|
||||
</Button>
|
||||
];
|
||||
}
|
||||
}, [current]);
|
||||
}, [current, loading]);
|
||||
|
||||
return (
|
||||
<RbModal
|
||||
@@ -350,7 +360,7 @@ const UploadWorkflowModal = forwardRef<UploadWorkflowModalRef, UploadWorkflowMod
|
||||
title={t('application.importSuccess')}
|
||||
subTitle={t('application.importSuccessDesc')}
|
||||
extra={[
|
||||
<Button key="back" onClick={() => handleJump('list')}>
|
||||
<Button key="back" onClick={() => handleJump('list')}>
|
||||
{t('application.gotoList')}
|
||||
</Button>,
|
||||
<Button
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-06 21:09:42
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-05 15:09:22
|
||||
* @Last Modified time: 2026-03-06 12:20:43
|
||||
*/
|
||||
/**
|
||||
* File Upload Component
|
||||
@@ -208,6 +208,7 @@ const UploadFiles = forwardRef<UploadFilesRef, UploadFilesProps>(({
|
||||
newFileList.map(file => {
|
||||
const type = (file.type && transform_file_type[file.type as keyof typeof transform_file_type]) || file.type || 'document'
|
||||
file.type = type
|
||||
file.thumbUrl = file.thumbUrl || URL.createObjectURL(file.originFileObj as Blob)
|
||||
})
|
||||
setFileList(newFileList);
|
||||
if (onChange) {
|
||||
|
||||
@@ -672,9 +672,17 @@ const CreateModal = forwardRef<CreateModalRef, CreateModalRefProps>(({
|
||||
{currentType !== 'Folder' && dynamicTypeList.map((tp) => {
|
||||
const fieldKey = typeToFieldKey(tp);
|
||||
// When tp is 'llm', merge llm and chat options
|
||||
const options = tp.toLowerCase() === 'llm' || tp.toLowerCase() === 'image2text'
|
||||
let options = tp.toLowerCase() === 'llm' || tp.toLowerCase() === 'image2text'
|
||||
? [...(modelOptionsByType['llm'] || []), ...(modelOptionsByType['chat'] || [])]
|
||||
: modelOptionsByType[tp] || [];
|
||||
|
||||
// When tp is 'image2text', filter to only include models with 'vision' capability
|
||||
if (tp.toLowerCase() === 'image2text') {
|
||||
options = options.filter((opt: any) => {
|
||||
const model = models?.items?.find((m: any) => m.id === opt.value);
|
||||
return model?.capability?.includes('vision');
|
||||
});
|
||||
}
|
||||
return (
|
||||
<Form.Item
|
||||
key={tp}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 17:30:06
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-04 10:09:45
|
||||
* @Last Modified time: 2026-03-06 13:49:00
|
||||
*/
|
||||
/**
|
||||
* Memory Extraction Engine Configuration Constants
|
||||
@@ -140,13 +140,8 @@ export const configList: ConfigVo[] = [
|
||||
{
|
||||
label: 'intelligentSemanticPruningScene',
|
||||
variableName: 'pruning_scene',
|
||||
control: 'select',
|
||||
control: 'text',
|
||||
type: 'enum',
|
||||
options: [
|
||||
{ label: 'education', value: 'education' },
|
||||
{ label: 'online_service', value: 'online_service' },
|
||||
{ label: 'outbound', value: 'outbound' },
|
||||
],
|
||||
meaning: 'intelligentSemanticPruningSceneDesc',
|
||||
},
|
||||
// Intelligent semantic pruning阈值
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
/*
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 17:30:02
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-03 17:30:02
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-06 13:50:05
|
||||
*/
|
||||
/**
|
||||
* Memory Extraction Engine Configuration Page
|
||||
@@ -13,7 +13,7 @@
|
||||
import { type FC, useState, useEffect } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useParams } from 'react-router-dom'
|
||||
import { Row, Col, Space, Select, InputNumber, Slider, App, Form } from 'antd'
|
||||
import { Row, Col, Space, Select, InputNumber, Slider, App, Form, Input } from 'antd'
|
||||
import clsx from 'clsx'
|
||||
|
||||
import Card from './components/Card'
|
||||
@@ -35,15 +35,15 @@ const keys = [
|
||||
/**
|
||||
* Configuration description component
|
||||
*/
|
||||
const ConfigDesc: FC<{ config: Variable, className?: string }> = ({config, className}) => {
|
||||
const ConfigDesc: FC<{ config: Variable, className?: string; onlyMeaning?: boolean; }> = ({ config, className, onlyMeaning = false}) => {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<div className={className}>
|
||||
<Space size={8} className={clsx("rb:mt-1 rb:text-[12px] rb:text-[#5B6167] rb:font-regular rb:leading-4 ")}>
|
||||
{!onlyMeaning && <Space size={8} className={clsx("rb:mt-1 rb:text-[12px] rb:text-[#5B6167] rb:font-regular rb:leading-4 ")}>
|
||||
{config.variableName && <span className="rb:font-regular">{t('memoryExtractionEngine.variableName')}: {config.variableName}</span>}
|
||||
{config.control && <span className="rb:font-regular">{t('memoryExtractionEngine.control')}: {t(`memoryExtractionEngine.${config.control}`)}</span>}
|
||||
{config.type && <span className="rb:font-regular">{t('memoryExtractionEngine.type')}: {config.type}</span>}
|
||||
</Space>
|
||||
</Space>}
|
||||
{config.meaning && <div className={clsx("rb:mt-1 rb:text-[12px] rb:text-[#5B6167] rb:font-regular rb:leading-4 ")}>{t('memoryExtractionEngine.Meaning')}: {t(`memoryExtractionEngine.${config.meaning}`)}</div>}
|
||||
</div>
|
||||
)
|
||||
@@ -253,6 +253,21 @@ const MemoryExtractionEngine: FC = () => {
|
||||
</div>
|
||||
</>
|
||||
}
|
||||
{config.control === 'text' &&
|
||||
<>
|
||||
<div className="rb:text-[14px] rb:font-medium rb:leading-5 rb:mt-6 rb:mb-2">
|
||||
-{t(`memoryExtractionEngine.${config.label}`)}
|
||||
</div>
|
||||
<div className="rb:pl-2">
|
||||
<Form.Item
|
||||
name={config.variableName}
|
||||
>
|
||||
<Input placeholder={t('common.pleaseEnter')} disabled />
|
||||
</Form.Item>
|
||||
<ConfigDesc config={config} onlyMeaning={true} className="rb:-mt-4!" />
|
||||
</div>
|
||||
</>
|
||||
}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 17:33:15
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-05 16:28:58
|
||||
* @Last Modified time: 2026-03-06 13:53:53
|
||||
*/
|
||||
/**
|
||||
* Memory Management Page
|
||||
@@ -154,10 +154,10 @@ const MemoryManagement: React.FC = () => {
|
||||
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/edit.svg')] rb:hover:bg-[url('@/assets/images/edit_hover.svg')]"
|
||||
onClick={() => handleEdit(item)}
|
||||
></div>
|
||||
<div
|
||||
{!item.is_system_default && <div
|
||||
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/delete.svg')] rb:hover:bg-[url('@/assets/images/delete_hover.svg')]"
|
||||
onClick={() => handleDelete(item)}
|
||||
></div>
|
||||
></div>}
|
||||
</Space>
|
||||
</div>
|
||||
</RbCard>
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 16:49:45
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-04 11:50:47
|
||||
* @Last Modified time: 2026-03-06 12:26:12
|
||||
*/
|
||||
/**
|
||||
* Model List Detail Drawer
|
||||
@@ -144,7 +144,7 @@ const ModelListDetail = forwardRef<ModelListDetailRef, ModelListDetailProps>(({
|
||||
{item.name[0]}
|
||||
</div>
|
||||
}
|
||||
extra={<Switch defaultChecked={item.is_active} disabled={loading} onChange={() => handleChange(item)} />}
|
||||
extra={<Switch checked={item.is_active} disabled={loading} onChange={() => handleChange(item)} />}
|
||||
bodyClassName="rb:relative rb:pb-[64px]! rb:h-[calc(100%-64px)]!"
|
||||
>
|
||||
<Tooltip title={item.description}>
|
||||
@@ -153,7 +153,7 @@ const ModelListDetail = forwardRef<ModelListDetailRef, ModelListDetailProps>(({
|
||||
<div className="rb:absolute rb:bottom-4 rb:left-6 rb:right-6">
|
||||
<Row gutter={12}>
|
||||
<Col span={12}>
|
||||
<Button block onClick={() => handleEdit(item)}>{t('modelNew.modelConfiguration')}</Button>
|
||||
{!item.model_id && <Button block onClick={() => handleEdit(item)}>{t('modelNew.modelConfiguration')}</Button>}
|
||||
</Col>
|
||||
<Col span={12}>
|
||||
<Button type="primary" ghost block onClick={() => handleKeyConfig(item)}>{t('modelNew.keyConfig')}</Button>
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 16:50:18
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-04 11:39:20
|
||||
* @Last Modified time: 2026-03-06 12:26:11
|
||||
*/
|
||||
/**
|
||||
* Type definitions for Model Management
|
||||
@@ -121,6 +121,7 @@ export interface ModelApiKey {
|
||||
* Model list item data structure
|
||||
*/
|
||||
export interface ModelListItem {
|
||||
model_id?: string;
|
||||
/** Model name */
|
||||
model_name?: string;
|
||||
/** Associated model config IDs */
|
||||
|
||||
@@ -102,7 +102,7 @@ const Detail: FC = () => {
|
||||
<PageHeader
|
||||
name={<Space>
|
||||
{data.scene_name}
|
||||
<Tag color="warning">{t('common.default')}</Tag>
|
||||
{data.is_system_default ? <Tag color="warning">{t('common.default')}</Tag> : undefined}
|
||||
</Space>}
|
||||
subTitle={<Tooltip title={data.scene_description}><div className="rb:h-4 rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{data.scene_description}</div></Tooltip>}
|
||||
extra={data.is_system_default ? undefined : (<Space>
|
||||
|
||||
@@ -25,6 +25,7 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
|
||||
const textContent = root.getTextContent();
|
||||
if (textContent !== prevValueRef.current) {
|
||||
isUserInputRef.current = true;
|
||||
prevValueRef.current = textContent;
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -33,7 +34,13 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
|
||||
}, [editor]);
|
||||
|
||||
useEffect(() => {
|
||||
if ((value !== prevValueRef.current || enableLineNumbers !== prevEnableLineNumbersRef.current) && !isUserInputRef.current) {
|
||||
if (value !== prevValueRef.current || enableLineNumbers !== prevEnableLineNumbersRef.current) {
|
||||
// Skip reset if the change was triggered by user input (avoid cursor jump)
|
||||
if (isUserInputRef.current && enableLineNumbers === prevEnableLineNumbersRef.current) {
|
||||
prevValueRef.current = value;
|
||||
isUserInputRef.current = false;
|
||||
return;
|
||||
}
|
||||
queueMicrotask(() => {
|
||||
editor.update(() => {
|
||||
const root = $getRoot();
|
||||
|
||||
@@ -35,7 +35,8 @@ const NODE_VARIABLES = {
|
||||
],
|
||||
'http-request': [
|
||||
{ label: 'body', dataType: 'string', field: 'body' },
|
||||
{ label: 'status_code', dataType: 'number', field: 'status_code' }
|
||||
{ label: 'status_code', dataType: 'number', field: 'status_code' },
|
||||
{ label: 'headers', dataType: 'object', field: 'headers' },
|
||||
],
|
||||
'question-classifier': [{ label: 'class_name', dataType: 'string', field: 'class_name' }],
|
||||
'memory-read': [
|
||||
@@ -390,11 +391,6 @@ export const useVariableList = (
|
||||
addVariable(list, keys, `${pid}_item`, 'item', itemType, `${pid}.item`, pd);
|
||||
addVariable(list, keys, `${pid}_index`, 'index', 'number', `${pid}.index`, pd);
|
||||
} else if (pd.type === 'iteration' && !pd.config.input.defaultValue) {
|
||||
let itemType = 'object';
|
||||
const iv = list.find(v => `{{${v.value}}}` === pd.config.input.defaultValue);
|
||||
if (iv?.dataType.startsWith('array[')) {
|
||||
itemType = iv.dataType.replace(/^array\[(.+)\]$/, '$1');
|
||||
}
|
||||
addVariable(list, keys, `${pid}_item`, 'item', 'string', `${pid}.item`, pd);
|
||||
addVariable(list, keys, `${pid}_index`, 'index', 'number', `${pid}.index`, pd);
|
||||
}
|
||||
|
||||
@@ -95,7 +95,7 @@ const Properties: FC<PropertiesProps> = ({
|
||||
initialValue[key] = config[key].defaultValue
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
form.setFieldsValue({
|
||||
type,
|
||||
id: selectedNode.id,
|
||||
@@ -114,16 +114,16 @@ const Properties: FC<PropertiesProps> = ({
|
||||
*/
|
||||
const updateNodeLabel = (newLabel: string) => {
|
||||
if (selectedNode && form) {
|
||||
const nodeData = selectedNode.data as NodeProperties;
|
||||
const nodeData = selectedNode.getData() as NodeProperties;
|
||||
selectedNode.setAttrByPath('text/text', `${nodeData.icon} ${newLabel}`);
|
||||
selectedNode.setData({ ...selectedNode.data, name: newLabel });
|
||||
selectedNode.setData({ ...selectedNode.getData(), name: newLabel });
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (values && selectedNode) {
|
||||
const { id, knowledge_retrieval, group, group_variables, ...rest } = values
|
||||
const { knowledge_bases = [], ...restKnowledgeConfig } = (knowledge_retrieval as any) || {}
|
||||
const { knowledge_bases = [], name: _name, description: _description, ...restKnowledgeConfig } = (knowledge_retrieval as any) || {}
|
||||
|
||||
let allRest = {
|
||||
...rest,
|
||||
@@ -136,21 +136,23 @@ const Properties: FC<PropertiesProps> = ({
|
||||
}))
|
||||
}
|
||||
|
||||
const nodeData = selectedNode.getData()
|
||||
|
||||
Object.keys(values).forEach(key => {
|
||||
if (selectedNode.data?.config?.[key]) {
|
||||
if (nodeData?.config?.[key]) {
|
||||
// Create a deep copy to avoid reference sharing between nodes
|
||||
if (!selectedNode.data.config[key]) {
|
||||
selectedNode.data.config[key] = {};
|
||||
if (!nodeData.config[key]) {
|
||||
nodeData.config[key] = {};
|
||||
}
|
||||
selectedNode.data.config[key] = {
|
||||
...selectedNode.data.config[key],
|
||||
nodeData.config[key] = {
|
||||
...nodeData.config[key],
|
||||
defaultValue: values[key]
|
||||
};
|
||||
}
|
||||
})
|
||||
|
||||
selectedNode?.setData({
|
||||
...selectedNode.data,
|
||||
...nodeData,
|
||||
...allRest,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -529,6 +529,10 @@ export const unknownNode = {
|
||||
type: 'unknown',
|
||||
icon: unknownIcon
|
||||
}
|
||||
export const noteNode = {
|
||||
type: 'notes',
|
||||
icon: unknownIcon
|
||||
}
|
||||
|
||||
export const nodeWidth = 240;
|
||||
/**
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 15:17:48
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-28 17:59:34
|
||||
* @Last Modified time: 2026-03-07 15:23:39
|
||||
*/
|
||||
import { useRef, useEffect, useState } from 'react';
|
||||
import { useParams } from 'react-router-dom';
|
||||
@@ -12,7 +12,7 @@ import { Graph, Node, MiniMap, Snapline, Clipboard, Keyboard, type Edge } from '
|
||||
import { register } from '@antv/x6-react-shape';
|
||||
import type { PortMetadata } from '@antv/x6/lib/model/port';
|
||||
|
||||
import { nodeRegisterLibrary, graphNodeLibrary, nodeLibrary, portMarkup, portAttrs, edgeAttrs, edge_color, edge_selected_color, portTextAttrs, defaultAbsolutePortGroups, nodeWidth, unknownNode } from '../constant';
|
||||
import { nodeRegisterLibrary, graphNodeLibrary, nodeLibrary, portMarkup, portAttrs, edgeAttrs, edge_color, edge_selected_color, portTextAttrs, defaultAbsolutePortGroups, nodeWidth, unknownNode, noteNode } from '../constant';
|
||||
import type { WorkflowConfig, NodeProperties, ChatVariable } from '../types';
|
||||
import { getWorkflowConfig, saveWorkflowConfig } from '@/api/application'
|
||||
|
||||
@@ -128,7 +128,7 @@ export const useWorkflowGraph = ({
|
||||
if (nodes.length) {
|
||||
const nodeList = nodes.map(node => {
|
||||
const { id, type, name, position, config = {} } = node
|
||||
let nodeLibraryConfig = [...nodeLibrary, { nodes: [unknownNode] }]
|
||||
let nodeLibraryConfig = [...nodeLibrary, { nodes: [unknownNode, noteNode] }]
|
||||
.flatMap(category => category.nodes)
|
||||
.find(n => n.type === type)
|
||||
nodeLibraryConfig = JSON.parse(JSON.stringify({ config: {}, ...nodeLibraryConfig })) as NodeProperties
|
||||
@@ -715,6 +715,8 @@ export const useWorkflowGraph = ({
|
||||
panning: isHandMode,
|
||||
mousewheel: {
|
||||
enabled: true,
|
||||
factor: 0.1,
|
||||
modifiers: null,
|
||||
},
|
||||
connecting: {
|
||||
connector: {
|
||||
|
||||
Reference in New Issue
Block a user