Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop
This commit is contained in:
@@ -60,14 +60,14 @@ def list_apps(
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = app_service.AppService(db)
|
||||
|
||||
|
||||
# 当 ids 存在且不为 None 时,根据 ids 获取应用
|
||||
if ids is not None:
|
||||
app_ids = [id.strip() for id in ids.split(',') if id.strip()]
|
||||
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
|
||||
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
||||
return success(data=items)
|
||||
|
||||
|
||||
# 正常分页查询
|
||||
items_orm, total = app_service.list_apps(
|
||||
db,
|
||||
|
||||
@@ -30,7 +30,7 @@ from sqlalchemy.orm import Session
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory/emotion",
|
||||
prefix="/memory/emotion-memory",
|
||||
tags=["Emotion Analysis"],
|
||||
dependencies=[Depends(get_current_user)] # 所有路由都需要认证
|
||||
)
|
||||
|
||||
@@ -39,7 +39,7 @@ from app.services.memory_forget_service import MemoryForgetService
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory/forget",
|
||||
prefix="/memory/forget-memory",
|
||||
tags=["Memory Forgetting Engine"],
|
||||
dependencies=[Depends(get_current_user)] # 所有路由都需要认证
|
||||
)
|
||||
|
||||
@@ -842,7 +842,7 @@ async def run_hybrid_search(
|
||||
|
||||
if search_type in ["keyword", "hybrid"]:
|
||||
# Keyword-based search
|
||||
logger.info("Starting keyword search...")
|
||||
logger.info("[PERF] Starting keyword search...")
|
||||
keyword_start = time.time()
|
||||
keyword_task = asyncio.create_task(
|
||||
search_graph(
|
||||
@@ -856,7 +856,7 @@ async def run_hybrid_search(
|
||||
|
||||
if search_type in ["embedding", "hybrid"]:
|
||||
# Embedding-based search
|
||||
logger.info("Starting embedding search...")
|
||||
logger.info("[PERF] Starting embedding search...")
|
||||
embedding_start = time.time()
|
||||
|
||||
# 从数据库读取嵌入器配置(按 ID)并构建 RedBearModelConfig
|
||||
@@ -872,13 +872,13 @@ async def run_hybrid_search(
|
||||
type="llm"
|
||||
)
|
||||
config_load_time = time.time() - config_load_start
|
||||
logger.info(f"Config loading took {config_load_time:.4f}s")
|
||||
logger.info(f"[PERF] Config loading took {config_load_time:.4f}s")
|
||||
|
||||
# Init embedder
|
||||
embedder_init_start = time.time()
|
||||
embedder = OpenAIEmbedderClient(model_config=rb_config)
|
||||
embedder_init_time = time.time() - embedder_init_start
|
||||
logger.info(f"Embedder init took {embedder_init_time:.4f}s")
|
||||
logger.info(f"[PERF] Embedder init took {embedder_init_time:.4f}s")
|
||||
|
||||
embedding_task = asyncio.create_task(
|
||||
search_graph_by_embedding(
|
||||
@@ -895,7 +895,7 @@ async def run_hybrid_search(
|
||||
keyword_results = await keyword_task
|
||||
keyword_latency = time.time() - keyword_start
|
||||
latency_metrics["keyword_search_latency"] = round(keyword_latency, 4)
|
||||
logger.info(f"Keyword search completed in {keyword_latency:.4f}s")
|
||||
logger.info(f"[PERF] Keyword search completed in {keyword_latency:.4f}s")
|
||||
if search_type == "keyword":
|
||||
results = keyword_results
|
||||
else:
|
||||
@@ -905,7 +905,7 @@ async def run_hybrid_search(
|
||||
embedding_results = await embedding_task
|
||||
embedding_latency = time.time() - embedding_start
|
||||
latency_metrics["embedding_search_latency"] = round(embedding_latency, 4)
|
||||
logger.info(f"Embedding search completed in {embedding_latency:.4f}s")
|
||||
logger.info(f"[PERF] Embedding search completed in {embedding_latency:.4f}s")
|
||||
if search_type == "embedding":
|
||||
results = embedding_results
|
||||
else:
|
||||
@@ -922,17 +922,21 @@ async def run_hybrid_search(
|
||||
|
||||
# Apply two-stage reranking with ACTR activation calculation
|
||||
rerank_start = time.time()
|
||||
logger.info("Using two-stage reranking with ACTR activation")
|
||||
logger.info("[PERF] Using two-stage reranking with ACTR activation")
|
||||
|
||||
# 加载遗忘引擎配置
|
||||
config_start = time.time()
|
||||
try:
|
||||
pc = get_pipeline_config(memory_config)
|
||||
forgetting_cfg = pc.forgetting_engine
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to load forgetting config, using defaults: {e}")
|
||||
forgetting_cfg = ForgettingEngineConfig()
|
||||
config_time = time.time() - config_start
|
||||
logger.info(f"[PERF] Forgetting config loading took {config_time:.4f}s")
|
||||
|
||||
# 统一使用激活度重排序(两阶段:检索 + ACTR计算)
|
||||
rerank_compute_start = time.time()
|
||||
reranked_results = rerank_with_activation(
|
||||
keyword_results=keyword_results,
|
||||
embedding_results=embedding_results,
|
||||
@@ -941,10 +945,12 @@ async def run_hybrid_search(
|
||||
forgetting_config=forgetting_cfg,
|
||||
activation_boost_factor=activation_boost_factor,
|
||||
)
|
||||
rerank_compute_time = time.time() - rerank_compute_start
|
||||
logger.info(f"[PERF] Rerank computation took {rerank_compute_time:.4f}s")
|
||||
|
||||
rerank_latency = time.time() - rerank_start
|
||||
latency_metrics["reranking_latency"] = round(rerank_latency, 4)
|
||||
logger.info(f"Reranking completed in {rerank_latency:.4f}s")
|
||||
logger.info(f"[PERF] Total reranking completed in {rerank_latency:.4f}s")
|
||||
|
||||
# Optional: apply reranker placeholder if enabled via config
|
||||
reranked_results = apply_reranker_placeholder(reranked_results, query_text)
|
||||
@@ -985,8 +991,10 @@ async def run_hybrid_search(
|
||||
else:
|
||||
results["latency_metrics"] = latency_metrics
|
||||
|
||||
logger.info(f"Total search completed in {total_latency:.4f}s")
|
||||
logger.info(f"Latency breakdown: {latency_metrics}")
|
||||
logger.info(f"[PERF] ===== SEARCH PERFORMANCE SUMMARY =====")
|
||||
logger.info(f"[PERF] Total search completed in {total_latency:.4f}s")
|
||||
logger.info(f"[PERF] Latency breakdown: {json.dumps(latency_metrics, indent=2)}")
|
||||
logger.info(f"[PERF] =========================================")
|
||||
|
||||
# Sanitize results: drop large/unused fields
|
||||
_remove_keys_recursive(results, ["name_embedding"]) # drop entity name embeddings from outputs
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
from app.core.logging_config import get_memory_logger
|
||||
@@ -28,6 +29,118 @@ class MemorySummaryResponse(RobustLLMResponse):
|
||||
)
|
||||
|
||||
|
||||
async def generate_title_and_type_for_summary(
|
||||
content: str,
|
||||
llm_client
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
为MemorySummary生成标题和类型
|
||||
|
||||
此方法应该在创建MemorySummary节点时调用,生成title和type
|
||||
|
||||
Args:
|
||||
content: Summary的内容文本
|
||||
llm_client: LLM客户端实例
|
||||
|
||||
Returns:
|
||||
(标题, 类型)元组
|
||||
"""
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_episodic_title_and_type_prompt
|
||||
|
||||
# 定义有效的类型集合
|
||||
VALID_TYPES = {
|
||||
"conversation", # 对话
|
||||
"project_work", # 项目/工作
|
||||
"learning", # 学习
|
||||
"decision", # 决策
|
||||
"important_event" # 重要事件
|
||||
}
|
||||
DEFAULT_TYPE = "conversation" # 默认类型
|
||||
|
||||
try:
|
||||
if not content:
|
||||
logger.warning("content为空,无法生成标题和类型")
|
||||
return ("空内容", DEFAULT_TYPE)
|
||||
|
||||
# 1. 渲染Jinja2提示词模板
|
||||
prompt = await render_episodic_title_and_type_prompt(content)
|
||||
|
||||
# 2. 调用LLM生成标题和类型
|
||||
messages = [
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
response = await llm_client.chat(messages=messages)
|
||||
|
||||
# 3. 解析LLM响应
|
||||
content_response = response.content
|
||||
if isinstance(content_response, list):
|
||||
if len(content_response) > 0:
|
||||
if isinstance(content_response[0], dict):
|
||||
text = content_response[0].get('text', content_response[0].get('content', str(content_response[0])))
|
||||
full_response = str(text)
|
||||
else:
|
||||
full_response = str(content_response[0])
|
||||
else:
|
||||
full_response = ""
|
||||
elif isinstance(content_response, dict):
|
||||
full_response = str(content_response.get('text', content_response.get('content', str(content_response))))
|
||||
else:
|
||||
full_response = str(content_response) if content_response is not None else ""
|
||||
|
||||
# 4. 解析JSON响应
|
||||
try:
|
||||
# 尝试从响应中提取JSON
|
||||
# 移除可能的markdown代码块标记
|
||||
json_str = full_response.strip()
|
||||
if json_str.startswith("```json"):
|
||||
json_str = json_str[7:]
|
||||
if json_str.startswith("```"):
|
||||
json_str = json_str[3:]
|
||||
if json_str.endswith("```"):
|
||||
json_str = json_str[:-3]
|
||||
json_str = json_str.strip()
|
||||
|
||||
result_data = json.loads(json_str)
|
||||
title = result_data.get("title", "未知标题")
|
||||
episodic_type_raw = result_data.get("type", DEFAULT_TYPE)
|
||||
|
||||
# 5. 校验和归一化类型
|
||||
# 将类型转换为小写并去除空格
|
||||
episodic_type_normalized = str(episodic_type_raw).lower().strip()
|
||||
|
||||
# 检查是否在有效类型集合中
|
||||
if episodic_type_normalized in VALID_TYPES:
|
||||
episodic_type = episodic_type_normalized
|
||||
else:
|
||||
# 尝试映射常见的中文类型到英文
|
||||
type_mapping = {
|
||||
"对话": "conversation",
|
||||
"项目": "project_work",
|
||||
"工作": "project_work",
|
||||
"项目/工作": "project_work",
|
||||
"学习": "learning",
|
||||
"决策": "decision",
|
||||
"重要事件": "important_event",
|
||||
"事件": "important_event"
|
||||
}
|
||||
episodic_type = type_mapping.get(episodic_type_raw, DEFAULT_TYPE)
|
||||
logger.warning(
|
||||
f"LLM返回的类型 '{episodic_type_raw}' 不在有效集合中,"
|
||||
f"已归一化为 '{episodic_type}'"
|
||||
)
|
||||
|
||||
logger.info(f"成功生成标题和类型: title={title}, type={episodic_type}")
|
||||
return (title, episodic_type)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"无法解析LLM响应为JSON: {full_response}")
|
||||
return ("解析失败", DEFAULT_TYPE)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成标题和类型时出错: {str(e)}", exc_info=True)
|
||||
return ("错误", DEFAULT_TYPE)
|
||||
|
||||
async def _process_chunk_summary(
|
||||
dialog: DialogData,
|
||||
chunk,
|
||||
@@ -63,10 +176,9 @@ async def _process_chunk_summary(
|
||||
title = None
|
||||
episodic_type = None
|
||||
try:
|
||||
from app.services.user_memory_service import UserMemoryService
|
||||
title, episodic_type = await UserMemoryService.generate_title_and_type_for_summary(
|
||||
title, episodic_type = await generate_title_and_type_for_summary(
|
||||
content=summary_text,
|
||||
end_user_id=dialog.group_id
|
||||
llm_client=llm_client
|
||||
)
|
||||
logger.info(f"Generated title and type for MemorySummary: title={title}, type={episodic_type}")
|
||||
except Exception as e:
|
||||
|
||||
@@ -8,14 +8,16 @@ Classes:
|
||||
AccessHistoryManager: 访问历史管理器,提供并发安全的访问记录和一致性检查
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from app.core.memory.storage_services.forgetting_engine.actr_calculator import (
|
||||
ACTRCalculator,
|
||||
)
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -188,30 +190,43 @@ class AccessHistoryManager:
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 成功更新的节点列表
|
||||
"""
|
||||
import time
|
||||
batch_start = time.time()
|
||||
|
||||
if current_time is None:
|
||||
current_time = datetime.now()
|
||||
|
||||
# PERFORMANCE FIX: Process all nodes in parallel instead of sequentially
|
||||
tasks = []
|
||||
for node_id in node_ids:
|
||||
task = self.record_access(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
group_id=group_id,
|
||||
current_time=current_time
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
# Execute all tasks in parallel
|
||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Collect successful results and count failures
|
||||
results = []
|
||||
failed_count = 0
|
||||
|
||||
for node_id in node_ids:
|
||||
try:
|
||||
updated_node = await self.record_access(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
group_id=group_id,
|
||||
current_time=current_time
|
||||
)
|
||||
results.append(updated_node)
|
||||
except Exception as e:
|
||||
for node_id, result in zip(node_ids, task_results):
|
||||
if isinstance(result, Exception):
|
||||
failed_count += 1
|
||||
logger.warning(
|
||||
f"批量访问记录失败: {node_label}[{node_id}], 错误: {str(e)}"
|
||||
f"批量访问记录失败: {node_label}[{node_id}], 错误: {str(result)}"
|
||||
)
|
||||
else:
|
||||
results.append(result)
|
||||
|
||||
batch_duration = time.time() - batch_start
|
||||
logger.info(
|
||||
f"批量访问记录完成: 成功 {len(results)}/{len(node_ids)}, "
|
||||
f"失败 {failed_count}"
|
||||
f"[PERF] 批量访问记录完成: 成功 {len(results)}/{len(node_ids)}, "
|
||||
f"失败 {failed_count}, 耗时 {batch_duration:.4f}s"
|
||||
)
|
||||
|
||||
return results
|
||||
@@ -531,7 +546,10 @@ class AccessHistoryManager:
|
||||
Dict[str, Any]: 更新数据,包含所有需要更新的字段
|
||||
"""
|
||||
access_history = node_data.get('access_history') or []
|
||||
importance_score = node_data.get('importance_score', 0.5)
|
||||
# Handle None importance_score - default to 0.5
|
||||
importance_score = node_data.get('importance_score')
|
||||
if importance_score is None:
|
||||
importance_score = 0.5
|
||||
|
||||
# 追加新的访问时间
|
||||
new_access_history = access_history + [current_time_iso]
|
||||
@@ -620,34 +638,52 @@ class AccessHistoryManager:
|
||||
new_version = current_version + 1
|
||||
|
||||
# 步骤2:使用乐观锁更新节点
|
||||
# 只有当版本号匹配时才更新
|
||||
update_query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
"""
|
||||
# 根据节点类型构建完整的查询语句
|
||||
content_field_map = {
|
||||
'Statement': 'n.statement as statement',
|
||||
'MemorySummary': 'n.content as content',
|
||||
'ExtractedEntity': 'null as content_placeholder' # 占位符,后续会被过滤
|
||||
}
|
||||
|
||||
# 显式检查节点类型,不支持的类型抛出错误
|
||||
if node_label not in content_field_map:
|
||||
raise ValueError(
|
||||
f"Unsupported node_label: {node_label}. "
|
||||
f"Supported labels are: {list(content_field_map.keys())}"
|
||||
)
|
||||
|
||||
content_field = content_field_map[node_label]
|
||||
|
||||
# 构建 WHERE 子句
|
||||
where_conditions = []
|
||||
if group_id:
|
||||
update_query += " WHERE n.group_id = $group_id"
|
||||
where_conditions.append("n.group_id = $group_id")
|
||||
|
||||
# 添加版本检查
|
||||
if current_version > 0:
|
||||
update_query += " AND n.version = $current_version"
|
||||
where_conditions.append("n.version = $current_version")
|
||||
else:
|
||||
# 如果节点没有版本号,检查是否为首次更新
|
||||
update_query += " AND (n.version IS NULL OR n.version = 0)"
|
||||
where_conditions.append("(n.version IS NULL OR n.version = 0)")
|
||||
|
||||
update_query += """
|
||||
where_clause = " AND ".join(where_conditions) if where_conditions else "true"
|
||||
|
||||
# 构建完整的更新查询
|
||||
update_query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
WHERE {where_clause}
|
||||
SET n.activation_value = $activation_value,
|
||||
n.access_history = $access_history,
|
||||
n.last_access_time = $last_access_time,
|
||||
n.access_count = $access_count,
|
||||
n.version = $new_version
|
||||
RETURN n.id as id,
|
||||
n.statement as statement,
|
||||
n.activation_value as activation_value,
|
||||
n.access_history as access_history,
|
||||
n.last_access_time as last_access_time,
|
||||
n.access_count as access_count,
|
||||
n.importance_score as importance_score,
|
||||
n.version as version
|
||||
n.version as version,
|
||||
{content_field}
|
||||
"""
|
||||
|
||||
update_params = {
|
||||
@@ -671,7 +707,11 @@ class AccessHistoryManager:
|
||||
f"Expected version {current_version}, but node was modified by another transaction."
|
||||
)
|
||||
|
||||
return dict(updated_node)
|
||||
# 转换为字典并移除占位符字段
|
||||
result_dict = dict(updated_node)
|
||||
result_dict.pop('content_placeholder', None)
|
||||
|
||||
return result_dict
|
||||
|
||||
# 执行事务
|
||||
try:
|
||||
|
||||
@@ -260,17 +260,32 @@ class ForgettingStrategy:
|
||||
)
|
||||
|
||||
# 生成标题和类型(使用LLM)
|
||||
from app.services.user_memory_service import UserMemoryService
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import generate_title_and_type_for_summary
|
||||
|
||||
# 获取 LLM 客户端
|
||||
llm_client = None
|
||||
if config_id is not None and db is not None:
|
||||
try:
|
||||
llm_client = await self._get_llm_client(db, config_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"获取 LLM 客户端失败: {str(e)}")
|
||||
|
||||
# 生成标题和类型
|
||||
try:
|
||||
title, episodic_type = await UserMemoryService.generate_title_and_type_for_summary(
|
||||
content=summary_text,
|
||||
end_user_id=group_id
|
||||
)
|
||||
logger.info(f"成功为MemorySummary生成标题和类型: title={title}, type={episodic_type}")
|
||||
if llm_client is not None:
|
||||
title, episodic_type = await generate_title_and_type_for_summary(
|
||||
content=summary_text,
|
||||
llm_client=llm_client
|
||||
)
|
||||
logger.info(f"成功为MemorySummary生成标题和类型: title={title}, type={episodic_type}")
|
||||
else:
|
||||
logger.warning("LLM 客户端不可用,使用默认标题和类型")
|
||||
title = "未命名"
|
||||
episodic_type = "conversation"
|
||||
except Exception as e:
|
||||
logger.error(f"生成标题和类型失败,使用默认值: {str(e)}")
|
||||
title = "未命名"
|
||||
episodic_type = "其他"
|
||||
episodic_type = "conversation"
|
||||
|
||||
# 计算继承的激活值和重要性(取较高值)
|
||||
inherited_activation = max(statement_activation, entity_activation)
|
||||
|
||||
@@ -3,13 +3,11 @@
|
||||
|
||||
基于 LangGraph 的工作流执行引擎。
|
||||
"""
|
||||
|
||||
# import uuid
|
||||
import datetime
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.graph_builder import GraphBuilder
|
||||
@@ -55,6 +53,12 @@ class WorkflowExecutor:
|
||||
self.edges = workflow_config.get("edges", [])
|
||||
self.execution_config = workflow_config.get("execution_config", {})
|
||||
|
||||
self.checkpoint_config = {
|
||||
"configurable": {
|
||||
"thread_id": uuid.uuid4(),
|
||||
}
|
||||
}
|
||||
|
||||
def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState:
|
||||
"""准备初始状态(注入系统变量和会话变量)
|
||||
|
||||
@@ -95,7 +99,7 @@ class WorkflowExecutor:
|
||||
case VariableType.ARRAY_NUMBER | VariableType.ARRAY_OBJECT | VariableType.ARRAY_BOOLEAN | VariableType.ARRAY_STRING:
|
||||
conversation_vars[var_name] = []
|
||||
input_variables = input_data.get("variables") or {} # Start 节点的自定义变量
|
||||
|
||||
conversation_vars = conversation_vars | input_data.get("conv", {})
|
||||
# 构建分层的变量结构
|
||||
variables = {
|
||||
"sys": {
|
||||
@@ -110,7 +114,7 @@ class WorkflowExecutor:
|
||||
}
|
||||
|
||||
return {
|
||||
"messages": [HumanMessage(content=user_message)],
|
||||
"messages": [('user', user_message)],
|
||||
"variables": variables,
|
||||
"node_outputs": {},
|
||||
"runtime_vars": {}, # 运行时节点变量(简化版,供快速访问)
|
||||
@@ -196,6 +200,28 @@ class WorkflowExecutor:
|
||||
logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}")
|
||||
return prefixes, adjacent_and_referenced
|
||||
|
||||
def _build_final_output(self, result, elapsed_time):
|
||||
node_outputs = result.get("node_outputs", {})
|
||||
final_output = self._extract_final_output(node_outputs)
|
||||
token_usage = self._aggregate_token_usage(node_outputs)
|
||||
conversation_id = None
|
||||
for node_id, node_output in node_outputs.items():
|
||||
if node_output.get("node_type") == "start":
|
||||
conversation_id = node_output.get("output", {}).get("conversation_id")
|
||||
break
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"output": final_output,
|
||||
"node_outputs": node_outputs,
|
||||
"messages": result.get("messages", []),
|
||||
"conversation_id": conversation_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": token_usage,
|
||||
"error": result.get("error"),
|
||||
"variables": result.get("variables", {}),
|
||||
}
|
||||
|
||||
def build_graph(self, stream=False) -> CompiledStateGraph:
|
||||
"""构建 LangGraph
|
||||
|
||||
@@ -236,40 +262,16 @@ class WorkflowExecutor:
|
||||
|
||||
# 3. 执行工作流
|
||||
try:
|
||||
result = await graph.ainvoke(initial_state)
|
||||
|
||||
result = await graph.ainvoke(initial_state, config=self.checkpoint_config)
|
||||
|
||||
# 计算耗时
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
# 提取节点输出(现在包含 start 和 end 节点)
|
||||
node_outputs = result.get("node_outputs", {})
|
||||
|
||||
# 提取最终输出(从最后一个非 start/end 节点)
|
||||
final_output = self._extract_final_output(node_outputs)
|
||||
|
||||
# 聚合 token 使用情况
|
||||
token_usage = self._aggregate_token_usage(node_outputs)
|
||||
|
||||
# 提取 conversation_id(从 start 节点输出)
|
||||
conversation_id = None
|
||||
for node_id, node_output in node_outputs.items():
|
||||
if node_output.get("node_type") == "start":
|
||||
conversation_id = node_output.get("output", {}).get("conversation_id")
|
||||
break
|
||||
|
||||
logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"output": final_output,
|
||||
"node_outputs": node_outputs,
|
||||
"messages": result.get("messages", []),
|
||||
"conversation_id": conversation_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": token_usage,
|
||||
"error": result.get("error")
|
||||
}
|
||||
return self._build_final_output(result, elapsed_time)
|
||||
|
||||
except Exception as e:
|
||||
# 计算耗时(即使失败也记录)
|
||||
@@ -331,11 +333,11 @@ class WorkflowExecutor:
|
||||
# 3. Execute workflow
|
||||
try:
|
||||
chunk_count = 0
|
||||
final_state = None
|
||||
|
||||
async for event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
|
||||
config=self.checkpoint_config
|
||||
):
|
||||
# event should be a tuple: (mode, data)
|
||||
# But let's handle both cases
|
||||
@@ -411,12 +413,11 @@ class WorkflowExecutor:
|
||||
elif mode == "updates":
|
||||
# Handle state updates - store final state
|
||||
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())}")
|
||||
final_state = data
|
||||
|
||||
# 计算耗时
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
result = graph.get_state(self.checkpoint_config).values
|
||||
logger.info(
|
||||
f"Workflow execution completed (streaming), "
|
||||
f"total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s"
|
||||
@@ -425,12 +426,7 @@ class WorkflowExecutor:
|
||||
# 发送 workflow_end 事件
|
||||
yield {
|
||||
"event": "workflow_end",
|
||||
"data": {
|
||||
"execution_id": self.execution_id,
|
||||
"status": "completed",
|
||||
"elapsed_time": elapsed_time,
|
||||
"timestamp": end_time.isoformat()
|
||||
}
|
||||
"data": self._build_final_output(result, elapsed_time)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any
|
||||
|
||||
from langgraph.graph.state import CompiledStateGraph, StateGraph
|
||||
from langgraph.graph import START, END
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
||||
@@ -249,4 +250,5 @@ class GraphBuilder:
|
||||
self.graph = StateGraph(WorkflowState)
|
||||
self.add_nodes()
|
||||
self.add_edges() # 添加边必须在添加节点之后
|
||||
return self.graph.compile()
|
||||
checkpointer = InMemorySaver()
|
||||
return self.graph.compile(checkpointer=checkpointer)
|
||||
|
||||
@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
|
||||
class AssignerNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = AssignerNodeConfig(**self.config)
|
||||
self.typed_config: AssignerNodeConfig | None = None
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
"""
|
||||
@@ -28,6 +28,7 @@ class AssignerNode(BaseNode):
|
||||
None or the result of the assignment operation.
|
||||
"""
|
||||
# Initialize a variable pool for accessing conversation, node, and system variables
|
||||
self.typed_config = AssignerNodeConfig(**self.config)
|
||||
logger.info(f"节点 {self.node_id} 开始执行")
|
||||
pool = VariablePool(state)
|
||||
for assignment in self.typed_config.assignments:
|
||||
|
||||
@@ -25,7 +25,7 @@ class WorkflowState(TypedDict):
|
||||
The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
|
||||
"""
|
||||
# List of messages (append mode)
|
||||
messages: Annotated[list[AnyMessage], add]
|
||||
messages: Annotated[list[tuple[str, str]], add]
|
||||
|
||||
# Set of loop node IDs, used for assigning values in loop nodes
|
||||
cycle_nodes: list
|
||||
@@ -203,6 +203,7 @@ class BaseNode(ABC):
|
||||
# 返回包装后的输出和运行时变量
|
||||
return {
|
||||
**wrapped_output,
|
||||
"variables": state["variables"],
|
||||
"runtime_vars": {
|
||||
self.node_id: runtime_var
|
||||
},
|
||||
@@ -355,6 +356,7 @@ class BaseNode(ABC):
|
||||
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
|
||||
state_update = {
|
||||
**final_output,
|
||||
"variables": state["variables"],
|
||||
"runtime_vars": {
|
||||
self.node_id: runtime_var
|
||||
},
|
||||
|
||||
@@ -30,7 +30,6 @@ class CycleGraphNode(BaseNode):
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: LoopNodeConfig | IterationNodeConfig | None = None
|
||||
|
||||
self.cycle_nodes = list() # Nodes belonging to this cycle
|
||||
self.cycle_edges = list() # Edges connecting nodes within the cycle
|
||||
|
||||
@@ -32,7 +32,7 @@ class HttpRequestNode(BaseNode):
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = HttpRequestNodeConfig(**self.config)
|
||||
self.typed_config: HttpRequestNodeConfig | None = None
|
||||
|
||||
def _build_timeout(self) -> Timeout:
|
||||
"""
|
||||
@@ -181,6 +181,7 @@ class HttpRequestNode(BaseNode):
|
||||
- dict: Serialized HttpRequestNodeOutput on success
|
||||
- str: Branch identifier (e.g. "ERROR") when branching is enabled
|
||||
"""
|
||||
self.typed_config = HttpRequestNodeConfig(**self.config)
|
||||
async with httpx.AsyncClient(
|
||||
verify=self.typed_config.verify_ssl,
|
||||
timeout=self._build_timeout(),
|
||||
|
||||
@@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
|
||||
class IfElseNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = IfElseNodeConfig(**self.config)
|
||||
self.typed_config: IfElseNodeConfig | None= None
|
||||
|
||||
@staticmethod
|
||||
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
|
||||
@@ -109,6 +109,7 @@ class IfElseNode(BaseNode):
|
||||
Returns:
|
||||
str: The matched branch identifier, e.g., 'CASE1', 'CASE2', ..., used for node transitions.
|
||||
"""
|
||||
self.typed_config = IfElseNodeConfig(**self.config)
|
||||
expressions = self.evaluate_conditional_edge_expressions(state)
|
||||
# TODO: 变量类型及文本类型解析
|
||||
for i in range(len(expressions)):
|
||||
|
||||
@@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
|
||||
class JinjaRenderNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = JinjaRenderNodeConfig(**self.config)
|
||||
self.typed_config: JinjaRenderNodeConfig | None = None
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
"""
|
||||
@@ -34,6 +34,7 @@ class JinjaRenderNode(BaseNode):
|
||||
RuntimeError: If Jinja2 template rendering fails due to invalid template
|
||||
syntax or missing variables.
|
||||
"""
|
||||
self.typed_config = JinjaRenderNodeConfig(**self.config)
|
||||
render = TemplateRenderer(strict=False)
|
||||
|
||||
context = {}
|
||||
|
||||
@@ -44,8 +44,8 @@ class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
|
||||
description="Knowledge base config"
|
||||
)
|
||||
|
||||
reranker_id: UUID = Field(
|
||||
default="",
|
||||
reranker_id: UUID | None = Field(
|
||||
default=None,
|
||||
description="Reranker top k"
|
||||
)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
|
||||
class KnowledgeRetrievalNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
|
||||
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
|
||||
|
||||
@staticmethod
|
||||
def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType):
|
||||
@@ -171,6 +171,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
Raises:
|
||||
RuntimeError: If no valid knowledge base is found or access is denied.
|
||||
"""
|
||||
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
|
||||
query = self._render_template(self.typed_config.query, state)
|
||||
with get_db_read() as db:
|
||||
knowledge_bases = self.typed_config.knowledge_bases
|
||||
|
||||
@@ -68,7 +68,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = LLMNodeConfig(**self.config)
|
||||
self.typed_config: LLMNodeConfig | None = None
|
||||
|
||||
def _render_context(self, message, state):
|
||||
context = f"<context>{self._render_template(self.typed_config.context, state)}</context>"
|
||||
@@ -164,6 +164,7 @@ class LLMNode(BaseNode):
|
||||
Returns:
|
||||
LLM 响应消息
|
||||
"""
|
||||
self.typed_config = LLMNodeConfig(**self.config)
|
||||
llm, prompt_or_messages = self._prepare_llm(state, True)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
||||
|
||||
@@ -10,9 +10,10 @@ from app.services.memory_agent_service import MemoryAgentService
|
||||
class MemoryReadNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = MemoryReadNodeConfig(**self.config)
|
||||
self.typed_config: MemoryReadNodeConfig | None = None
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
self.typed_config = MemoryReadNodeConfig(**self.config)
|
||||
with get_db_read() as db:
|
||||
workspace_id = self.get_variable('sys.workspace_id', state)
|
||||
end_user_id = self.get_variable("sys.user_id", state)
|
||||
|
||||
@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
|
||||
class ParameterExtractorNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = ParameterExtractorNodeConfig(**self.config)
|
||||
self.typed_config: ParameterExtractorNodeConfig | None = None
|
||||
|
||||
@staticmethod
|
||||
def _get_prompt():
|
||||
@@ -145,6 +145,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
Raises:
|
||||
BusinessException: If LLM output cannot be parsed as valid JSON.
|
||||
"""
|
||||
self.typed_config = ParameterExtractorNodeConfig(**self.config)
|
||||
llm = self._get_llm_instance()
|
||||
system_prompt, user_prompt = self._get_prompt()
|
||||
|
||||
|
||||
@@ -21,8 +21,8 @@ class QuestionClassifierNode(BaseNode):
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = QuestionClassifierNodeConfig(**self.config)
|
||||
self.category_to_case_map = self._build_category_case_map()
|
||||
self.typed_config: QuestionClassifierNodeConfig | None = None
|
||||
self.category_to_case_map = {}
|
||||
|
||||
def _get_llm_instance(self) -> RedBearLLM:
|
||||
"""获取LLM实例"""
|
||||
@@ -67,6 +67,8 @@ class QuestionClassifierNode(BaseNode):
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict:
|
||||
"""执行问题分类"""
|
||||
self.typed_config = QuestionClassifierNodeConfig(**self.config)
|
||||
self.category_to_case_map = self._build_category_case_map()
|
||||
question = self.typed_config.input_variable
|
||||
supplement_prompt = self.typed_config.user_supplement_prompt or ""
|
||||
categories = self.typed_config.categories or []
|
||||
|
||||
@@ -7,6 +7,7 @@ Start 节点实现
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_config import VariableType
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||
|
||||
@@ -34,7 +35,7 @@ class StartNode(BaseNode):
|
||||
super().__init__(node_config, workflow_config)
|
||||
|
||||
# 解析并验证配置
|
||||
self.typed_config = StartNodeConfig(**self.config)
|
||||
self.typed_config: StartNodeConfig | None = None
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""执行 start 节点业务逻辑
|
||||
@@ -47,6 +48,7 @@ class StartNode(BaseNode):
|
||||
Returns:
|
||||
包含系统参数、会话变量和自定义变量的字典
|
||||
"""
|
||||
self.typed_config = StartNodeConfig(**self.config)
|
||||
logger.info(f"节点 {self.node_id} (Start) 开始执行")
|
||||
|
||||
# 创建变量池实例(在方法内复用)
|
||||
@@ -113,6 +115,18 @@ class StartNode(BaseNode):
|
||||
logger.debug(
|
||||
f"变量 '{var_name}' 使用默认值: {var_def.default}"
|
||||
)
|
||||
else:
|
||||
match var_def.type:
|
||||
case VariableType.STRING:
|
||||
processed[var_name] = ""
|
||||
case VariableType.NUMBER:
|
||||
processed[var_name] = 0
|
||||
case VariableType.OBJECT:
|
||||
processed[var_name] = {}
|
||||
case VariableType.BOOLEAN:
|
||||
processed[var_name] = False
|
||||
case VariableType.ARRAY_NUMBER | VariableType.ARRAY_OBJECT | VariableType.ARRAY_BOOLEAN | VariableType.ARRAY_STRING:
|
||||
processed[var_name] = []
|
||||
|
||||
return processed
|
||||
|
||||
|
||||
@@ -19,10 +19,11 @@ class ToolNode(BaseNode):
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = ToolNodeConfig(**self.config)
|
||||
self.typed_config: ToolNodeConfig | None = None
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""执行工具"""
|
||||
self.typed_config = ToolNodeConfig(**self.config)
|
||||
# 获取租户ID和用户ID
|
||||
tenant_id = self.get_variable("sys.tenant_id", state)
|
||||
user_id = self.get_variable("sys.user_id", state)
|
||||
|
||||
@@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
|
||||
class VariableAggregatorNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = VariableAggregatorNodeConfig(**self.config)
|
||||
self.typed_config: VariableAggregatorNodeConfig | None = None
|
||||
|
||||
@staticmethod
|
||||
def _get_express(variable_string: str) -> Any:
|
||||
@@ -37,6 +37,7 @@ class VariableAggregatorNode(BaseNode):
|
||||
- str: In non-group mode, returns the first non-None variable value.
|
||||
- dict: In group mode, returns a mapping of group_name -> first non-None variable value.
|
||||
"""
|
||||
self.typed_config = VariableAggregatorNodeConfig(**self.config)
|
||||
if not self.typed_config.group:
|
||||
# --------------------------
|
||||
# Non-group mode
|
||||
|
||||
@@ -66,24 +66,38 @@ async def _update_activation_values_batch(
|
||||
max_retries=max_retries
|
||||
)
|
||||
|
||||
# 提取节点ID列表
|
||||
node_ids = [node.get('id') for node in nodes if node.get('id')]
|
||||
# 提取节点ID列表并去重(保持原始顺序)
|
||||
seen_ids = set()
|
||||
unique_node_ids = []
|
||||
for node in nodes:
|
||||
node_id = node.get('id')
|
||||
if node_id and node_id not in seen_ids:
|
||||
seen_ids.add(node_id)
|
||||
unique_node_ids.append(node_id)
|
||||
|
||||
if not node_ids:
|
||||
if not unique_node_ids:
|
||||
logger.warning(f"批量更新激活值:没有有效的节点ID")
|
||||
return nodes
|
||||
|
||||
# 记录去重信息(仅针对具有有效 ID 的节点)
|
||||
id_nodes_count = sum(1 for n in nodes if n.get("id"))
|
||||
if len(unique_node_ids) < id_nodes_count:
|
||||
logger.info(
|
||||
f"批量更新激活值:检测到重复节点,具有有效ID的节点数量={id_nodes_count}, "
|
||||
f"去重后唯一ID数量={len(unique_node_ids)}"
|
||||
)
|
||||
|
||||
# 批量记录访问
|
||||
try:
|
||||
updated_nodes = await access_manager.record_batch_access(
|
||||
node_ids=node_ids,
|
||||
node_ids=unique_node_ids,
|
||||
node_label=node_label,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"批量更新激活值成功: {node_label}, "
|
||||
f"更新数量={len(updated_nodes)}/{len(node_ids)}"
|
||||
f"更新数量={len(updated_nodes)}/{len(unique_node_ids)}"
|
||||
)
|
||||
|
||||
return updated_nodes
|
||||
@@ -153,19 +167,38 @@ async def _update_search_results_activation(
|
||||
original_nodes = results[key]
|
||||
updated_nodes = update_result
|
||||
|
||||
# 创建 ID 到原始节点的映射(用于快速查找 score)
|
||||
original_map = {node.get('id'): node for node in original_nodes if node.get('id')}
|
||||
# 创建 ID 到更新节点的映射(用于快速查找激活值数据)
|
||||
updated_map = {node.get('id'): node for node in updated_nodes if node.get('id')}
|
||||
|
||||
# 合并数据:激活值来自更新结果,score 来自原始结果
|
||||
# 合并数据:保留所有原始节点(包括重复的),用更新后的激活值数据填充
|
||||
merged_nodes = []
|
||||
for updated_node in updated_nodes:
|
||||
node_id = updated_node.get('id')
|
||||
if node_id and node_id in original_map:
|
||||
# 保留原始的 score 字段
|
||||
original_score = original_map[node_id].get('score')
|
||||
if original_score is not None:
|
||||
updated_node['score'] = original_score
|
||||
merged_nodes.append(updated_node)
|
||||
for original_node in original_nodes:
|
||||
node_id = original_node.get('id')
|
||||
if node_id and node_id in updated_map:
|
||||
# 从原始节点开始,用更新后的激活值数据覆盖
|
||||
merged_node = original_node.copy()
|
||||
|
||||
# 更新激活值相关字段
|
||||
activation_fields = {
|
||||
'activation_value',
|
||||
'access_history',
|
||||
'last_access_time',
|
||||
'access_count',
|
||||
'importance_score',
|
||||
'version',
|
||||
'statement', # Statement 节点的内容字段
|
||||
'content' # MemorySummary 节点的内容字段
|
||||
}
|
||||
|
||||
# 只更新激活值相关字段,保留原始节点的其他字段
|
||||
for field in activation_fields:
|
||||
if field in updated_map[node_id]:
|
||||
merged_node[field] = updated_map[node_id][field]
|
||||
|
||||
merged_nodes.append(merged_node)
|
||||
else:
|
||||
# 如果没有更新数据,保留原始节点
|
||||
merged_nodes.append(original_node)
|
||||
|
||||
updated_results[key] = merged_nodes
|
||||
else:
|
||||
|
||||
@@ -15,6 +15,7 @@ from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
@@ -22,6 +23,7 @@ from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.models import AgentConfig, ModelApiKey, ModelConfig
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||
from app.services import task_service
|
||||
from app.services.langchain_tool_server import Search
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.model_parameter_merger import ModelParameterMerger
|
||||
@@ -101,6 +103,14 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
)
|
||||
task = celery_app.send_task(
|
||||
"app.core.memory.agent.read_message",
|
||||
args=[end_user_id, question, [], "1", config_id, storage_type, user_rag_memory_id]
|
||||
)
|
||||
result = task_service.get_task_memory_read_result(task.id)
|
||||
status = result.get("status")
|
||||
logger.info(f"读取任务状态:{status}")
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
logger.info(f'用户ID:Agent:{end_user_id}')
|
||||
|
||||
@@ -456,23 +456,36 @@ class MemoryAgentService:
|
||||
client = MultiServerMCPClient(mcp_config)
|
||||
|
||||
async with client.session('data_flow') as session:
|
||||
session_start = time.time()
|
||||
logger.debug("Connected to MCP Server: data_flow")
|
||||
|
||||
tools_start = time.time()
|
||||
tools = await load_mcp_tools(session)
|
||||
tools_time = time.time() - tools_start
|
||||
logger.info(f"[PERF] MCP tools loading took: {tools_time:.4f}s")
|
||||
|
||||
outputs = []
|
||||
intermediate_outputs = []
|
||||
seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates
|
||||
|
||||
# Pass memory_config to the graph workflow
|
||||
graph_start = time.time()
|
||||
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph:
|
||||
graph_init_time = time.time() - graph_start
|
||||
logger.info(f"[PERF] Graph initialization took: {graph_init_time:.4f}s")
|
||||
|
||||
start = time.time()
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
workflow_errors = [] # Track errors from workflow
|
||||
|
||||
|
||||
event_count = 0
|
||||
async for event in graph.astream(
|
||||
{"messages": history, "memory_config": memory_config, "errors": []},
|
||||
stream_mode="values",
|
||||
config=config
|
||||
):
|
||||
event_count += 1
|
||||
event_start = time.time()
|
||||
messages = event.get('messages')
|
||||
# Capture any errors from the state
|
||||
if event.get('errors'):
|
||||
@@ -525,9 +538,15 @@ class MemoryAgentService:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to extract intermediate output: {e}")
|
||||
|
||||
event_time = time.time() - event_start
|
||||
logger.info(f"[PERF] Event {event_count} processing took: {event_time:.4f}s")
|
||||
|
||||
workflow_duration = time.time() - start
|
||||
logger.info(f"Read graph workflow completed in {workflow_duration}s")
|
||||
session_duration = time.time() - session_start
|
||||
logger.info(f"[PERF] Read graph workflow completed in {workflow_duration}s")
|
||||
logger.info(f"[PERF] Total session duration: {session_duration:.4f}s")
|
||||
logger.info(f"[PERF] Total events processed: {event_count}")
|
||||
# Extract final answer
|
||||
final_answer = ""
|
||||
for messages in outputs:
|
||||
@@ -1186,8 +1205,8 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
ValueError: 当终端用户不存在或应用未发布时
|
||||
"""
|
||||
from app.models.app_release_model import AppRelease
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.data_config_model import DataConfig
|
||||
from app.models.end_user_model import EndUser
|
||||
from sqlalchemy import select
|
||||
|
||||
logger.info(f"Getting connected config for end_user: {end_user_id}")
|
||||
@@ -1266,8 +1285,8 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
对于查询失败的用户,value 包含 error 字段
|
||||
"""
|
||||
from app.models.app_release_model import AppRelease
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.data_config_model import DataConfig
|
||||
from app.models.end_user_model import EndUser
|
||||
from sqlalchemy import select
|
||||
|
||||
logger.info(f"Batch getting connected configs for {len(end_user_ids)} end users")
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import Optional
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.emotion_analytics_service import EmotionAnalyticsService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -109,3 +110,188 @@ class MemoryBaseService:
|
||||
except Exception as e:
|
||||
logger.error(f"提取情景记忆情绪时出错: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def get_episodic_memory_count(
|
||||
self,
|
||||
end_user_id: Optional[str] = None
|
||||
) -> int:
|
||||
"""
|
||||
获取情景记忆数量
|
||||
|
||||
查询 MemorySummary 节点的数量。
|
||||
|
||||
Args:
|
||||
end_user_id: 可选的终端用户ID,用于过滤特定用户的节点
|
||||
|
||||
Returns:
|
||||
情景记忆的数量
|
||||
"""
|
||||
try:
|
||||
if end_user_id:
|
||||
query = """
|
||||
MATCH (n:MemorySummary)
|
||||
WHERE n.group_id = $group_id
|
||||
RETURN count(n) as count
|
||||
"""
|
||||
result = await self.neo4j_connector.execute_query(query, group_id=end_user_id)
|
||||
else:
|
||||
query = """
|
||||
MATCH (n:MemorySummary)
|
||||
RETURN count(n) as count
|
||||
"""
|
||||
result = await self.neo4j_connector.execute_query(query)
|
||||
|
||||
count = result[0]["count"] if result and len(result) > 0 else 0
|
||||
logger.debug(f"情景记忆数量: {count} (end_user_id={end_user_id})")
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取情景记忆数量时出错: {str(e)}", exc_info=True)
|
||||
return 0
|
||||
|
||||
async def get_explicit_memory_count(
|
||||
self,
|
||||
end_user_id: Optional[str] = None
|
||||
) -> int:
|
||||
"""
|
||||
获取显性记忆数量
|
||||
|
||||
显性记忆 = 情景记忆(MemorySummary)+ 语义记忆(ExtractedEntity with is_explicit_memory=true)
|
||||
|
||||
Args:
|
||||
end_user_id: 可选的终端用户ID,用于过滤特定用户的节点
|
||||
|
||||
Returns:
|
||||
显性记忆的数量
|
||||
"""
|
||||
try:
|
||||
# 1. 获取情景记忆数量
|
||||
episodic_count = await self.get_episodic_memory_count(end_user_id)
|
||||
|
||||
# 2. 获取语义记忆数量(ExtractedEntity 且 is_explicit_memory = true)
|
||||
if end_user_id:
|
||||
semantic_query = """
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE e.group_id = $group_id AND e.is_explicit_memory = true
|
||||
RETURN count(e) as count
|
||||
"""
|
||||
semantic_result = await self.neo4j_connector.execute_query(
|
||||
semantic_query,
|
||||
group_id=end_user_id
|
||||
)
|
||||
else:
|
||||
semantic_query = """
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE e.is_explicit_memory = true
|
||||
RETURN count(e) as count
|
||||
"""
|
||||
semantic_result = await self.neo4j_connector.execute_query(semantic_query)
|
||||
|
||||
semantic_count = semantic_result[0]["count"] if semantic_result and len(semantic_result) > 0 else 0
|
||||
|
||||
# 3. 计算总数
|
||||
explicit_count = episodic_count + semantic_count
|
||||
logger.debug(
|
||||
f"显性记忆数量: {explicit_count} "
|
||||
f"(情景={episodic_count}, 语义={semantic_count}, end_user_id={end_user_id})"
|
||||
)
|
||||
return explicit_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取显性记忆数量时出错: {str(e)}", exc_info=True)
|
||||
return 0
|
||||
|
||||
async def get_emotional_memory_count(
|
||||
self,
|
||||
end_user_id: Optional[str] = None,
|
||||
statement_count_fallback: int = 0
|
||||
) -> int:
|
||||
"""
|
||||
获取情绪记忆数量
|
||||
|
||||
通过 EmotionAnalyticsService 获取情绪标签统计总数。
|
||||
如果获取失败或没有指定 end_user_id,使用 statement_count_fallback 作为后备。
|
||||
|
||||
Args:
|
||||
end_user_id: 可选的终端用户ID
|
||||
statement_count_fallback: 后备方案的数量(通常是 statement 节点数量)
|
||||
|
||||
Returns:
|
||||
情绪记忆的数量
|
||||
"""
|
||||
try:
|
||||
if end_user_id:
|
||||
emotion_service = EmotionAnalyticsService()
|
||||
|
||||
emotion_data = await emotion_service.get_emotion_tags(
|
||||
end_user_id=end_user_id,
|
||||
emotion_type=None,
|
||||
start_date=None,
|
||||
end_date=None,
|
||||
limit=10
|
||||
)
|
||||
emotion_count = emotion_data.get("total_count", 0)
|
||||
logger.debug(f"情绪记忆数量: {emotion_count} (end_user_id={end_user_id})")
|
||||
return emotion_count
|
||||
else:
|
||||
# 如果没有指定 end_user_id,使用后备方案
|
||||
logger.debug(f"情绪记忆数量: {statement_count_fallback} (使用后备方案)")
|
||||
return statement_count_fallback
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"获取情绪记忆数量失败,使用后备方案: {str(e)}")
|
||||
return statement_count_fallback
|
||||
|
||||
async def get_forget_memory_count(
|
||||
self,
|
||||
end_user_id: Optional[str] = None,
|
||||
forgetting_threshold: float = 0.3
|
||||
) -> int:
|
||||
"""
|
||||
获取遗忘记忆数量
|
||||
|
||||
统计激活值低于遗忘阈值的节点数量(low_activation_nodes)。
|
||||
查询范围包括:Statement、ExtractedEntity、MemorySummary、Chunk 节点。
|
||||
|
||||
Args:
|
||||
end_user_id: 可选的终端用户ID,用于过滤特定用户的节点
|
||||
forgetting_threshold: 遗忘阈值,默认 0.3
|
||||
|
||||
Returns:
|
||||
遗忘记忆的数量(激活值低于阈值的节点数)
|
||||
"""
|
||||
try:
|
||||
# 构建查询语句
|
||||
query = """
|
||||
MATCH (n)
|
||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk)
|
||||
"""
|
||||
|
||||
if end_user_id:
|
||||
query += " AND n.group_id = $group_id"
|
||||
|
||||
query += """
|
||||
RETURN sum(CASE WHEN n.activation_value IS NOT NULL AND n.activation_value < $threshold THEN 1 ELSE 0 END) as low_activation_nodes
|
||||
"""
|
||||
|
||||
# 设置查询参数
|
||||
params = {'threshold': forgetting_threshold}
|
||||
if end_user_id:
|
||||
params['group_id'] = end_user_id
|
||||
|
||||
# 执行查询
|
||||
result = await self.neo4j_connector.execute_query(query, **params)
|
||||
|
||||
# 提取结果
|
||||
forget_count = result[0]['low_activation_nodes'] if result and len(result) > 0 else 0
|
||||
forget_count = forget_count or 0 # 处理 None 值
|
||||
|
||||
logger.debug(
|
||||
f"遗忘记忆数量: {forget_count} "
|
||||
f"(threshold={forgetting_threshold}, end_user_id={end_user_id})"
|
||||
)
|
||||
return forget_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取遗忘记忆数量时出错: {str(e)}", exc_info=True)
|
||||
return 0
|
||||
|
||||
@@ -401,5 +401,5 @@ class MemoryEpisodicService(MemoryBaseService):
|
||||
raise
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
# 创建全局服务实例(供控制器层使用)
|
||||
memory_episodic_service = MemoryEpisodicService()
|
||||
|
||||
@@ -15,6 +15,7 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_base_service import MemoryBaseService
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -1195,17 +1196,18 @@ async def analytics_memory_types(
|
||||
end_user_id: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
统计8种记忆类型的数量和百分比
|
||||
统计9种记忆类型的数量和百分比
|
||||
|
||||
计算规则:
|
||||
1. 感知记忆 (PERCEPTUAL_MEMORY) = statement + entity
|
||||
2. 工作记忆 (WORKING_MEMORY) = chunk + entity
|
||||
3. 短期记忆 (SHORT_TERM_MEMORY) = chunk
|
||||
4. 长期记忆 (LONG_TERM_MEMORY) = entity
|
||||
5. 显性记忆 (EXPLICIT_MEMORY) = 1/2 * entity
|
||||
5. 显性记忆 (EXPLICIT_MEMORY) = 情景记忆 + 语义记忆(通过 MemoryBaseService.get_explicit_memory_count 获取)
|
||||
6. 隐性记忆 (IMPLICIT_MEMORY) = 1/3 * entity
|
||||
7. 情绪记忆 (EMOTIONAL_MEMORY) = statement
|
||||
8. 情景记忆 (EPISODIC_MEMORY) = memory_summary
|
||||
7. 情绪记忆 (EMOTIONAL_MEMORY) = 情绪标签统计总数(通过 MemoryBaseService.get_emotional_memory_count 获取)
|
||||
8. 情景记忆 (EPISODIC_MEMORY) = memory_summary(通过 MemoryBaseService.get_episodic_memory_count 获取)
|
||||
9. 遗忘记忆 (FORGET_MEMORY) = 激活值低于阈值的节点数(通过 MemoryBaseService.get_forget_memory_count 获取)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
@@ -1230,13 +1232,16 @@ async def analytics_memory_types(
|
||||
- IMPLICIT_MEMORY: 隐性记忆
|
||||
- EMOTIONAL_MEMORY: 情绪记忆
|
||||
- EPISODIC_MEMORY: 情景记忆
|
||||
- FORGET_MEMORY: 遗忘记忆
|
||||
"""
|
||||
# 定义需要查询的节点类型
|
||||
# 初始化基础服务
|
||||
base_service = MemoryBaseService()
|
||||
|
||||
# 定义需要查询的基础节点类型
|
||||
node_types = {
|
||||
"Statement": "Statement",
|
||||
"Entity": "ExtractedEntity",
|
||||
"Chunk": "Chunk",
|
||||
"MemorySummary": "MemorySummary"
|
||||
"Chunk": "Chunk"
|
||||
}
|
||||
|
||||
# 存储每种节点类型的计数
|
||||
@@ -1266,18 +1271,45 @@ async def analytics_memory_types(
|
||||
statement_count = node_counts.get("Statement", 0)
|
||||
entity_count = node_counts.get("Entity", 0)
|
||||
chunk_count = node_counts.get("Chunk", 0)
|
||||
memory_summary_count = node_counts.get("MemorySummary", 0)
|
||||
|
||||
# 按规则计算8种记忆类型的数量(使用英文枚举作为key)
|
||||
# 获取用户的遗忘阈值配置
|
||||
forgetting_threshold = 0.3 # 默认值
|
||||
if end_user_id:
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
from app.core.memory.storage_services.forgetting_engine.config_utils import load_actr_config_from_db
|
||||
|
||||
# 获取用户关联的 config_id
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get('memory_config_id')
|
||||
|
||||
if config_id:
|
||||
# 从数据库加载配置
|
||||
config = load_actr_config_from_db(db, config_id)
|
||||
forgetting_threshold = config.get('forgetting_threshold', 0.3)
|
||||
logger.debug(f"使用用户配置的遗忘阈值: {forgetting_threshold} (end_user_id={end_user_id}, config_id={config_id})")
|
||||
else:
|
||||
logger.debug(f"用户未关联配置,使用默认遗忘阈值: {forgetting_threshold} (end_user_id={end_user_id})")
|
||||
except Exception as e:
|
||||
logger.warning(f"获取用户遗忘阈值配置失败,使用默认值 {forgetting_threshold}: {str(e)}")
|
||||
|
||||
# 使用 MemoryBaseService 的共享方法获取特殊记忆类型的数量
|
||||
episodic_count = await base_service.get_episodic_memory_count(end_user_id)
|
||||
explicit_count = await base_service.get_explicit_memory_count(end_user_id)
|
||||
emotion_count = await base_service.get_emotional_memory_count(end_user_id, statement_count)
|
||||
forget_count = await base_service.get_forget_memory_count(end_user_id, forgetting_threshold)
|
||||
|
||||
# 按规则计算9种记忆类型的数量(使用英文枚举作为key)
|
||||
memory_counts = {
|
||||
"PERCEPTUAL_MEMORY": statement_count + entity_count, # 感知记忆
|
||||
"WORKING_MEMORY": chunk_count + entity_count, # 工作记忆
|
||||
"SHORT_TERM_MEMORY": chunk_count, # 短期记忆
|
||||
"LONG_TERM_MEMORY": entity_count, # 长期记忆
|
||||
"EXPLICIT_MEMORY": entity_count // 2, # 显性记忆 (1/2 entity)
|
||||
"EXPLICIT_MEMORY": explicit_count, # 显性记忆(情景记忆 + 语义记忆)
|
||||
"IMPLICIT_MEMORY": entity_count // 3, # 隐性记忆 (1/3 entity)
|
||||
"EMOTIONAL_MEMORY": statement_count, # 情绪记忆
|
||||
"EPISODIC_MEMORY": memory_summary_count # 情景记忆
|
||||
"EMOTIONAL_MEMORY": emotion_count, # 情绪记忆(使用情绪标签统计)
|
||||
"EPISODIC_MEMORY": episodic_count, # 情景记忆
|
||||
"FORGET_MEMORY": forget_count # 遗忘记忆(激活值低于阈值)
|
||||
}
|
||||
|
||||
# 计算总数
|
||||
|
||||
@@ -491,6 +491,17 @@ class WorkflowService:
|
||||
)
|
||||
end_user_id = str(new_end_user.id)
|
||||
|
||||
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
|
||||
|
||||
for exec_res in executions:
|
||||
if exec_res.status == "completed":
|
||||
last_state = exec_res.output_data
|
||||
if isinstance(last_state, dict):
|
||||
variables = last_state.get("variables", {})
|
||||
conv_vars = variables.get("conv", {})
|
||||
input_data["conv"] = conv_vars
|
||||
break
|
||||
|
||||
result = await execute_workflow(
|
||||
workflow_config=workflow_config_dict,
|
||||
input_data=input_data,
|
||||
@@ -504,7 +515,7 @@ class WorkflowService:
|
||||
self.update_execution_status(
|
||||
execution.execution_id,
|
||||
"completed",
|
||||
output_data=result.get("node_outputs", {})
|
||||
output_data=result
|
||||
)
|
||||
else:
|
||||
self.update_execution_status(
|
||||
@@ -517,6 +528,7 @@ class WorkflowService:
|
||||
return {
|
||||
"execution_id": execution.execution_id,
|
||||
"status": result.get("status"),
|
||||
"variables": result.get("variables"),
|
||||
"output": result.get("output"), # 最终输出(字符串)
|
||||
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
|
||||
"conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID
|
||||
@@ -617,6 +629,16 @@ class WorkflowService:
|
||||
original_user_id=payload.user_id # Save original user_id to other_id
|
||||
)
|
||||
end_user_id = str(new_end_user.id)
|
||||
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
|
||||
|
||||
for exec_res in executions:
|
||||
if exec_res.status == "completed":
|
||||
last_state = exec_res.output_data
|
||||
if isinstance(last_state, dict):
|
||||
variables = last_state.get("variables", {})
|
||||
conv_vars = variables.get("conv", {})
|
||||
input_data["conv"] = conv_vars
|
||||
break
|
||||
|
||||
# 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件)
|
||||
async for event in self._run_workflow_stream(
|
||||
@@ -827,6 +849,23 @@ class WorkflowService:
|
||||
user_id=user_id
|
||||
):
|
||||
# 直接转发事件(executor 已经返回正确格式)
|
||||
if event.get("event") == "workflow_end":
|
||||
|
||||
status = event.get("data", {}).get("status")
|
||||
if status == "completed":
|
||||
self.update_execution_status(
|
||||
execution_id,
|
||||
"completed",
|
||||
output_data=event.get("data")
|
||||
)
|
||||
elif status == "failed":
|
||||
self.update_execution_status(
|
||||
execution_id,
|
||||
"failed",
|
||||
output_data=event.get("data")
|
||||
)
|
||||
else:
|
||||
logger.error(f"unexpect workflow run status, status: {status}")
|
||||
yield event
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -117,26 +117,26 @@ export const getRagContent = (end_user_id: string) => {
|
||||
}
|
||||
// 情感分布分析
|
||||
export const getWordCloud = (group_id: string) => {
|
||||
return request.post(`/memory/emotion/wordcloud`, { group_id, limit: 20 })
|
||||
return request.post(`/memory/emotion-memory/wordcloud`, { group_id, limit: 20 })
|
||||
}
|
||||
// 高频情绪关键词
|
||||
export const getEmotionTags = (group_id: string) => {
|
||||
return request.post(`/memory/emotion/tags`, { group_id, limit: 20 })
|
||||
return request.post(`/memory/emotion-memory/tags`, { group_id, limit: 20 })
|
||||
}
|
||||
// 情绪健康指数
|
||||
export const getEmotionHealth = (group_id: string) => {
|
||||
return request.post(`/memory/emotion/health`, { group_id, limit: 20 })
|
||||
return request.post(`/memory/emotion-memory/health`, { group_id, limit: 20 })
|
||||
}
|
||||
// 个性化建议
|
||||
export const getEmotionSuggestions = (group_id: string) => {
|
||||
return request.post(`/memory/emotion/suggestions`, { group_id, limit: 20 })
|
||||
return request.post(`/memory/emotion-memory/suggestions`, { group_id, limit: 20 })
|
||||
}
|
||||
export const analyticsRefresh = (end_user_id: string) => {
|
||||
return request.post('/memory-storage/analytics/generate_cache', { end_user_id })
|
||||
}
|
||||
// 遗忘
|
||||
export const getForgetStats = (group_id: string) => {
|
||||
return request.get(`/memory/forget/stats`, { group_id })
|
||||
return request.get(`/memory/forget-memory/stats`, { group_id })
|
||||
}
|
||||
// 隐性记忆-偏好
|
||||
export const getImplicitPreferences = (end_user_id: string) => {
|
||||
@@ -176,10 +176,10 @@ export const getPerceptualTimeline = (end_user: string) => {
|
||||
}
|
||||
// 情景记忆-总览
|
||||
export const getEpisodicOverview = (data: { end_user_id: string; time_range: string; episodic_type: string; } ) => {
|
||||
return request.post(`/memory-storage/classifications/episodic-memory`, data)
|
||||
return request.post(`/memory/episodic-memory/overview`, data)
|
||||
}
|
||||
export const getEpisodicDetail = (data: { end_user_id: string; summary_id: string; } ) => {
|
||||
return request.post(`/memory-storage/classifications/episodic-memory-details`, data)
|
||||
return request.post(`/memory/episodic-memory/details`, data)
|
||||
}
|
||||
// 关系演化
|
||||
export const getRelationshipEvolution = (data: { id: string; label: string; } ) => {
|
||||
@@ -190,10 +190,10 @@ export const getTimelineMemories = (data: { id: string; label: string; }) => {
|
||||
return request.get(`/memory-storage/memory_space/timeline_memories`, data)
|
||||
}
|
||||
export const getExplicitMemory = (end_user_id: string) => {
|
||||
return request.post(`/memory-storage/classifications/explicit-memory`, { end_user_id })
|
||||
return request.post(`/memory/explicit-memory/overview`, { end_user_id })
|
||||
}
|
||||
export const getExplicitMemoryDetails = (data: { end_user_id: string, memory_id: string; }) => {
|
||||
return request.post(`/memory-storage/classifications/explicit-memory-details`, data)
|
||||
return request.post(`/memory/explicit-memory/details`, data)
|
||||
}
|
||||
export const getConversations = (end_user: string) => {
|
||||
return request.get(`/memory/work/${end_user}/conversations`)
|
||||
@@ -205,7 +205,7 @@ export const getConversationDetail = (end_user: string, conversation_id: string)
|
||||
return request.get(`/memory/work/${end_user}/detail`, { conversation_id })
|
||||
}
|
||||
export const forgetTrigger = (data: { max_merge_batch_size: number; min_days_since_access: number; end_user_id: string;}) => {
|
||||
return request.post(`/memory/forget/trigger`, data)
|
||||
return request.post(`/memory/forget-memory/trigger`, data)
|
||||
}
|
||||
/*************** end 用户记忆 相关接口 ******************************/
|
||||
|
||||
@@ -229,11 +229,11 @@ export const deleteMemoryConfig = (config_id: number) => {
|
||||
}
|
||||
// 遗忘引擎-获取配置
|
||||
export const getMemoryForgetConfig = (config_id: number | string) => {
|
||||
return request.get('/memory/forget/read_config', { config_id })
|
||||
return request.get('/memory/forget-memory/read_config', { config_id })
|
||||
}
|
||||
// 遗忘引擎-更新配置
|
||||
export const updateMemoryForgetConfig = (values: ForgetConfigForm) => {
|
||||
return request.post('/memory/forget/update_config', values)
|
||||
return request.post('/memory/forget-memory/update_config', values)
|
||||
}
|
||||
// 记忆萃取引擎-获取配置
|
||||
export const getMemoryExtractionConfig = (config_id: number | string) => {
|
||||
|
||||
17
web/src/assets/images/menu/spaceConfig.svg
Normal file
17
web/src/assets/images/menu/spaceConfig.svg
Normal file
@@ -0,0 +1,17 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||
<title>模型 (1)</title>
|
||||
<g id="v0.2.0" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||
<g id="红熊空间-记忆管理" transform="translate(-24, -409)" stroke="#5B6167">
|
||||
<g id="记忆对话备份-2" transform="translate(12, 401)">
|
||||
<g id="模型-(1)" transform="translate(12, 8)">
|
||||
<g id="编组-21" transform="translate(1.5, 1.5)">
|
||||
<path d="M7,0.288675135 L11.6291651,2.96132487 C11.9385662,3.13995766 12.1291651,3.47008468 12.1291651,3.82735027 L12.1291651,9.17264973 C12.1291651,9.52991532 11.9385662,9.86004234 11.6291651,10.0386751 L7,12.7113249 C6.69059892,12.8899577 6.30940108,12.8899577 6,12.7113249 L1.37083488,10.0386751 C1.0614338,9.86004234 0.870834875,9.52991532 0.870834875,9.17264973 L0.870834875,3.82735027 C0.870834875,3.47008468 1.0614338,3.13995766 1.37083488,2.96132487 L6,0.288675135 C6.30940108,0.11004234 6.69059892,0.11004234 7,0.288675135 Z" id="多边形"></path>
|
||||
<polyline id="路径-15" points="0.931223827 3.37218958 6.5 6.5 6.5 12.8581283"></polyline>
|
||||
<line x1="6.5" y1="6.49748419" x2="12.0714286" y2="3.37218958" id="路径-16"></line>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.5 KiB |
17
web/src/assets/images/menu/spaceConfig_active.svg
Normal file
17
web/src/assets/images/menu/spaceConfig_active.svg
Normal file
@@ -0,0 +1,17 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||
<title>模型 (1)</title>
|
||||
<g id="v0.2.0" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||
<g id="红熊空间-记忆管理" transform="translate(-24, -409)" stroke="#212332">
|
||||
<g id="记忆对话备份-2" transform="translate(12, 401)">
|
||||
<g id="模型-(1)" transform="translate(12, 8)">
|
||||
<g id="编组-21" transform="translate(1.5, 1.5)">
|
||||
<path d="M7,0.288675135 L11.6291651,2.96132487 C11.9385662,3.13995766 12.1291651,3.47008468 12.1291651,3.82735027 L12.1291651,9.17264973 C12.1291651,9.52991532 11.9385662,9.86004234 11.6291651,10.0386751 L7,12.7113249 C6.69059892,12.8899577 6.30940108,12.8899577 6,12.7113249 L1.37083488,10.0386751 C1.0614338,9.86004234 0.870834875,9.52991532 0.870834875,9.17264973 L0.870834875,3.82735027 C0.870834875,3.47008468 1.0614338,3.13995766 1.37083488,2.96132487 L6,0.288675135 C6.30940108,0.11004234 6.69059892,0.11004234 7,0.288675135 Z" id="多边形"></path>
|
||||
<polyline id="路径-15" points="0.931223827 3.37218958 6.5 6.5 6.5 12.8581283"></polyline>
|
||||
<line x1="6.5" y1="6.49748419" x2="12.0714286" y2="3.37218958" id="路径-16"></line>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.5 KiB |
19
web/src/assets/images/userMemory/goto.svg
Normal file
19
web/src/assets/images/userMemory/goto.svg
Normal file
@@ -0,0 +1,19 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg width="28px" height="28px" viewBox="0 0 28 28" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||
<title>编组 13备份</title>
|
||||
<g id="V1.1" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||
<g id="红熊空间-记忆管理" transform="translate(-947, -144)">
|
||||
<g id="1备份-2" transform="translate(651, 128)">
|
||||
<g id="编组-13备份" transform="translate(296, 16)">
|
||||
<rect id="矩形" stroke="#DFE4ED" x="0.5" y="0.5" width="27" height="27" rx="6"></rect>
|
||||
<g id="进入@2x" transform="translate(5.8333, 5.8333)">
|
||||
<g id="编组-11" transform="translate(2.0417, 2.5521)">
|
||||
<path d="M5.42385066,3.34516089 L8.15899029,5.47250014 C8.23746067,5.5335329 8.25159666,5.64662254 8.1905639,5.72509292 C8.1813906,5.73688711 8.17078448,5.74749323 8.15899029,5.75666652 L5.42385066,7.88400578 C5.34538028,7.94503854 5.23229064,7.93090256 5.17125788,7.85243218 C5.14668314,7.82083621 5.13334107,7.78195037 5.13334107,7.74192259 L5.13334107,6.2384308 L5.13334107,6.2384308 L0,6.2384308 L0,4.99073587 L5.13334107,4.99073587 L5.13334107,3.48724407 C5.13334107,3.38783282 5.21392981,3.30724407 5.31334107,3.30724407 C5.35336884,3.30724407 5.39225469,3.32058615 5.42385066,3.34516089 Z" id="路径" fill="#5B6167" fill-rule="nonzero"></path>
|
||||
<path d="M1.60417096,2.83745334 L1.60417096,0.9 C1.60417096,0.402943725 2.00711469,0 2.50417096,-1.11022302e-16 L10.3291667,-1.11022302e-16 C10.8262229,-2.22044605e-16 11.2291667,0.402943725 11.2291667,0.9 L11.2291667,10.3291667 C11.2291667,10.8262229 10.8262229,11.2291667 10.3291667,11.2291667 L2.50417096,11.2291667 C2.00711469,11.2291667 1.60417096,10.8262229 1.60417096,10.3291667 L1.60417096,8.46506778 L1.60417096,8.46506778" id="路径" stroke="#5B6167" stroke-width="1.1" stroke-linejoin="round"></path>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 2.1 KiB |
@@ -1,4 +1,4 @@
|
||||
import { useEffect, useState, useCallback, useRef, type FC, type Key } from 'react';
|
||||
import { useEffect, useState, type FC, type Key } from 'react';
|
||||
import { Select } from 'antd'
|
||||
import type { SelectProps, DefaultOptionType } from 'antd/es/select'
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -26,7 +26,7 @@ interface CustomSelectProps extends Omit<SelectProps, 'filterOption'> {
|
||||
disabled?: boolean;
|
||||
style?: React.CSSProperties;
|
||||
className?: string;
|
||||
filterOption?: (inputValue: string, option: DefaultOptionType) => boolean;
|
||||
filterOption?: (inputValue: string, option?: DefaultOptionType) => boolean;
|
||||
}
|
||||
interface OptionType {
|
||||
[key: string]: Key | string | number;
|
||||
@@ -48,44 +48,27 @@ const CustomSelect: FC<CustomSelectProps> = ({
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
const [options, setOptions] = useState<OptionType[]>([]);
|
||||
// 创建防抖定时器引用
|
||||
const debounceRef = useRef<number>();
|
||||
|
||||
// 防抖搜索函数
|
||||
const handleSearch = useCallback((value?: string) => {
|
||||
// 清除之前的定时器
|
||||
if (debounceRef.current) {
|
||||
clearTimeout(debounceRef.current);
|
||||
}
|
||||
|
||||
// 设置新的定时器
|
||||
debounceRef.current = window.setTimeout(() => {
|
||||
request.get<ApiResponse<OptionType>>(url, {...params, [optionFilterProp]: value}).then((res) => {
|
||||
const data = res;
|
||||
setOptions(Array.isArray(data) ? data || [] : Array.isArray(data?.items) ? data.items || [] : []);
|
||||
});
|
||||
}, 300); // 300毫秒防抖延迟
|
||||
}, [url, params, optionFilterProp]);
|
||||
|
||||
// 默认模糊搜索函数
|
||||
const defaultFilterOption = (inputValue: string, option?: DefaultOptionType) => {
|
||||
if (!option || !inputValue) return true;
|
||||
const label = String(option.children || option.label || '');
|
||||
return label.toLowerCase().includes(inputValue.toLowerCase());
|
||||
};
|
||||
// 组件挂载时获取初始数据
|
||||
useEffect(() => {
|
||||
handleSearch();
|
||||
|
||||
// 组件卸载时清除定时器
|
||||
return () => {
|
||||
if (debounceRef.current) {
|
||||
clearTimeout(debounceRef.current);
|
||||
}
|
||||
};
|
||||
}, [url, handleSearch]);
|
||||
request.get<ApiResponse<OptionType>>(url, params).then((res) => {
|
||||
const data = res;
|
||||
setOptions(Array.isArray(data) ? data || [] : Array.isArray(data?.items) ? data.items || [] : []);
|
||||
});
|
||||
}, []);
|
||||
return (
|
||||
<Select
|
||||
placeholder={placeholder ? placeholder : t('common.select')}
|
||||
onChange={onChange}
|
||||
defaultValue={hasAll ? null : undefined}
|
||||
showSearch={showSearch}
|
||||
onSearch={handleSearch}
|
||||
filterOption={filterOption || false} // 禁用本地过滤,使用服务器端过滤
|
||||
filterOption={filterOption || defaultFilterOption}
|
||||
{...props}
|
||||
>
|
||||
{hasAll && (<Select.Option>{allTitle || t('common.all')}</Select.Option>)}
|
||||
|
||||
@@ -40,6 +40,8 @@ import apiKeyIcon from '@/assets/images/menu/apiKey.png';
|
||||
import apiKeyActiveIcon from '@/assets/images/menu/apiKey_active.png';
|
||||
import pricingIcon from '@/assets/images/menu/pricing.svg'
|
||||
import pricingActiveIcon from '@/assets/images/menu/pricing_active.svg'
|
||||
import spaceConfigIcon from '@/assets/images/menu/spaceConfig.svg'
|
||||
import spaceConfigActiveIcon from '@/assets/images/menu/spaceConfig_active.svg'
|
||||
|
||||
// 图标路径映射表
|
||||
const iconPathMap: Record<string, string> = {
|
||||
@@ -68,7 +70,9 @@ const iconPathMap: Record<string, string> = {
|
||||
'apiKey': apiKeyIcon,
|
||||
'apiKeyActive': apiKeyActiveIcon,
|
||||
'pricing': pricingIcon,
|
||||
'pricingActive': pricingActiveIcon
|
||||
'pricingActive': pricingActiveIcon,
|
||||
'spaceConfig': spaceConfigIcon,
|
||||
'spaceConfigActive': spaceConfigActiveIcon,
|
||||
};
|
||||
|
||||
const { Sider } = Layout;
|
||||
|
||||
@@ -87,7 +87,7 @@ export const en = {
|
||||
modelManagement: 'Model Management',
|
||||
memoryStore: 'Memory Store',
|
||||
apiParameters: 'API Parameters',
|
||||
userMemory: 'User Memory',
|
||||
userMemory: 'Memory Store',
|
||||
memberManagement: 'Member Management',
|
||||
memorySummary: 'Memory Summary',
|
||||
memoryConversation: 'Memory Validation',
|
||||
@@ -110,6 +110,7 @@ export const en = {
|
||||
pricing: 'Pricing Management',
|
||||
orderPayment: 'Order Payment',
|
||||
orderHistory: 'Order History',
|
||||
spaceConfig: 'Space Configuration'
|
||||
},
|
||||
dashboard: {
|
||||
total_models: 'Total number of available models',
|
||||
@@ -1232,6 +1233,8 @@ export const en = {
|
||||
hire_date: 'Hire Date',
|
||||
memoryContent: 'Memory Content',
|
||||
created_at: 'Created At',
|
||||
updated_at: 'Updated At',
|
||||
fullScreen: 'Full Screen',
|
||||
|
||||
memoryWindow: "{{name}}'s Window of Memory",
|
||||
memory_insight: 'Overall Overview',
|
||||
@@ -1258,7 +1261,7 @@ export const en = {
|
||||
unix: 'items',
|
||||
completeMemory: 'Complete Memory',
|
||||
relationshipEvolution: 'Relationship Evolution',
|
||||
timelineMemories: 'Shared Memory Timeline',
|
||||
timelineMemories: 'Long-term Memory',
|
||||
emotionLine: 'Emotion Changes Over Time',
|
||||
interaction: 'Interaction Frequency & Relationship Stages',
|
||||
timelines_memory: 'All',
|
||||
@@ -1269,6 +1272,12 @@ export const en = {
|
||||
negative: 'Negative Emotion',
|
||||
neutral: 'Neutral Emotion',
|
||||
interactionCountData: 'Interaction Count',
|
||||
capacity: 'Capacity',
|
||||
type: 'Type',
|
||||
person: 'Personal',
|
||||
memoryNum: 'memories',
|
||||
memory_config_name: 'Memory Engine',
|
||||
searchPlaceholder: 'Search memory store name',
|
||||
},
|
||||
space: {
|
||||
createSpace: 'Create Space',
|
||||
@@ -1284,7 +1293,8 @@ export const en = {
|
||||
neo4jDesc: 'Based on knowledge graph, suitable for relational reasoning and path query',
|
||||
llmModel: 'LLM Model',
|
||||
embeddingModel: 'Embedding Model',
|
||||
rerankModel: 'Rerank Model'
|
||||
rerankModel: 'Rerank Model',
|
||||
configAlert: 'Space model configuration ensures that the space can correctly call the corresponding models to process business data during runtime.',
|
||||
},
|
||||
memoryExtractionEngine: {
|
||||
title: 'Memory Engine Module Configuration Center',
|
||||
@@ -1459,6 +1469,8 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
||||
quickReply: 'Quick Reply',
|
||||
web_search: 'Online search',
|
||||
memory: 'Memory',
|
||||
memoryConversationAnalysisEmpty: 'There is currently no dialogue analysis content available',
|
||||
memoryConversationAnalysisEmptySubTitle: 'After entering your user ID, click on "Test Memory" to view the conversation memory',
|
||||
},
|
||||
login: {
|
||||
title: 'Red Bear Memory Science',
|
||||
@@ -1613,19 +1625,17 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
||||
|
||||
JsonTool_desc: 'Data Format Conversion',
|
||||
JsonTool_features: 'JSON formatting, compression, validation and conversion functions',
|
||||
jsonFormat: 'JSON Formatting',
|
||||
jsonGzip: 'JSON Compression',
|
||||
jsonCheck: 'JSON Validation',
|
||||
jsonConversion: 'Format Conversion',
|
||||
jsonParse: 'JSON Parse',
|
||||
jsonInsert: 'JSON Insert',
|
||||
jsonReplace: 'JSON Validation',
|
||||
jsonDelete: 'JSON Delete',
|
||||
jsonEg: 'Example JSON',
|
||||
enterJson: 'Enter JSON',
|
||||
jsonPlaceholder: 'Enter JSON data, e.g.: {"name": "test", "value": 123}',
|
||||
clear: 'Clear',
|
||||
parse: 'Paste',
|
||||
format: 'Format',
|
||||
minify: 'Minify',
|
||||
validate: 'Validate',
|
||||
convert: 'Escape',
|
||||
paste: 'Paste',
|
||||
parse: 'Parse',
|
||||
json_path: 'JSON Path Parameters',
|
||||
outputResult: 'Output Result',
|
||||
validJosn: 'JSON format is correct, validation passed!',
|
||||
|
||||
@@ -1944,7 +1954,8 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
||||
variableConfig: 'Variable Configuration',
|
||||
variableRequired: 'Required',
|
||||
addMessage: 'Add Message',
|
||||
answerDesc: 'Reply'
|
||||
answerDesc: 'Reply',
|
||||
addNode: 'Add Node',
|
||||
},
|
||||
emotionEngine: {
|
||||
emotionEngineConfig: 'Emotion Engine Configuration',
|
||||
|
||||
@@ -87,7 +87,7 @@ export const zh = {
|
||||
modelManagement: '模型管理',
|
||||
memoryStore: '记忆存储',
|
||||
apiParameters: 'API参数',
|
||||
userMemory: '用户记忆',
|
||||
userMemory: '记忆库',
|
||||
memberManagement: '成员管理',
|
||||
memorySummary: '记忆摘要',
|
||||
memoryConversation: '记忆验证',
|
||||
@@ -102,7 +102,7 @@ export const zh = {
|
||||
knowledgeShare: '详情',
|
||||
knowledgeCreateDataset: '新建数据集',
|
||||
knowledgeDocumentDetails: '详情',
|
||||
userMemoryDetail: '用户记忆详情',
|
||||
userMemoryDetail: '记忆库详情',
|
||||
toolManagement: '工具管理',
|
||||
emotionEngine: '情感引擎',
|
||||
statementDetail: '情绪记忆',
|
||||
@@ -110,6 +110,7 @@ export const zh = {
|
||||
pricing: '收费管理',
|
||||
orderPayment: '订单支付',
|
||||
orderHistory: '订单记录',
|
||||
spaceConfig: '空间配置'
|
||||
},
|
||||
knowledgeBase: {
|
||||
home: '首页',
|
||||
@@ -1313,7 +1314,7 @@ export const zh = {
|
||||
updated_at: '最后更新时间',
|
||||
fullScreen: '全屏',
|
||||
|
||||
memoryWindow: "{{name}}的记忆之窗",
|
||||
memoryWindow: "{{name}} 的记忆之窗",
|
||||
memory_insight: '总体概述',
|
||||
key_findings: '关键发现',
|
||||
behavior_pattern: '行为模式',
|
||||
@@ -1338,7 +1339,7 @@ export const zh = {
|
||||
unix: '个',
|
||||
completeMemory: '完整记忆',
|
||||
relationshipEvolution: '关系演化',
|
||||
timelineMemories: '共同记忆时间线',
|
||||
timelineMemories: '长期记忆',
|
||||
emotionLine: '情绪随时间变化',
|
||||
interaction: '互动频率 & 关系阶段',
|
||||
timelines_memory: '全部',
|
||||
@@ -1349,6 +1350,12 @@ export const zh = {
|
||||
negative: '负向情绪',
|
||||
neutral: '中性情绪',
|
||||
interactionCountData: '互动次数',
|
||||
capacity: '容量',
|
||||
type: '类型',
|
||||
person: '个人',
|
||||
memoryNum: '条记忆',
|
||||
memory_config_name: '记忆引擎',
|
||||
searchPlaceholder: '搜索记忆库名称',
|
||||
},
|
||||
space: {
|
||||
createSpace: '创建空间',
|
||||
@@ -1364,7 +1371,8 @@ export const zh = {
|
||||
neo4jDesc: '基于知识图谱,适合关系推理和路径查询',
|
||||
llmModel: 'LLM 模型',
|
||||
embeddingModel: 'Embedding 模型',
|
||||
rerankModel: 'Rerank 模型'
|
||||
rerankModel: 'Rerank 模型',
|
||||
configAlert: '空间模型配置为空间的模型模型,保障空间运行时能正确的调用到相应的模型来处理业务数据。',
|
||||
},
|
||||
memoryExtractionEngine: {
|
||||
title: '记忆引擎模块配置中心',
|
||||
@@ -1537,6 +1545,8 @@ export const zh = {
|
||||
quickReply: '快速回复',
|
||||
web_search: '联网搜索',
|
||||
memory: '记忆',
|
||||
memoryConversationAnalysisEmpty: '目前没有可用的对话分析内容',
|
||||
memoryConversationAnalysisEmptySubTitle: '输入您的用户ID后,点击"测试记忆"查看对话记忆',
|
||||
},
|
||||
login: {
|
||||
title: '红熊记忆科学',
|
||||
@@ -1711,19 +1721,17 @@ export const zh = {
|
||||
|
||||
JsonTool_desc: '数据格式转换',
|
||||
JsonTool_features: 'JSON格式化、压缩、验证和转换功能',
|
||||
jsonFormat: 'JSON格式化',
|
||||
jsonGzip: 'JSON压缩',
|
||||
jsonCheck: 'JSON验证',
|
||||
jsonConversion: '格式转换',
|
||||
jsonParse: 'JSON解析',
|
||||
jsonInsert: 'JSON插入',
|
||||
jsonReplace: 'JSON验证',
|
||||
jsonDelete: 'JSON删除',
|
||||
jsonEg: '示例JSON',
|
||||
enterJson: '输入JSON',
|
||||
jsonPlaceholder: '输入JSON数据,例如:{"name": "测试", "value": 123}',
|
||||
clear: '清空',
|
||||
parse: '粘贴',
|
||||
format: '格式化',
|
||||
minify: '压缩',
|
||||
validate: '验证',
|
||||
convert: '转义',
|
||||
paste: '粘贴',
|
||||
parse: '解析',
|
||||
json_path: 'JSON 路径参数',
|
||||
outputResult: '输出结果',
|
||||
validJosn: 'JSON格式正确,验证通过!',
|
||||
|
||||
@@ -2043,7 +2051,8 @@ export const zh = {
|
||||
variableConfig: '变量配置',
|
||||
variableRequired: '必填',
|
||||
addMessage: '添加消息',
|
||||
answerDesc: '回复'
|
||||
answerDesc: '回复',
|
||||
addNode: '添加节点',
|
||||
},
|
||||
emotionEngine: {
|
||||
emotionEngineConfig: '情感引擎配置',
|
||||
|
||||
@@ -66,6 +66,7 @@ const componentMap: Record<string, LazyExoticComponent<ComponentType<object>>> =
|
||||
OrderHistory: lazy(() => import('@/views/OrderHistory')),
|
||||
Pricing: lazy(() => import('@/views/Pricing')),
|
||||
ToolManagement: lazy(() => import('@/views/ToolManagement')),
|
||||
SpaceConfig: lazy(() => import('@/views/SpaceConfig')),
|
||||
Login: lazy(() => import('@/views/Login')),
|
||||
InviteRegister: lazy(() => import('@/views/InviteRegister')),
|
||||
NoPermission: lazy(() => import('@/views/NoPermission')),
|
||||
|
||||
@@ -33,6 +33,7 @@
|
||||
{ "path": "/api-key", "element": "ApiKeyManagement" },
|
||||
{ "path": "/emotion-engine/:id", "element": "EmotionEngine" },
|
||||
{ "path": "/reflection-engine/:id", "element": "SelfReflectionEngine" },
|
||||
{ "path": "/space-config", "element": "SpaceConfig" },
|
||||
{ "path": "/no-permission", "element": "NoPermission" },
|
||||
{ "path": "/*", "element": "NotFound" }
|
||||
]
|
||||
|
||||
@@ -376,6 +376,21 @@
|
||||
"icon": null,
|
||||
"iconActive": null,
|
||||
"subs": null
|
||||
},
|
||||
{
|
||||
"id": 12,
|
||||
"parent": 0,
|
||||
"code": "spaceConfig",
|
||||
"label": "空间配置",
|
||||
"i18nKey": "menu.spaceConfig",
|
||||
"path": "/space-config",
|
||||
"enable": true,
|
||||
"display": true,
|
||||
"level": 1,
|
||||
"sort": 0,
|
||||
"icon": null,
|
||||
"iconActive": null,
|
||||
"subs": null
|
||||
}
|
||||
]
|
||||
}
|
||||
118
web/src/views/SpaceConfig/index.tsx
Normal file
118
web/src/views/SpaceConfig/index.tsx
Normal file
@@ -0,0 +1,118 @@
|
||||
import { type FC, useEffect, useState } from 'react';
|
||||
import { Form, App, Button, Skeleton } from 'antd';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { SpaceConfigData } from './types'
|
||||
import { getWorkspaceModels, updateWorkspaceModels } from '@/api/workspaces'
|
||||
import { getModelListUrl } from '@/api/models'
|
||||
import CustomSelect from '@/components/CustomSelect'
|
||||
import RbAlert from '@/components/RbAlert';
|
||||
|
||||
const SpaceConfig: FC = () => {
|
||||
const { t } = useTranslation();
|
||||
const { message } = App.useApp();
|
||||
const [pageLoading, setPageLoding] = useState(false)
|
||||
const [form] = Form.useForm<SpaceConfigData>();
|
||||
const [loading, setLoading] = useState(false)
|
||||
|
||||
const values = Form.useWatch([], form);
|
||||
|
||||
useEffect(() => {
|
||||
setPageLoding(true)
|
||||
getWorkspaceModels().then((res) => {
|
||||
const { llm, embedding, rerank } = res as SpaceConfigData
|
||||
form.setFieldsValue({
|
||||
llm,
|
||||
embedding,
|
||||
rerank
|
||||
})
|
||||
})
|
||||
.finally(() => {
|
||||
setPageLoding(false)
|
||||
})
|
||||
}, [])
|
||||
// 封装保存方法,添加提交逻辑
|
||||
const handleSave = () => {
|
||||
form
|
||||
.validateFields()
|
||||
.then(() => {
|
||||
setLoading(true)
|
||||
updateWorkspaceModels(values)
|
||||
.then(() => {
|
||||
setLoading(false)
|
||||
message.success(t('common.updateSuccess'))
|
||||
})
|
||||
.catch(() => {
|
||||
setLoading(false)
|
||||
});
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log('err', err)
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="rb:h-full rb:max-w-140 rb:mx-auto">
|
||||
{pageLoading
|
||||
? <Skeleton active />
|
||||
: <Form
|
||||
form={form}
|
||||
layout="vertical"
|
||||
>
|
||||
<Form.Item
|
||||
label={t('space.llmModel')}
|
||||
name="llm"
|
||||
rules={[{ required: true, message: t('common.pleaseSelect') }]}
|
||||
>
|
||||
<CustomSelect
|
||||
url={getModelListUrl}
|
||||
params={{ type: 'llm', pagesize: 100 }}
|
||||
valueKey="id"
|
||||
labelKey="name"
|
||||
hasAll={false}
|
||||
style={{width: '100%'}}
|
||||
/>
|
||||
</Form.Item>
|
||||
<Form.Item
|
||||
label={t('space.embeddingModel')}
|
||||
name="embedding"
|
||||
rules={[{ required: true, message: t('common.pleaseSelect') }]}
|
||||
>
|
||||
<CustomSelect
|
||||
url={getModelListUrl}
|
||||
params={{ type: 'embedding', pagesize: 100 }}
|
||||
valueKey="id"
|
||||
labelKey="name"
|
||||
hasAll={false}
|
||||
style={{width: '100%'}}
|
||||
/>
|
||||
</Form.Item>
|
||||
<Form.Item
|
||||
label={t('space.rerankModel')}
|
||||
name="rerank"
|
||||
rules={[{ required: true, message: t('common.pleaseSelect') }]}
|
||||
>
|
||||
<CustomSelect
|
||||
url={getModelListUrl}
|
||||
params={{ type: 'rerank', pagesize: 100 }}
|
||||
valueKey="id"
|
||||
labelKey="name"
|
||||
hasAll={false}
|
||||
style={{width: '100%'}}
|
||||
/>
|
||||
</Form.Item>
|
||||
|
||||
<RbAlert>{t('space.configAlert')}</RbAlert>
|
||||
|
||||
<Form.Item className="rb:text-right">
|
||||
<Button type="primary" className="rb:mt-6" onClick={handleSave} loading={loading}>
|
||||
{t('common.save')}
|
||||
</Button>
|
||||
</Form.Item>
|
||||
</Form>
|
||||
}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default SpaceConfig;
|
||||
8
web/src/views/SpaceConfig/types.ts
Normal file
8
web/src/views/SpaceConfig/types.ts
Normal file
@@ -0,0 +1,8 @@
|
||||
export interface SpaceConfigData {
|
||||
llm: string;
|
||||
embedding: string;
|
||||
rerank: string;
|
||||
}
|
||||
export interface SpaceConfigRef {
|
||||
handleOpen: () => void;
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
import { forwardRef, useImperativeHandle, useState } from 'react';
|
||||
import { Form, Input, Button, Space, Tree } from 'antd';
|
||||
import { Form, Input, Button, Space } from 'antd';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { TreeDataNode } from 'antd';
|
||||
|
||||
@@ -12,7 +12,7 @@ import { execute } from '@/api/tools';
|
||||
const JsonToolModal = forwardRef<JsonToolModalRef>((_props, ref) => {
|
||||
const { t } = useTranslation();
|
||||
const [visible, setVisible] = useState(false);
|
||||
const [form] = Form.useForm<{ json: string; }>();
|
||||
const [form] = Form.useForm<{ json: string; json_path: string; }>();
|
||||
const [data, setData] = useState<ToolItem>({} as ToolItem)
|
||||
const [formatValue, setFormatValue] = useState<string | Record<string, any> | null>(null)
|
||||
|
||||
@@ -60,44 +60,29 @@ const JsonToolModal = forwardRef<JsonToolModalRef>((_props, ref) => {
|
||||
}
|
||||
const handleOperate = (type: string) => {
|
||||
const json = form.getFieldValue('json')
|
||||
const json_path = form.getFieldValue('json_path')
|
||||
if (!json || !data.id) return
|
||||
let params: ExecuteData = {
|
||||
tool_id: data.id,
|
||||
parameters: {
|
||||
operation: type,
|
||||
input_data: json
|
||||
input_data: json,
|
||||
json_path
|
||||
}
|
||||
}
|
||||
if (type === 'format') {
|
||||
if (type === 'parse') {
|
||||
params = {
|
||||
...params,
|
||||
parameters: {
|
||||
...params.parameters,
|
||||
indent: 2,
|
||||
ensure_ascii: false,
|
||||
sort_keys: false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
execute(params)
|
||||
.then(res => {
|
||||
const { data } = res as {data: {
|
||||
formatted_json: string;
|
||||
minified_json: string;
|
||||
is_valid: boolean;
|
||||
converted_json: string;
|
||||
error: string;
|
||||
structure: Record<string, string | number>
|
||||
}}
|
||||
switch (type) {
|
||||
case 'format':
|
||||
setFormatValue(data.formatted_json);
|
||||
break
|
||||
case 'minify':
|
||||
setFormatValue(data.minified_json)
|
||||
break
|
||||
}
|
||||
const { data } = res as { data: string; }
|
||||
setFormatValue(data);
|
||||
})
|
||||
}
|
||||
const clear = () => {
|
||||
@@ -126,15 +111,20 @@ const JsonToolModal = forwardRef<JsonToolModalRef>((_props, ref) => {
|
||||
label={<Space size={8}>
|
||||
{t('tool.enterJson')}
|
||||
<Button onClick={clear}>{t('tool.clear')}</Button>
|
||||
<Button onClick={handleParse}>{t('tool.parse')}</Button>
|
||||
<Button onClick={handleParse}>{t('tool.paste')}</Button>
|
||||
</Space>}
|
||||
>
|
||||
<Input.TextArea rows={10} placeholder={t('tool.jsonPlaceholder')} />
|
||||
</FormItem>
|
||||
<FormItem
|
||||
name="json_path"
|
||||
label={t('tool.json_path')}
|
||||
>
|
||||
<Input placeholder={t('common.pleaseEnter')} />
|
||||
</FormItem>
|
||||
|
||||
<Space size={8} className="rb:mb-3">
|
||||
<Button onClick={() => handleOperate('format')}>{t('tool.format')}</Button>
|
||||
<Button onClick={() => handleOperate('minify')}>{t('tool.minify')}</Button>
|
||||
<Button onClick={() => handleOperate('parse')}>{t('tool.parse')}</Button>
|
||||
</Space>
|
||||
<FormItem
|
||||
label={t('tool.outputResult')}
|
||||
|
||||
@@ -23,6 +23,7 @@ interface CurrentTimeObj {
|
||||
iso_format: string;
|
||||
timestamp: string;
|
||||
timestamp_ms: string;
|
||||
utc_datetime: string;
|
||||
}
|
||||
const TimeToolModal = forwardRef<TimeToolModalRef>((_props, ref) => {
|
||||
const { t } = useTranslation();
|
||||
@@ -88,8 +89,8 @@ const TimeToolModal = forwardRef<TimeToolModalRef>((_props, ref) => {
|
||||
}
|
||||
})
|
||||
.then(res => {
|
||||
const response = res as { data: CurrentTimeObj }
|
||||
setTimestampFormat(response.data.datetime)
|
||||
const response = res as { data: string }
|
||||
setTimestampFormat(response.data)
|
||||
})
|
||||
}
|
||||
const handleChangeFormatType = () => {
|
||||
@@ -149,7 +150,7 @@ const TimeToolModal = forwardRef<TimeToolModalRef>((_props, ref) => {
|
||||
<Input disabled value={currentTime?.datetime} />
|
||||
</FormItem>
|
||||
<FormItem label={t('tool.utcTime')} >
|
||||
<Input disabled value={currentTime?.iso_format} />
|
||||
<Input disabled value={currentTime?.utc_datetime} />
|
||||
</FormItem>
|
||||
<FormItem label={t('tool.secondsTimestamp')} >
|
||||
<Input disabled value={currentTime?.timestamp} />
|
||||
|
||||
@@ -10,10 +10,10 @@ export const InnerConfigData: Record<string, InnerConfigItem> = {
|
||||
},
|
||||
JsonTool: {
|
||||
features: [
|
||||
'jsonFormat',
|
||||
'jsonGzip',
|
||||
'jsonCheck',
|
||||
'jsonConversion'
|
||||
'jsonParse',
|
||||
'jsonInsert',
|
||||
'jsonReplace',
|
||||
'jsonDelete'
|
||||
],
|
||||
eg: '{"name":"工具","tool_class":"内置"}'
|
||||
},
|
||||
|
||||
@@ -130,6 +130,7 @@ export interface ExecuteData {
|
||||
ensure_ascii?: boolean;
|
||||
sort_keys?: boolean;
|
||||
input_data?: string;
|
||||
json_path?: string;
|
||||
}
|
||||
}
|
||||
export interface CustomToolModalRef {
|
||||
|
||||
@@ -1,127 +0,0 @@
|
||||
import { forwardRef, useImperativeHandle, useState } from 'react';
|
||||
import { Form, App } from 'antd';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { ConfigModalData, ConfigModalRef } from '../types'
|
||||
import { getWorkspaceModels, updateWorkspaceModels } from '@/api/workspaces'
|
||||
import { getModelListUrl } from '@/api/models'
|
||||
import CustomSelect from '@/components/CustomSelect'
|
||||
import RbModal from '@/components/RbModal'
|
||||
|
||||
const ConfigModal = forwardRef<ConfigModalRef>((_props, ref) => {
|
||||
const { t } = useTranslation();
|
||||
const { message } = App.useApp();
|
||||
const [visible, setVisible] = useState(false);
|
||||
const [form] = Form.useForm<ConfigModalData>();
|
||||
const [loading, setLoading] = useState(false)
|
||||
|
||||
const values = Form.useWatch([], form);
|
||||
|
||||
// 封装取消方法,添加关闭弹窗逻辑
|
||||
const handleClose = () => {
|
||||
setVisible(false);
|
||||
form.resetFields();
|
||||
setLoading(false)
|
||||
};
|
||||
|
||||
const handleOpen = () => {
|
||||
getWorkspaceModels().then((res) => {
|
||||
const { llm, embedding, rerank } = res as ConfigModalData
|
||||
form.setFieldsValue({
|
||||
llm,
|
||||
embedding,
|
||||
rerank
|
||||
})
|
||||
})
|
||||
setVisible(true);
|
||||
};
|
||||
// 封装保存方法,添加提交逻辑
|
||||
const handleSave = () => {
|
||||
form
|
||||
.validateFields()
|
||||
.then(() => {
|
||||
setLoading(true)
|
||||
updateWorkspaceModels(values)
|
||||
.then(() => {
|
||||
setLoading(false)
|
||||
handleClose()
|
||||
message.success(t('common.updateSuccess'))
|
||||
})
|
||||
.catch(() => {
|
||||
setLoading(false)
|
||||
});
|
||||
|
||||
handleClose()
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log('err', err)
|
||||
});
|
||||
}
|
||||
|
||||
// 暴露给父组件的方法
|
||||
useImperativeHandle(ref, () => ({
|
||||
handleOpen,
|
||||
handleClose
|
||||
}));
|
||||
|
||||
return (
|
||||
<RbModal
|
||||
title={t(`userMemory.editConfig`)}
|
||||
open={visible}
|
||||
onCancel={handleClose}
|
||||
okText={t('common.save')}
|
||||
onOk={handleSave}
|
||||
confirmLoading={loading}
|
||||
>
|
||||
<Form
|
||||
form={form}
|
||||
layout="vertical"
|
||||
>
|
||||
<Form.Item
|
||||
label={t('space.llmModel')}
|
||||
name="llm"
|
||||
rules={[{ required: true, message: t('common.pleaseSelect') }]}
|
||||
>
|
||||
<CustomSelect
|
||||
url={getModelListUrl}
|
||||
params={{ type: 'llm', pagesize: 100 }}
|
||||
valueKey="id"
|
||||
labelKey="name"
|
||||
hasAll={false}
|
||||
style={{width: '100%'}}
|
||||
/>
|
||||
</Form.Item>
|
||||
<Form.Item
|
||||
label={t('space.embeddingModel')}
|
||||
name="embedding"
|
||||
rules={[{ required: true, message: t('common.pleaseSelect') }]}
|
||||
>
|
||||
<CustomSelect
|
||||
url={getModelListUrl}
|
||||
params={{ type: 'embedding', pagesize: 100 }}
|
||||
valueKey="id"
|
||||
labelKey="name"
|
||||
hasAll={false}
|
||||
style={{width: '100%'}}
|
||||
/>
|
||||
</Form.Item>
|
||||
<Form.Item
|
||||
label={t('space.rerankModel')}
|
||||
name="rerank"
|
||||
rules={[{ required: true, message: t('common.pleaseSelect') }]}
|
||||
>
|
||||
<CustomSelect
|
||||
url={getModelListUrl}
|
||||
params={{ type: 'rerank', pagesize: 100 }}
|
||||
valueKey="id"
|
||||
labelKey="name"
|
||||
hasAll={false}
|
||||
style={{width: '100%'}}
|
||||
/>
|
||||
</Form.Item>
|
||||
</Form>
|
||||
</RbModal>
|
||||
);
|
||||
});
|
||||
|
||||
export default ConfigModal;
|
||||
@@ -1,56 +1,28 @@
|
||||
import { useEffect, useState, useRef } from 'react';
|
||||
import { useEffect, useState, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useNavigate } from 'react-router-dom'
|
||||
import { Row, Col, Radio, Button, List, Skeleton, Space } from 'antd';
|
||||
import type { ColumnsType } from 'antd/es/table';
|
||||
import type { RadioChangeEvent } from 'antd';
|
||||
import { AppstoreOutlined, MenuOutlined } from '@ant-design/icons';
|
||||
import { Row, Col, List, Skeleton } from 'antd';
|
||||
import Empty from '@/components/Empty'
|
||||
|
||||
import type { Data, ConfigModalRef } from './types'
|
||||
import totalNum from '@/assets/images/memory/totalNum.svg'
|
||||
import onlineNum from '@/assets/images/memory/onlineNum.svg'
|
||||
import Table from '@/components/Table'
|
||||
import { getTotalEndUsers, userMemoryListUrl, getUserMemoryList } from '@/api/memory';
|
||||
import ConfigModal from './components/ConfigModal';
|
||||
import type { Data } from './types'
|
||||
import { getUserMemoryList } from '@/api/memory';
|
||||
import { useUser } from '@/store/user'
|
||||
import RbCard from '@/components/RbCard/Card'
|
||||
import SearchInput from '@/components/SearchInput';
|
||||
|
||||
const bgList = [
|
||||
'linear-gradient( 180deg, #F1F6FE 0%, #FBFDFF 100%)',
|
||||
'linear-gradient( 180deg, #F1F9FE 0%, #FBFDFF 100%)',
|
||||
'linear-gradient( 180deg, #FEFBF7 0%, #FBFDFF 100%)',
|
||||
'linear-gradient( 180deg, #F1F9FE 0%, #FBFDFF 100%)',
|
||||
]
|
||||
|
||||
const countList = [
|
||||
'total_num', 'online_num',
|
||||
]
|
||||
const IconList: Record<string, string> = {
|
||||
total_num: totalNum,
|
||||
online_num: onlineNum,
|
||||
}
|
||||
export default function UserMemory() {
|
||||
const { t } = useTranslation();
|
||||
const navigate = useNavigate()
|
||||
const { storageType } = useUser()
|
||||
const configModalRef = useRef<ConfigModalRef>(null)
|
||||
const [loading, setLoading] = useState<boolean>(false);
|
||||
const [data, setData] = useState<Data[]>([]);
|
||||
const [countData, setCountData] = useState<Record<string, number>>({});
|
||||
const [layout, setLayout] = useState<'card' | 'list'>('card');
|
||||
const [search, setSearch] = useState<string | undefined>(undefined);
|
||||
|
||||
// 获取数据
|
||||
useEffect(() => {
|
||||
getCountData()
|
||||
getData()
|
||||
}, []);
|
||||
|
||||
// 用户记忆统计
|
||||
const getCountData = () => {
|
||||
getTotalEndUsers().then((res) => {
|
||||
setCountData(res as Record<string, number> || {})
|
||||
})
|
||||
}
|
||||
const getData = () => {
|
||||
setLoading(true)
|
||||
getUserMemoryList().then((res) => {
|
||||
@@ -60,7 +32,6 @@ export default function UserMemory() {
|
||||
setLoading(false)
|
||||
})
|
||||
}
|
||||
console.log('storageType', storageType)
|
||||
const handleViewDetail = (id: string | number) => {
|
||||
switch (storageType) {
|
||||
case 'neo4j':
|
||||
@@ -70,112 +41,77 @@ export default function UserMemory() {
|
||||
navigate(`/user-memory/${id}`)
|
||||
}
|
||||
}
|
||||
const handleChangeLayout = (e: RadioChangeEvent) => {
|
||||
const type = e.target.value
|
||||
setLayout(type)
|
||||
const handleViewMemoryConfig = () => {
|
||||
navigate(`/memory`)
|
||||
}
|
||||
// 表格列配置
|
||||
const columns: ColumnsType = [
|
||||
{
|
||||
title: t('userMemory.user'),
|
||||
dataIndex: 'end_user',
|
||||
key: 'end_user',
|
||||
render: (value) => value?.other_name && value?.other_name !== '' ? value?.other_name : value?.id || '-'
|
||||
},
|
||||
{
|
||||
title: t('userMemory.knowledgeEntryCount'),
|
||||
dataIndex: 'memory_num',
|
||||
key: 'memory_num',
|
||||
render: (value) => value?.total || 0
|
||||
},
|
||||
{
|
||||
title: t('common.operation'),
|
||||
key: 'action',
|
||||
render: (_, record) => (
|
||||
<Button
|
||||
type="link"
|
||||
onClick={() => handleViewDetail(record.end_user?.id)}
|
||||
>
|
||||
{t('common.viewDetail')}
|
||||
</Button>
|
||||
),
|
||||
},
|
||||
];
|
||||
|
||||
const filterData = useMemo(() => {
|
||||
if (search && search.trim() !== '') {
|
||||
return data.filter((item) => {
|
||||
const { end_user } = item as Data;
|
||||
const name = end_user?.other_name && end_user?.other_name !== '' ? end_user?.other_name : end_user?.id
|
||||
return name?.includes(search)
|
||||
})
|
||||
}
|
||||
|
||||
return data
|
||||
}, [search, data])
|
||||
|
||||
return (
|
||||
<div>
|
||||
<Row gutter={16} className="rb:mb-4">
|
||||
{countList.map(key => (
|
||||
<Col key={key} span={6}>
|
||||
<div className="rb:bg-[#FBFDFF] rb:border rb:border-[#DFE4ED] rb:rounded-xl rb:p-[18px_20px_20px_20px]">
|
||||
<div className="rb:text-[28px] rb:font-extrabold rb:leading-8.75 rb:flex rb:items-center rb:justify-between rb:mb-3">
|
||||
{countData[key] || 0}{key === 'avgInteractionTime' ? 's' : ''}
|
||||
<img className="rb:w-6 rb:h-6" src={IconList[key]} />
|
||||
</div>
|
||||
<div className="rb:text-[12px] rb:text-[#5B6167] rb:font-regular rb:leading-4">{t(`userMemory.${key}`)}</div>
|
||||
</div>
|
||||
</Col>
|
||||
))}
|
||||
<Col span={12} className="rb:text-right">
|
||||
<Space>
|
||||
<Button type="primary" onClick={() => configModalRef?.current?.handleOpen()}>{t('userMemory.chooseModel')}</Button>
|
||||
<Radio.Group value={layout} onChange={handleChangeLayout}>
|
||||
<Radio.Button value="card" disabled={layout === 'card'}><AppstoreOutlined /></Radio.Button>
|
||||
<Radio.Button value="list" disabled={layout === 'list'}><MenuOutlined /></Radio.Button>
|
||||
</Radio.Group>
|
||||
</Space>
|
||||
<Col span={8}>
|
||||
<SearchInput
|
||||
placeholder={t('userMemory.searchPlaceholder')}
|
||||
onSearch={(value) => setSearch(value)}
|
||||
style={{ width: '100%' }}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
{layout === 'card' &&
|
||||
<>
|
||||
{loading ?
|
||||
<Skeleton active />
|
||||
: data.length > 0 ? (
|
||||
<List
|
||||
grid={{ gutter: 16, column: 4 }}
|
||||
dataSource={data}
|
||||
renderItem={(item, index) => {
|
||||
const { end_user, memory_num } = item as Data;
|
||||
const name = end_user?.other_name && end_user?.other_name !== '' ? end_user?.other_name : end_user?.id
|
||||
return (
|
||||
<List.Item key={index}>
|
||||
<div
|
||||
className="rb:p-5 rb:rounded-xl rb:border rb:border-[#DFE4ED] rb:cursor-pointer"
|
||||
style={{
|
||||
background: bgList[index % bgList.length],
|
||||
}}
|
||||
{loading ?
|
||||
<Skeleton active />
|
||||
: filterData.length > 0 ? (
|
||||
<List
|
||||
grid={{ gutter: 16, column: 3 }}
|
||||
dataSource={filterData}
|
||||
renderItem={(item, index) => {
|
||||
const { end_user, memory_num, memory_config } = item as Data;
|
||||
const name = end_user?.other_name && end_user?.other_name !== '' ? end_user?.other_name : end_user?.id
|
||||
return (
|
||||
<List.Item key={index}>
|
||||
<RbCard
|
||||
avatar={<div className="rb:w-12 rb:h-12 rb:text-center rb:font-semibold rb:text-[28px] rb:leading-12 rb:rounded-lg rb:text-[#FBFDFF] rb:bg-[#155EEF] rb:mr-2">{name[0]}</div>}
|
||||
title={name || '-'}
|
||||
extra={<div
|
||||
className="rb:w-7 rb:h-7 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/userMemory/goto.svg')]"
|
||||
onClick={() => handleViewDetail(end_user.id)}
|
||||
>
|
||||
<div className="rb:flex rb:items-center">
|
||||
<div className="rb:w-12 rb:h-12 rb:text-center rb:font-semibold rb:text-[28px] rb:leading-12 rb:rounded-lg rb:text-[#FBFDFF] rb:bg-[#155EEF]">{name[0]}</div>
|
||||
<div className="rb:max-w-[calc(100%-60px)] rb:text-base rb:font-medium rb:leading-6 rb:ml-3 rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">
|
||||
{name || '-'}<br/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="rb:grid rb:grid-cols-1 rb:gap-3 rb:mt-7 rb:mb-7">
|
||||
<div className="rb:text-center">
|
||||
<div className="rb:text-[24px] rb:leading-7.5 rb:font-extrabold">{memory_num.total || 0}</div>
|
||||
<div className="rb:wrap-break-word">{t(`userMemory.knowledgeEntryCount`)}</div>
|
||||
</div>
|
||||
</div>
|
||||
></div>}
|
||||
>
|
||||
<div className="rb:flex rb:justify-between rb:items-center">
|
||||
<div>{t('userMemory.capacity')}</div>
|
||||
<div>{memory_num?.total || 0} {t('userMemory.memoryNum')}</div>
|
||||
</div>
|
||||
<div className="rb:flex rb:justify-between rb:items-center rb:mt-2.5">
|
||||
<div>{t('userMemory.type')}</div>
|
||||
<div>{t(`userMemory.${item.type || 'person'}`)}</div>
|
||||
</div>
|
||||
</List.Item>
|
||||
)
|
||||
}}
|
||||
/>
|
||||
) : <Empty />}
|
||||
</>
|
||||
}
|
||||
|
||||
{layout === 'list' &&
|
||||
<Table
|
||||
apiUrl={userMemoryListUrl}
|
||||
columns={columns}
|
||||
rowKey="end_user.id"
|
||||
pagination={false}
|
||||
/>
|
||||
<div className="rb:mt-3 rb:bg-[#F6F8FC] rb:rounded-lg rb:border rb:border-[#DFE4ED] rb:py-2 rb:px-3" onClick={handleViewMemoryConfig}>
|
||||
<div className="rb:text-[#5B6167] rb:leading-5 rb:flex rb:justify-between rb:items-center">
|
||||
{t('userMemory.memory_config_name')}
|
||||
<div
|
||||
className="rb:w-7 rb:h-7 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/userMemory/arrow_right.svg')]"
|
||||
></div>
|
||||
</div>
|
||||
<div className="rb:font-medium rb:leading-5 rb:mt-1">{memory_config?.memory_config_name || '-'}</div>
|
||||
</div>
|
||||
</RbCard>
|
||||
</List.Item>
|
||||
)
|
||||
}}
|
||||
/>
|
||||
) : <Empty />
|
||||
}
|
||||
<ConfigModal ref={configModalRef} />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -17,13 +17,10 @@ export interface Data {
|
||||
entity: number;
|
||||
}
|
||||
},
|
||||
memory_config: {
|
||||
memory_config_id: string;
|
||||
memory_config_name: string;
|
||||
},
|
||||
type: string;
|
||||
name?: string;
|
||||
}
|
||||
export interface ConfigModalData {
|
||||
llm: string;
|
||||
embedding: string;
|
||||
rerank: string;
|
||||
}
|
||||
export interface ConfigModalRef {
|
||||
handleOpen: () => void;
|
||||
}
|
||||
@@ -3,8 +3,7 @@ import { useTranslation } from 'react-i18next'
|
||||
import ReactEcharts from 'echarts-for-react';
|
||||
import Empty from '@/components/Empty'
|
||||
import Loading from '@/components/Empty/Loading'
|
||||
import type { Emotion } from './GraphDetail'
|
||||
import { format } from 'echarts';
|
||||
import type { Emotion } from '../pages/GraphDetail'
|
||||
|
||||
interface EmotionLineProps {
|
||||
chartData: Emotion[];
|
||||
|
||||
@@ -3,7 +3,7 @@ import { useTranslation } from 'react-i18next'
|
||||
import ReactEcharts from 'echarts-for-react'
|
||||
import Empty from '@/components/Empty'
|
||||
import Loading from '@/components/Empty/Loading'
|
||||
import type { Interaction } from './GraphDetail'
|
||||
import type { Interaction } from '../pages/GraphDetail'
|
||||
|
||||
interface InteractionBarProps {
|
||||
chartData: Interaction[];
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
import React, { type FC, useEffect, useState, useRef, useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useParams } from 'react-router-dom'
|
||||
import { useParams, useNavigate } from 'react-router-dom'
|
||||
import { Col, Row, Space, Button } from 'antd'
|
||||
import dayjs from 'dayjs'
|
||||
|
||||
import RbCard from '@/components/RbCard/Card'
|
||||
import ReactEcharts from 'echarts-for-react'
|
||||
import detailEmpty from '@/assets/images/userMemory/detail_empty.png'
|
||||
import type { Node, Edge, GraphData, StatementNodeProperties, ExtractedEntityNodeProperties, GraphDetailRef } from '../types'
|
||||
import type { Node, Edge, GraphData, StatementNodeProperties, ExtractedEntityNodeProperties } from '../types'
|
||||
import {
|
||||
getMemorySearchEdges,
|
||||
} from '@/api/memory'
|
||||
import Empty from '@/components/Empty'
|
||||
import Tag from '@/components/Tag'
|
||||
import GraphDetail from '../components/GraphDetail'
|
||||
|
||||
const colors = ['#155EEF', '#369F21', '#4DA8FF', '#FF5D34', '#9C6FFF', '#FF8A4C', '#8BAEF7', '#FFB048']
|
||||
const RelationshipNetwork:FC = () => {
|
||||
@@ -26,7 +25,7 @@ const RelationshipNetwork:FC = () => {
|
||||
const [categories, setCategories] = useState<{ name: string }[]>([])
|
||||
const [selectedNode, setSelectedNode] = useState<Node | null>(null)
|
||||
// const [fullScreen, setFullScreen] = useState<boolean>(false)
|
||||
const graphDetailRef = useRef<GraphDetailRef>(null)
|
||||
const navigate = useNavigate()
|
||||
|
||||
console.log('categories', categories)
|
||||
// 关系网络
|
||||
@@ -133,15 +132,14 @@ const RelationshipNetwork:FC = () => {
|
||||
}
|
||||
}, [nodes])
|
||||
|
||||
// const handleFullScreen = () => {
|
||||
// setFullScreen(prev => !prev)
|
||||
// }
|
||||
|
||||
console.log('selectedNode', selectedNode)
|
||||
|
||||
const handleViewAll = () => {
|
||||
if (!selectedNode) return
|
||||
graphDetailRef.current?.handleOpen(selectedNode)
|
||||
const params = new URLSearchParams({
|
||||
nodeId: selectedNode.id,
|
||||
nodeLabel: selectedNode.label,
|
||||
nodeName: selectedNode.name || ''
|
||||
})
|
||||
navigate(`/user-memory/detail/${id}/GRAPH?${params.toString()}`)
|
||||
}
|
||||
|
||||
return (
|
||||
@@ -336,8 +334,6 @@ const RelationshipNetwork:FC = () => {
|
||||
</div>
|
||||
</RbCard>
|
||||
</Col>
|
||||
|
||||
<GraphDetail ref={graphDetailRef} />
|
||||
</Row>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
import { useState, forwardRef, useImperativeHandle, useMemo } from 'react'
|
||||
import { useState, forwardRef, useImperativeHandle, useMemo, useEffect } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useSearchParams } from 'react-router-dom'
|
||||
import { Row, Col, Tabs, Space, Skeleton } from 'antd'
|
||||
|
||||
import { getRelationshipEvolution, getTimelineMemories } from '@/api/memory'
|
||||
import type { Node, GraphDetailRef } from '../types'
|
||||
import RbDrawer from '@/components/RbDrawer'
|
||||
import RbCard from '@/components/RbCard/Card'
|
||||
import EmotionLine from './EmotionLine'
|
||||
import EmotionLine from '../components/EmotionLine'
|
||||
import { formatDateTime } from '@/utils/format'
|
||||
import Tag from '@/components/Tag'
|
||||
import InteractionBar from './InteractionBar'
|
||||
import InteractionBar from '../components/InteractionBar'
|
||||
import Empty from '@/components/Empty'
|
||||
import PageHeader from '../components/PageHeader'
|
||||
|
||||
export interface Emotion {
|
||||
emotion_intensity: number;
|
||||
@@ -35,7 +36,7 @@ interface Timeline {
|
||||
|
||||
const GraphDetail = forwardRef<GraphDetailRef>((_props, ref) => {
|
||||
const { t } = useTranslation()
|
||||
const [open, setOpen] = useState(false);
|
||||
const [searchParams] = useSearchParams()
|
||||
const [vo, setVo] = useState<Node | null>(null)
|
||||
const [loading, setLoading] = useState(false)
|
||||
const [emotionData, setEmotionData] = useState<Emotion[]>([])
|
||||
@@ -43,14 +44,23 @@ const GraphDetail = forwardRef<GraphDetailRef>((_props, ref) => {
|
||||
const [activeTab, setActiveTab] = useState('timelines_memory')
|
||||
const [timelineLoading, setTimelineLoading] = useState(false)
|
||||
const [timelineMemories, setTimelineMemories] = useState<Timeline>({ timelines_memory: [], MemorySummary: [], Statement: [], ExtractedEntity: []})
|
||||
useEffect(() => {
|
||||
const nodeId = searchParams.get('nodeId')
|
||||
const nodeLabel = searchParams.get('nodeLabel')
|
||||
const nodeName = searchParams.get('nodeName')
|
||||
|
||||
if (nodeId && nodeLabel) {
|
||||
const nodeFromUrl = {
|
||||
id: nodeId,
|
||||
label: nodeLabel,
|
||||
name: nodeName || nodeLabel
|
||||
}
|
||||
handleOpen(nodeFromUrl as Node)
|
||||
}
|
||||
}, [searchParams])
|
||||
|
||||
const handleCancel = () => {
|
||||
setVo(null)
|
||||
setOpen(false)
|
||||
}
|
||||
const handleOpen = (vo: Node) => {
|
||||
setActiveTab('timelines_memory')
|
||||
setOpen(true)
|
||||
setVo(vo)
|
||||
getRelationshipEvolutionData(vo)
|
||||
getTimelineMemoriesData(vo)
|
||||
@@ -85,56 +95,57 @@ const GraphDetail = forwardRef<GraphDetailRef>((_props, ref) => {
|
||||
}, [activeTab, timelineMemories])
|
||||
|
||||
return (
|
||||
<RbDrawer
|
||||
title={vo?.name}
|
||||
open={open}
|
||||
onClose={handleCancel}
|
||||
width={1000}
|
||||
>
|
||||
<div className="rb:text-[16px] rb:font-medium rb:leading-5.5 rb:mb-3">{t('userMemory.relationshipEvolution')}</div>
|
||||
<RbCard>
|
||||
<Row gutter={16}>
|
||||
<Col span={12}>
|
||||
<EmotionLine chartData={emotionData} loading={loading} />
|
||||
</Col>
|
||||
<Col span={12}>
|
||||
<InteractionBar chartData={interactionData} loading={loading} />
|
||||
</Col>
|
||||
</Row>
|
||||
</RbCard>
|
||||
<>
|
||||
<PageHeader
|
||||
name={vo?.name}
|
||||
source="node"
|
||||
/>
|
||||
<div className="rb:h-full rb:max-w-266 rb:mx-auto">
|
||||
<div className="rb:text-[16px] rb:font-medium rb:leading-5.5 rb:mb-3">{t('userMemory.relationshipEvolution')}</div>
|
||||
<RbCard>
|
||||
<Row gutter={16}>
|
||||
<Col span={12}>
|
||||
<EmotionLine chartData={emotionData} loading={loading} />
|
||||
</Col>
|
||||
<Col span={12}>
|
||||
<InteractionBar chartData={interactionData} loading={loading} />
|
||||
</Col>
|
||||
</Row>
|
||||
</RbCard>
|
||||
|
||||
<div className="rb:text-[16px] rb:font-medium rb:leading-5.5 rb:mb-3 rb:mt-6">{t('userMemory.timelineMemories')}</div>
|
||||
<RbCard>
|
||||
<Tabs
|
||||
activeKey={activeTab}
|
||||
items={['timelines_memory', 'ExtractedEntity', 'Statement', 'MemorySummary'].map(key => ({
|
||||
label: t(`userMemory.${key}`),
|
||||
key
|
||||
}))}
|
||||
onChange={(key: string) => setActiveTab(key)}
|
||||
/>
|
||||
{timelineLoading
|
||||
? <Skeleton active />
|
||||
: !activeContent || activeContent.length === 0
|
||||
? <Empty size={120} className="rb:mt-12 rb:mb-20.25" />
|
||||
: <Space size={16} direction="vertical" className="rb:w-full">
|
||||
{activeContent.map((vo, index) => (
|
||||
<RbCard
|
||||
key={index}
|
||||
headerType="borderL"
|
||||
headerClassName="rb:before:bg-[#155EEF]!"
|
||||
title={vo.text}
|
||||
>
|
||||
<div className="rb:text-[#A8A9AA] rb:text-[12px] rb:leading-4">{formatDateTime(vo.created_at)}</div>
|
||||
<Tag className="rb:mt-2">{vo.type}</Tag>
|
||||
</RbCard>
|
||||
))}
|
||||
</Space>
|
||||
}
|
||||
<div className="rb:text-[16px] rb:font-medium rb:leading-5.5 rb:mb-3 rb:mt-6">{t('userMemory.timelineMemories')}</div>
|
||||
<RbCard>
|
||||
<Tabs
|
||||
activeKey={activeTab}
|
||||
items={['timelines_memory', 'Statement', 'MemorySummary'].map(key => ({
|
||||
label: t(`userMemory.${key}`),
|
||||
key
|
||||
}))}
|
||||
onChange={(key: string) => setActiveTab(key)}
|
||||
/>
|
||||
{timelineLoading
|
||||
? <Skeleton active />
|
||||
: !activeContent || activeContent.length === 0
|
||||
? <Empty size={120} className="rb:mt-12 rb:mb-20.25" />
|
||||
: <Space size={16} direction="vertical" className="rb:w-full">
|
||||
{activeContent.map((vo, index) => (
|
||||
<RbCard
|
||||
key={index}
|
||||
headerType="borderL"
|
||||
headerClassName="rb:before:bg-[#155EEF]!"
|
||||
title={vo.text}
|
||||
>
|
||||
<div className="rb:text-[#A8A9AA] rb:text-[12px] rb:leading-4">{formatDateTime(vo.created_at)}</div>
|
||||
<Tag className="rb:mt-2">{vo.type}</Tag>
|
||||
</RbCard>
|
||||
))}
|
||||
</Space>
|
||||
}
|
||||
|
||||
|
||||
</RbCard>
|
||||
</RbDrawer>
|
||||
|
||||
</RbCard>
|
||||
</div>
|
||||
</>
|
||||
)
|
||||
})
|
||||
export default GraphDetail
|
||||
@@ -1,7 +1,7 @@
|
||||
import { type FC, useEffect, useState, useMemo, useRef } from 'react'
|
||||
import { useParams, useNavigate } from 'react-router-dom'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Dropdown, Space, Button } from 'antd'
|
||||
import { Dropdown, Button } from 'antd'
|
||||
|
||||
import PageHeader from '../components/PageHeader'
|
||||
import StatementDetail from './StatementDetail'
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
getEndUserProfile,
|
||||
} from '@/api/memory'
|
||||
import refreshIcon from '@/assets/images/refresh_hover.svg'
|
||||
import GraphDetail from './GraphDetail'
|
||||
|
||||
const Detail: FC = () => {
|
||||
const { t } = useTranslation()
|
||||
@@ -47,6 +48,10 @@ const Detail: FC = () => {
|
||||
forgetDetailRef.current?.handleRefresh()
|
||||
}
|
||||
|
||||
if (type === 'GRAPH') {
|
||||
return <GraphDetail />
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="rb:h-full rb:w-full">
|
||||
<PageHeader
|
||||
|
||||
@@ -107,7 +107,7 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
|
||||
<div style={{ maxHeight: '300px', overflowY: 'auto', minWidth: '240px' }}>
|
||||
{nodeLibrary.map((category, categoryIndex) => {
|
||||
const filteredNodes = category.nodes.filter(nodeType =>
|
||||
nodeType.type !== 'start' && nodeType.type !== 'end' && nodeType.type !== 'loop' && nodeType.type !== 'cycle-start'
|
||||
nodeType.type !== 'start' && nodeType.type !== 'end' && nodeType.type !== 'iteration' && nodeType.type !== 'loop' && nodeType.type !== 'cycle-start'
|
||||
);
|
||||
|
||||
if (filteredNodes.length === 0) return null;
|
||||
|
||||
@@ -33,7 +33,7 @@ const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => {
|
||||
y: cycleStartBBox.y,
|
||||
data: {
|
||||
type: 'add-node',
|
||||
label: '添加节点',
|
||||
label: t('workflow.addNode'),
|
||||
icon: '+',
|
||||
parentId: node.id,
|
||||
cycle: data.id,
|
||||
@@ -61,7 +61,7 @@ const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => {
|
||||
},
|
||||
},
|
||||
},
|
||||
zIndex: 3
|
||||
zIndex: 10
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -97,7 +97,7 @@ const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => {
|
||||
y: centerY,
|
||||
data: {
|
||||
type: 'add-node',
|
||||
label: '添加节点',
|
||||
label: t('workflow.addNode'),
|
||||
icon: '+',
|
||||
parentId: node.id,
|
||||
cycle: data.id,
|
||||
@@ -128,7 +128,7 @@ const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => {
|
||||
},
|
||||
},
|
||||
},
|
||||
zIndex: 3
|
||||
zIndex: 10
|
||||
}
|
||||
|
||||
graph.addEdge(edgeConfig)
|
||||
|
||||
@@ -151,11 +151,11 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
|
||||
|
||||
let filteredNodes;
|
||||
if (isChildOfLoop) {
|
||||
// Use same filtering as AddNode for child nodes of loop
|
||||
// Use same filtering as AddNode for child nodes of loop, but allow break
|
||||
filteredNodes = category.nodes.filter(nodeType => !['start', 'end', 'loop', 'cycle-start', 'iteration'].includes(nodeType.type));
|
||||
} else if (isChildOfIteration) {
|
||||
// Filter out loop and iteration nodes for children of iteration nodes
|
||||
filteredNodes = category.nodes.filter(nodeType => !['start', 'end', 'loop', 'break', 'cycle-start', 'iteration'].includes(nodeType.type));
|
||||
// Filter out loop and iteration nodes for children of iteration nodes, but allow break
|
||||
filteredNodes = category.nodes.filter(nodeType => !['start', 'end', 'loop', 'cycle-start', 'iteration'].includes(nodeType.type));
|
||||
} else {
|
||||
// Original filtering for non-loop child nodes
|
||||
filteredNodes = category.nodes.filter(nodeType => !['start', 'end', 'break', 'cycle-start'].includes(nodeType.type));
|
||||
|
||||
@@ -60,7 +60,7 @@ const AssignmentList: FC<AssignmentListProps> = ({
|
||||
>
|
||||
<VariableSelect
|
||||
placeholder={t('common.pleaseSelect')}
|
||||
options={options}
|
||||
options={options.filter(vo => vo.nodeData.type === 'loop' || vo.value.includes('conv.'))}
|
||||
popupMatchSelectWidth={false}
|
||||
onChange={() => {
|
||||
form.setFieldValue([parentName, name, 'operation'], undefined);
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
import { type FC } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { Input, Button, Form, Space } from 'antd';
|
||||
import { PlusOutlined, CopyOutlined, DeleteOutlined, ExpandOutlined } from '@ant-design/icons';
|
||||
import { Button, Form, Space } from 'antd';
|
||||
import { DeleteOutlined } from '@ant-design/icons';
|
||||
import { Graph, Node } from '@antv/x6';
|
||||
import type { PortMetadata } from '@antv/x6/lib/model/port';
|
||||
import Editor from '../../Editor';
|
||||
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
|
||||
|
||||
interface CategoryListProps {
|
||||
parentName: string;
|
||||
options: Suggestion[];
|
||||
selectedNode?: Node | null;
|
||||
graphRef?: React.MutableRefObject<Graph | undefined>;
|
||||
}
|
||||
|
||||
const CategoryList: FC<CategoryListProps> = ({ parentName, selectedNode, graphRef }) => {
|
||||
const CategoryList: FC<CategoryListProps> = ({ parentName, selectedNode, graphRef, options }) => {
|
||||
const { t } = useTranslation();
|
||||
const form = Form.useFormInstance();
|
||||
const formValues = Form.useWatch([parentName], form);
|
||||
@@ -167,9 +169,9 @@ const CategoryList: FC<CategoryListProps> = ({ parentName, selectedNode, graphRe
|
||||
name={[name, 'class_name']}
|
||||
noStyle
|
||||
>
|
||||
<Input.TextArea
|
||||
<Editor
|
||||
placeholder={t('common.pleaseEnter')}
|
||||
rows={2}
|
||||
options={options}
|
||||
/>
|
||||
</Form.Item>
|
||||
</div>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { type FC } from 'react'
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { Form, Button, Select, Row, Col, InputNumber, Radio, type SelectProps } from 'antd'
|
||||
import { Form, Button, Select, Row, Col, InputNumber, Radio, Input, type SelectProps } from 'antd'
|
||||
import { DeleteOutlined } from '@ant-design/icons';
|
||||
|
||||
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
|
||||
@@ -114,7 +114,7 @@ const ConditionList: FC<CaseListProps> = ({
|
||||
<Col span={14}>
|
||||
<Form.Item name={[field.name, 'left']} noStyle>
|
||||
<VariableSelect
|
||||
options={options}
|
||||
options={options.filter(vo => vo.value.includes('sys.') || vo.value.includes('conv.') || vo.nodeData.type === 'loop')}
|
||||
size="small"
|
||||
allowClear={false}
|
||||
popupMatchSelectWidth={false}
|
||||
@@ -186,7 +186,7 @@ const ConditionList: FC<CaseListProps> = ({
|
||||
<Radio.Button value={true}>True</Radio.Button>
|
||||
<Radio.Button value={false}>False</Radio.Button>
|
||||
</Radio.Group>
|
||||
: <Editor options={options} />
|
||||
: <Input placeholder={t('common.pleaseEnter')} />
|
||||
}
|
||||
</Form.Item>
|
||||
</Col>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { type FC } from 'react'
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { Form, Button, Select, Row, Col, Input } from 'antd'
|
||||
import { Form, Select, Row, Col, Input } from 'antd'
|
||||
import { DeleteOutlined, PlusOutlined } from '@ant-design/icons';
|
||||
import VariableSelect from '../VariableSelect'
|
||||
|
||||
@@ -36,7 +36,6 @@ const CycleVarsList: FC<CycleVarsListProps> = ({
|
||||
value = [],
|
||||
options,
|
||||
parentName,
|
||||
onChange,
|
||||
selectedNode,
|
||||
graphRef
|
||||
}) => {
|
||||
@@ -139,12 +138,17 @@ const CycleVarsList: FC<CycleVarsListProps> = ({
|
||||
<Form.Item name={[name, 'value']} noStyle>
|
||||
{currentInputType === 'variable' ? (
|
||||
<VariableSelect
|
||||
placeholder="选择变量"
|
||||
options={availableOptions}
|
||||
placeholder={t('common.pleaseSelect')}
|
||||
options={availableOptions.filter(option => {
|
||||
const currentType = value?.[index]?.type;
|
||||
if (!currentType) return true;
|
||||
|
||||
return option.dataType === currentType
|
||||
})}
|
||||
/>
|
||||
) : (
|
||||
<Input.TextArea
|
||||
placeholder="输入值"
|
||||
placeholder={t('common.pleaseEnter')}
|
||||
rows={3}
|
||||
className="rb:w-full"
|
||||
/>
|
||||
|
||||
@@ -18,8 +18,22 @@ const GroupVariableList: FC<GroupVariableListProps> = ({
|
||||
isCanAdd = false
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
const form = Form.useFormInstance();
|
||||
const value = form.getFieldValue(name) || [];
|
||||
|
||||
console.log('GroupVariableList', value)
|
||||
|
||||
if (!isCanAdd) {
|
||||
// Filter options based on first variable's dataType if value exists
|
||||
let filteredOptions = options;
|
||||
if (value.length > 0) {
|
||||
const firstVariableValue = value[0];
|
||||
const firstVariable = options.find(opt => `{{${opt.value}}}` === firstVariableValue);
|
||||
if (firstVariable) {
|
||||
filteredOptions = options.filter(opt => opt.dataType === firstVariable.dataType);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="rb:mb-4">
|
||||
<Row gutter={12} className="rb:mb-2!">
|
||||
@@ -38,7 +52,7 @@ const GroupVariableList: FC<GroupVariableListProps> = ({
|
||||
>
|
||||
<VariableSelect
|
||||
placeholder={t('common.pleaseSelect')}
|
||||
options={options}
|
||||
options={filteredOptions}
|
||||
mode="multiple"
|
||||
/>
|
||||
</Form.Item>
|
||||
@@ -77,7 +91,18 @@ const GroupVariableList: FC<GroupVariableListProps> = ({
|
||||
>
|
||||
<VariableSelect
|
||||
placeholder={t('common.pleaseSelect')}
|
||||
options={options}
|
||||
options={(() => {
|
||||
const currentGroupValue = value[name]?.value || [];
|
||||
if (currentGroupValue.length > 0) {
|
||||
const firstVariableValue = currentGroupValue[0];
|
||||
const firstVariable = options.find(opt => `{{${opt.value}}}` === firstVariableValue);
|
||||
if (firstVariable) {
|
||||
return options.filter(opt => opt.dataType === firstVariable.dataType);
|
||||
}
|
||||
}
|
||||
return options;
|
||||
})()
|
||||
}
|
||||
mode="multiple"
|
||||
/>
|
||||
</Form.Item>
|
||||
|
||||
@@ -90,7 +90,7 @@ const HttpRequest: FC<{ options: Suggestion[]; selectedNode?: any; graphRef?: an
|
||||
</Col>
|
||||
<Col span={16}>
|
||||
<Form.Item name="url">
|
||||
<Editor options={options} variant="outlined" />
|
||||
<Editor options={options.filter(vo => vo.dataType === 'string' || vo.dataType === 'number')} variant="outlined" />
|
||||
</Form.Item>
|
||||
</Col>
|
||||
</Row>
|
||||
@@ -144,7 +144,7 @@ const HttpRequest: FC<{ options: Suggestion[]; selectedNode?: any; graphRef?: an
|
||||
<Form.Item name={['body', 'data']} noStyle>
|
||||
<EditableTable
|
||||
parentName={['body', 'data']}
|
||||
options={options}
|
||||
options={options.filter(vo => vo.dataType === 'string' || vo.dataType === 'number')}
|
||||
filterBooleanType={true}
|
||||
/>
|
||||
</Form.Item>
|
||||
@@ -154,7 +154,7 @@ const HttpRequest: FC<{ options: Suggestion[]; selectedNode?: any; graphRef?: an
|
||||
<MessageEditor
|
||||
key="json"
|
||||
parentName={['body', 'data']}
|
||||
options={options}
|
||||
options={options.filter(vo => vo.dataType === 'string' || vo.dataType === 'number')}
|
||||
isArray={false}
|
||||
title="JSON"
|
||||
/>
|
||||
|
||||
@@ -91,6 +91,7 @@ const VariableSelect: FC<VariableSelectProps> = ({
|
||||
showSearch
|
||||
allowClear={allowClear}
|
||||
filterOption={(input, option) => {
|
||||
if (input === '/') return true;
|
||||
if (option?.options) {
|
||||
return option.label?.toLowerCase().includes(input.toLowerCase()) ||
|
||||
option.options.some((opt: any) =>
|
||||
|
||||
@@ -22,6 +22,7 @@ import ConditionList from './ConditionList'
|
||||
import CycleVarsList from './CycleVarsList'
|
||||
import AssignmentList from './AssignmentList'
|
||||
import ToolConfig from './ToolConfig'
|
||||
// import { calculateVariableList } from './utils/variableListCalculator'
|
||||
|
||||
interface PropertiesProps {
|
||||
selectedNode?: Node | null;
|
||||
@@ -338,112 +339,35 @@ const Properties: FC<PropertiesProps> = ({
|
||||
const parentLoopNode = getParentLoopNode(selectedNode.id);
|
||||
|
||||
console.log('childNodeIds', selectedNode, childNodeIds)
|
||||
const allRelevantNodeIds = [...allPreviousNodeIds, ...childNodeIds];
|
||||
let allRelevantNodeIds = [...allPreviousNodeIds, ...childNodeIds];
|
||||
|
||||
// Add parent loop/iteration node variables if current node is a child
|
||||
// Add variables from nodes preceding the parent loop/iteration node if current node is a child
|
||||
if (parentLoopNode) {
|
||||
const parentData = parentLoopNode.getData();
|
||||
const parentNodeId = parentLoopNode.getData().id;
|
||||
|
||||
if (parentData.type === 'loop') {
|
||||
const cycleVars = parentData.cycle_vars || [];
|
||||
cycleVars.forEach((cycleVar: any) => {
|
||||
const key = `${parentNodeId}_cycle_${cycleVar.name}`;
|
||||
if (!addedKeys.has(key)) {
|
||||
addedKeys.add(key);
|
||||
variableList.push({
|
||||
key,
|
||||
label: cycleVar.name,
|
||||
type: 'variable',
|
||||
dataType: cycleVar.type || 'String',
|
||||
value: `${parentNodeId}.${cycleVar.name}`,
|
||||
nodeData: parentData,
|
||||
});
|
||||
}
|
||||
});
|
||||
} else if (parentData.type === 'iteration') {
|
||||
// Add item and index variables for iteration parent
|
||||
const itemKey = `${parentNodeId}_item`;
|
||||
const indexKey = `${parentNodeId}_index`;
|
||||
|
||||
if (!addedKeys.has(itemKey)) {
|
||||
addedKeys.add(itemKey);
|
||||
variableList.push({
|
||||
key: itemKey,
|
||||
label: 'item',
|
||||
type: 'variable',
|
||||
dataType: 'Object',
|
||||
value: `${parentNodeId}.item`,
|
||||
nodeData: parentData,
|
||||
});
|
||||
}
|
||||
|
||||
if (!addedKeys.has(indexKey)) {
|
||||
addedKeys.add(indexKey);
|
||||
variableList.push({
|
||||
key: indexKey,
|
||||
label: 'index',
|
||||
type: 'variable',
|
||||
dataType: 'Number',
|
||||
value: `${parentNodeId}.index`,
|
||||
nodeData: parentData,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Check if parent loop/iteration is connected to http-request via ERROR connection
|
||||
if (parentData.type === 'loop' || parentData.type === 'iteration') {
|
||||
const parentPreviousNodeIds = getAllPreviousNodes(parentLoopNode.id);
|
||||
parentPreviousNodeIds.forEach(prevNodeId => {
|
||||
const prevNode = nodes.find(n => n.id === prevNodeId);
|
||||
if (!prevNode) return;
|
||||
|
||||
const prevNodeData = prevNode.getData();
|
||||
if (prevNodeData.type === 'http-request') {
|
||||
// Check if connected via ERROR connection point
|
||||
const errorEdges = edges.filter(edge => {
|
||||
return edge.getTargetCellId() === parentLoopNode.id &&
|
||||
edge.getSourceCellId() === prevNodeId &&
|
||||
edge.getSourcePortId() === 'ERROR'
|
||||
});
|
||||
|
||||
if (errorEdges.length > 0) {
|
||||
const errorMessageKey = `${prevNodeData.id}_error_message`;
|
||||
const errorTypeKey = `${prevNodeData.id}_error_type`;
|
||||
|
||||
if (!addedKeys.has(errorMessageKey)) {
|
||||
addedKeys.add(errorMessageKey);
|
||||
variableList.push({
|
||||
key: errorMessageKey,
|
||||
label: 'error_message',
|
||||
type: 'variable',
|
||||
dataType: 'string',
|
||||
value: `${prevNodeData.id}.error_message`,
|
||||
nodeData: prevNodeData,
|
||||
});
|
||||
}
|
||||
|
||||
if (!addedKeys.has(errorTypeKey)) {
|
||||
addedKeys.add(errorTypeKey);
|
||||
variableList.push({
|
||||
key: errorTypeKey,
|
||||
label: 'error_type',
|
||||
type: 'variable',
|
||||
dataType: 'string',
|
||||
value: `${prevNodeData.id}.error_type`,
|
||||
nodeData: prevNodeData,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Add variables from nodes preceding the parent loop/iteration node
|
||||
const parentPreviousNodeIds = getAllPreviousNodes(parentLoopNode.id);
|
||||
allRelevantNodeIds.push(...parentPreviousNodeIds);
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Add conversation variables from global config
|
||||
const conversationVariables = workflowConfig?.variables || [];
|
||||
|
||||
conversationVariables.forEach((variable: any) => {
|
||||
const key = `CONVERSATION_${variable.name}`;
|
||||
if (!addedKeys.has(key)) {
|
||||
addedKeys.add(key);
|
||||
variableList.push({
|
||||
key,
|
||||
label: variable.name,
|
||||
type: 'variable',
|
||||
dataType: variable.type,
|
||||
value: `conv.${variable.name}`,
|
||||
nodeData: { type: 'CONVERSATION', name: 'CONVERSATION', icon: '' },
|
||||
group: 'CONVERSATION'
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
allRelevantNodeIds.forEach(nodeId => {
|
||||
const node = nodes.find(n => n.id === nodeId);
|
||||
if (!node) return;
|
||||
@@ -496,7 +420,7 @@ const Properties: FC<PropertiesProps> = ({
|
||||
key: llmKey,
|
||||
label: 'output',
|
||||
type: 'variable',
|
||||
dataType: 'String',
|
||||
dataType: 'string',
|
||||
value: `${dataNodeId}.output`,
|
||||
nodeData: nodeData,
|
||||
});
|
||||
@@ -565,6 +489,17 @@ const Properties: FC<PropertiesProps> = ({
|
||||
const groupVariables = nodeData.config.group_variables.defaultValue || [];
|
||||
groupVariables?.forEach((groupVar: any) => {
|
||||
if (!groupVar || !groupVar.key) return;
|
||||
|
||||
// Determine dataType from first variable in the group
|
||||
let groupDataType = 'string';
|
||||
if (groupVar.value && Array.isArray(groupVar.value) && groupVar.value.length > 0) {
|
||||
const firstVariableValue = groupVar.value[0];
|
||||
const firstVariable = variableList.find(v => `{{${v.value}}}` === firstVariableValue);
|
||||
if (firstVariable) {
|
||||
groupDataType = firstVariable.dataType;
|
||||
}
|
||||
}
|
||||
|
||||
const groupVarKey = `${dataNodeId}_${groupVar.key}`;
|
||||
if (!addedKeys.has(groupVarKey)) {
|
||||
addedKeys.add(groupVarKey);
|
||||
@@ -572,14 +507,26 @@ const Properties: FC<PropertiesProps> = ({
|
||||
key: groupVarKey,
|
||||
label: groupVar.key,
|
||||
type: 'variable',
|
||||
dataType: 'string',
|
||||
dataType: groupDataType,
|
||||
value: `${dataNodeId}.${groupVar.key}`,
|
||||
nodeData: nodeData,
|
||||
});
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// If group=false, add output variable
|
||||
// If group=false, add output variable with type from first group_variable
|
||||
const groupVariables = nodeData.config.group_variables.defaultValue || [];
|
||||
const firstVariable = groupVariables[0];
|
||||
let outputDataType: string = 'any';
|
||||
if (firstVariable) {
|
||||
const filterVo = [...variableList].find(v => {
|
||||
return `{{${v.value}}}` === firstVariable
|
||||
})
|
||||
if (filterVo) {
|
||||
outputDataType = filterVo?.dataType
|
||||
}
|
||||
}
|
||||
|
||||
const varAggregatorKey = `${dataNodeId}_output`;
|
||||
if (!addedKeys.has(varAggregatorKey)) {
|
||||
addedKeys.add(varAggregatorKey);
|
||||
@@ -587,7 +534,7 @@ const Properties: FC<PropertiesProps> = ({
|
||||
key: varAggregatorKey,
|
||||
label: 'output',
|
||||
type: 'variable',
|
||||
dataType: 'string',
|
||||
dataType: outputDataType,
|
||||
value: `${dataNodeId}.output`,
|
||||
nodeData: nodeData,
|
||||
});
|
||||
@@ -684,21 +631,20 @@ const Properties: FC<PropertiesProps> = ({
|
||||
nodeData: nodeData,
|
||||
});
|
||||
}
|
||||
if (!addedKeys.has(outputKey)) {
|
||||
addedKeys.add(outputKey);
|
||||
variableList.push({
|
||||
key: outputKey,
|
||||
label: 'output',
|
||||
type: 'variable',
|
||||
dataType: 'string',
|
||||
value: `${dataNodeId}.output`,
|
||||
nodeData: nodeData,
|
||||
});
|
||||
}
|
||||
// if (!addedKeys.has(outputKey)) {
|
||||
// addedKeys.add(outputKey);
|
||||
// variableList.push({
|
||||
// key: outputKey,
|
||||
// label: 'output',
|
||||
// type: 'variable',
|
||||
// dataType: 'string',
|
||||
// value: `${dataNodeId}.output`,
|
||||
// nodeData: nodeData,
|
||||
// });
|
||||
// }
|
||||
break
|
||||
case 'iteration':
|
||||
const iterationOutputKey = `${dataNodeId}_output`;
|
||||
const iterationItemKey = `${dataNodeId}_item`;
|
||||
if (!addedKeys.has(iterationOutputKey)) {
|
||||
addedKeys.add(iterationOutputKey);
|
||||
// Get the data type from the output configuration, default to string
|
||||
@@ -715,22 +661,11 @@ const Properties: FC<PropertiesProps> = ({
|
||||
key: iterationOutputKey,
|
||||
label: 'output',
|
||||
type: 'variable',
|
||||
dataType: outputDataType,
|
||||
dataType: `array[${outputDataType}]`,
|
||||
value: `${dataNodeId}.output`,
|
||||
nodeData: nodeData,
|
||||
});
|
||||
}
|
||||
if (!addedKeys.has(iterationItemKey)) {
|
||||
addedKeys.add(iterationItemKey);
|
||||
variableList.push({
|
||||
key: iterationItemKey,
|
||||
label: 'item',
|
||||
type: 'variable',
|
||||
dataType: 'string',
|
||||
value: `${dataNodeId}.item`,
|
||||
nodeData: nodeData,
|
||||
});
|
||||
}
|
||||
break
|
||||
case 'loop':
|
||||
const cycleVars = nodeData.config.cycle_vars.defaultValue || [];
|
||||
@@ -760,47 +695,337 @@ const Properties: FC<PropertiesProps> = ({
|
||||
key: toolDataKey,
|
||||
label: 'data',
|
||||
type: 'variable',
|
||||
dataType: 'object',
|
||||
dataType: 'string',
|
||||
value: `${dataNodeId}.data`,
|
||||
nodeData: nodeData,
|
||||
});
|
||||
}
|
||||
break
|
||||
case 'memory-read':
|
||||
const memoryReadAnswerKey = `${dataNodeId}_answer`;
|
||||
const memoryReadIntermediateOutputs = `${dataNodeId}_intermediate_outputs`;
|
||||
if (!addedKeys.has(memoryReadAnswerKey)) {
|
||||
addedKeys.add(memoryReadAnswerKey);
|
||||
variableList.push({
|
||||
key: memoryReadAnswerKey,
|
||||
label: 'answer',
|
||||
type: 'variable',
|
||||
dataType: 'string',
|
||||
value: `${dataNodeId}.answer`,
|
||||
nodeData: nodeData,
|
||||
});
|
||||
}
|
||||
if (!addedKeys.has(memoryReadIntermediateOutputs)) {
|
||||
addedKeys.add(memoryReadIntermediateOutputs);
|
||||
variableList.push({
|
||||
key: memoryReadIntermediateOutputs,
|
||||
label: 'intermediate_outputs',
|
||||
type: 'variable',
|
||||
dataType: 'array[object]',
|
||||
value: `${dataNodeId}.intermediate_outputs`,
|
||||
nodeData: nodeData,
|
||||
});
|
||||
}
|
||||
break
|
||||
}
|
||||
});
|
||||
|
||||
// Add conversation variables from global config
|
||||
const conversationVariables = workflowConfig?.variables || [];
|
||||
|
||||
conversationVariables.forEach((variable: any) => {
|
||||
const key = `CONVERSATION_${variable.name}`;
|
||||
if (!addedKeys.has(key)) {
|
||||
addedKeys.add(key);
|
||||
variableList.push({
|
||||
key,
|
||||
label: variable.name,
|
||||
type: 'variable',
|
||||
dataType: variable.type,
|
||||
value: `conv.${variable.name}`,
|
||||
nodeData: { type: 'CONVERSATION', name: 'CONVERSATION', icon: '' },
|
||||
group: 'CONVERSATION'
|
||||
|
||||
// Add parent loop/iteration node variables if current node is a child
|
||||
if (parentLoopNode) {
|
||||
const parentData = parentLoopNode.getData();
|
||||
const parentNodeId = parentLoopNode.getData().id;
|
||||
|
||||
if (parentData.type === 'loop') {
|
||||
const cycleVars = parentData.cycle_vars || [];
|
||||
cycleVars.forEach((cycleVar: any) => {
|
||||
const key = `${parentNodeId}_cycle_${cycleVar.name}`;
|
||||
if (!addedKeys.has(key)) {
|
||||
addedKeys.add(key);
|
||||
variableList.push({
|
||||
key,
|
||||
label: cycleVar.name,
|
||||
type: 'variable',
|
||||
dataType: cycleVar.type || 'String',
|
||||
value: `${parentNodeId}.${cycleVar.name}`,
|
||||
nodeData: parentData,
|
||||
});
|
||||
}
|
||||
});
|
||||
} else if (parentData.type === 'iteration') {
|
||||
// Add item and index variables for iteration parent only if input has value
|
||||
if (parentData.config.input.defaultValue) {
|
||||
const itemKey = `${parentNodeId}_item`;
|
||||
const indexKey = `${parentNodeId}_index`;
|
||||
|
||||
// Determine item dataType from input variable
|
||||
let itemDataType = 'object';
|
||||
const inputVariable = variableList.find(v => `{{${v.value}}}` === parentData.config.input.defaultValue);
|
||||
console.log('itemDataType defaultValue', parentData.config.input.defaultValue, variableList, inputVariable)
|
||||
if (inputVariable && inputVariable.dataType.startsWith('array[')) {
|
||||
itemDataType = inputVariable.dataType.replace(/^array\[(.+)\]$/, '$1');
|
||||
console.log('itemDataType', itemDataType)
|
||||
}
|
||||
|
||||
|
||||
if (!addedKeys.has(itemKey)) {
|
||||
addedKeys.add(itemKey);
|
||||
variableList.push({
|
||||
key: itemKey,
|
||||
label: 'item',
|
||||
type: 'variable',
|
||||
dataType: itemDataType,
|
||||
value: `${parentNodeId}.item`,
|
||||
nodeData: parentData,
|
||||
});
|
||||
}
|
||||
|
||||
if (!addedKeys.has(indexKey)) {
|
||||
addedKeys.add(indexKey);
|
||||
variableList.push({
|
||||
key: indexKey,
|
||||
label: 'index',
|
||||
type: 'variable',
|
||||
dataType: 'number',
|
||||
value: `${parentNodeId}.index`,
|
||||
nodeData: parentData,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if parent loop/iteration is connected to http-request via ERROR connection
|
||||
if (parentData.type === 'loop' || parentData.type === 'iteration') {
|
||||
const parentPreviousNodeIds = getAllPreviousNodes(parentLoopNode.id);
|
||||
parentPreviousNodeIds.forEach(prevNodeId => {
|
||||
const prevNode = nodes.find(n => n.id === prevNodeId);
|
||||
if (!prevNode) return;
|
||||
|
||||
const prevNodeData = prevNode.getData();
|
||||
if (prevNodeData.type === 'http-request') {
|
||||
// Check if connected via ERROR connection point
|
||||
const errorEdges = edges.filter(edge => {
|
||||
return edge.getTargetCellId() === parentLoopNode.id &&
|
||||
edge.getSourceCellId() === prevNodeId &&
|
||||
edge.getSourcePortId() === 'ERROR'
|
||||
});
|
||||
|
||||
if (errorEdges.length > 0) {
|
||||
const errorMessageKey = `${prevNodeData.id}_error_message`;
|
||||
const errorTypeKey = `${prevNodeData.id}_error_type`;
|
||||
|
||||
if (!addedKeys.has(errorMessageKey)) {
|
||||
addedKeys.add(errorMessageKey);
|
||||
variableList.push({
|
||||
key: errorMessageKey,
|
||||
label: 'error_message',
|
||||
type: 'variable',
|
||||
dataType: 'string',
|
||||
value: `${prevNodeData.id}.error_message`,
|
||||
nodeData: prevNodeData,
|
||||
});
|
||||
}
|
||||
|
||||
if (!addedKeys.has(errorTypeKey)) {
|
||||
addedKeys.add(errorTypeKey);
|
||||
variableList.push({
|
||||
key: errorTypeKey,
|
||||
label: 'error_type',
|
||||
type: 'variable',
|
||||
dataType: 'string',
|
||||
value: `${prevNodeData.id}.error_type`,
|
||||
nodeData: prevNodeData,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return variableList;
|
||||
}, [selectedNode, graphRef, workflowConfig?.variables]);
|
||||
|
||||
// Filter out boolean type variables for loop and llm nodes
|
||||
const getFilteredVariableList = (nodeType?: string) => {
|
||||
if (nodeType === 'loop' || nodeType === 'llm') {
|
||||
return variableList.filter(variable => variable.dataType !== 'boolean');
|
||||
const getFilteredVariableList = (nodeType?: string, key?: string) => {
|
||||
// Check if current node is a child of iteration node
|
||||
const parentIterationNode = selectedNode ? (() => {
|
||||
const nodes = graphRef.current?.getNodes() || [];
|
||||
const nodeData = selectedNode.getData();
|
||||
const cycle = nodeData?.cycle;
|
||||
|
||||
if (cycle) {
|
||||
const parentNode = nodes.find(n => n.getData().id === cycle);
|
||||
if (parentNode) {
|
||||
const parentData = parentNode.getData();
|
||||
if (parentData?.type === 'iteration') {
|
||||
return parentNode;
|
||||
}
|
||||
}
|
||||
}
|
||||
return null;
|
||||
})() : null;
|
||||
|
||||
// Helper function to add parent iteration variables
|
||||
const addParentIterationVars = (filteredList: any[]) => {
|
||||
if (parentIterationNode) {
|
||||
const parentData = parentIterationNode.getData();
|
||||
const parentNodeId = parentData.id;
|
||||
|
||||
if (parentData.config?.input?.defaultValue) {
|
||||
const itemKey = `${parentNodeId}_item`;
|
||||
const indexKey = `${parentNodeId}_index`;
|
||||
|
||||
const existingItemVar = filteredList.find(v => v.key === itemKey);
|
||||
const existingIndexVar = filteredList.find(v => v.key === indexKey);
|
||||
|
||||
if (!existingItemVar) {
|
||||
// Determine item dataType from input variable
|
||||
let itemDataType = 'object';
|
||||
const inputVariable = variableList.find(v => `{{${v.value}}}` === parentData.config.input.defaultValue);
|
||||
if (inputVariable && inputVariable.dataType.startsWith('array[')) {
|
||||
itemDataType = inputVariable.dataType.replace(/^array\[(.+)\]$/, '$1');
|
||||
}
|
||||
|
||||
filteredList.push({
|
||||
key: itemKey,
|
||||
label: 'item',
|
||||
type: 'variable',
|
||||
dataType: itemDataType,
|
||||
value: `${parentNodeId}.item`,
|
||||
nodeData: parentData,
|
||||
});
|
||||
}
|
||||
|
||||
if (!existingIndexVar) {
|
||||
filteredList.push({
|
||||
key: indexKey,
|
||||
label: 'index',
|
||||
type: 'variable',
|
||||
dataType: 'number',
|
||||
value: `${parentNodeId}.index`,
|
||||
nodeData: parentData,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
return filteredList;
|
||||
};
|
||||
|
||||
if (nodeType === 'llm') {
|
||||
// For LLM nodes that are children of iteration or loop nodes, include parent variables
|
||||
const parentLoopNode = selectedNode ? (() => {
|
||||
const nodes = graphRef.current?.getNodes() || [];
|
||||
const nodeData = selectedNode.getData();
|
||||
const cycle = nodeData?.cycle;
|
||||
|
||||
if (cycle) {
|
||||
const parentNode = nodes.find(n => n.getData().id === cycle);
|
||||
if (parentNode) {
|
||||
const parentData = parentNode.getData();
|
||||
if (parentData?.type === 'loop' || parentData?.type === 'iteration') {
|
||||
return parentNode;
|
||||
}
|
||||
}
|
||||
}
|
||||
return null;
|
||||
})() : null;
|
||||
|
||||
let filteredList = variableList.filter(variable => variable.dataType !== 'boolean');
|
||||
|
||||
// If this LLM node is a child of iteration/loop, ensure parent variables are included
|
||||
if (parentLoopNode) {
|
||||
const parentData = parentLoopNode.getData();
|
||||
const parentNodeId = parentData.id;
|
||||
|
||||
// Ensure parent loop/iteration variables are included
|
||||
if (parentData.type === 'loop') {
|
||||
const cycleVars = parentData.cycle_vars || [];
|
||||
cycleVars.forEach((cycleVar: any) => {
|
||||
const key = `${parentNodeId}_cycle_${cycleVar.name}`;
|
||||
const existingVar = filteredList.find(v => v.key === key);
|
||||
if (!existingVar && cycleVar.name && cycleVar.type !== 'boolean') {
|
||||
filteredList.push({
|
||||
key,
|
||||
label: cycleVar.name,
|
||||
type: 'variable',
|
||||
dataType: cycleVar.type || 'String',
|
||||
value: `${parentNodeId}.${cycleVar.name}`,
|
||||
nodeData: parentData,
|
||||
});
|
||||
}
|
||||
});
|
||||
} else if (parentData.type === 'iteration') {
|
||||
// Add item and index variables for iteration parent
|
||||
if (parentData.config?.input?.defaultValue) {
|
||||
const itemKey = `${parentNodeId}_item`;
|
||||
const indexKey = `${parentNodeId}_index`;
|
||||
|
||||
const existingItemVar = filteredList.find(v => v.key === itemKey);
|
||||
const existingIndexVar = filteredList.find(v => v.key === indexKey);
|
||||
|
||||
if (!existingItemVar) {
|
||||
// Determine item dataType from input variable
|
||||
let itemDataType = 'object';
|
||||
const inputVariable = variableList.find(v => `{{${v.value}}}` === parentData.config.input.defaultValue);
|
||||
if (inputVariable && inputVariable.dataType.startsWith('array[')) {
|
||||
itemDataType = inputVariable.dataType.replace(/^array\[(.+)\]$/, '$1');
|
||||
}
|
||||
|
||||
filteredList.push({
|
||||
key: itemKey,
|
||||
label: 'item',
|
||||
type: 'variable',
|
||||
dataType: itemDataType,
|
||||
value: `${parentNodeId}.item`,
|
||||
nodeData: parentData,
|
||||
});
|
||||
}
|
||||
|
||||
if (!existingIndexVar) {
|
||||
filteredList.push({
|
||||
key: indexKey,
|
||||
label: 'index',
|
||||
type: 'variable',
|
||||
dataType: 'Number',
|
||||
value: `${parentNodeId}.index`,
|
||||
nodeData: parentData,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return filteredList;
|
||||
}
|
||||
return variableList;
|
||||
if (nodeType === 'knowledge-retrieval' || nodeType === 'parameter-extractor' && key !== 'prompt' || nodeType === 'memory-read' || nodeType === 'memory-write' || nodeType === 'question-classifier') {
|
||||
let filteredList = variableList.filter(variable => variable.dataType === 'string');
|
||||
return addParentIterationVars(filteredList);
|
||||
}
|
||||
if (nodeType === 'parameter-extractor' && key === 'prompt') {
|
||||
let filteredList = variableList.filter(variable => variable.dataType === 'string' || variable.dataType === 'number');
|
||||
return addParentIterationVars(filteredList);
|
||||
}
|
||||
if (nodeType === 'iteration' && key === 'output') {
|
||||
return variableList.filter(variable => variable.value.includes('sys.'));
|
||||
}
|
||||
if (nodeType === 'iteration') {
|
||||
return variableList.filter(variable => variable.dataType.includes('array'));
|
||||
}
|
||||
if (nodeType === 'loop' && key === 'condition') {
|
||||
let filteredList = variableList.filter(variable => variable.nodeData.type !== 'loop');
|
||||
return addParentIterationVars(filteredList);
|
||||
}
|
||||
|
||||
// For all other node types, add parent iteration variables if applicable
|
||||
let baseList = variableList;
|
||||
return addParentIterationVars(baseList);
|
||||
};
|
||||
|
||||
// const defaultVariableList = calculateVariableList(selectedNode as Node, graphRef, workflowConfig )
|
||||
|
||||
console.log('values', values)
|
||||
console.log('variableList', variableList, selectedNode?.data)
|
||||
// console.log('variableList', variableList, defaultVariableList)
|
||||
|
||||
return (
|
||||
<div className="rb:w-75 rb:fixed rb:right-0 rb:top-16 rb:bottom-0 rb:p-3">
|
||||
@@ -901,11 +1126,10 @@ const Properties: FC<PropertiesProps> = ({
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Form.Item key={key} name={key}>
|
||||
<MessageEditor
|
||||
key={key}
|
||||
key={key}
|
||||
options={contextVariableList.filter(variable => variable.nodeData?.type !== 'knowledge-retrieval')}
|
||||
parentName={key}
|
||||
/>
|
||||
@@ -915,7 +1139,12 @@ const Properties: FC<PropertiesProps> = ({
|
||||
if (selectedNode?.data?.type === 'end' && key === 'output') {
|
||||
return (
|
||||
<Form.Item key={key} name={key}>
|
||||
<MessageEditor key={key} isArray={false} parentName={key} options={variableList} />
|
||||
<MessageEditor
|
||||
key={key}
|
||||
isArray={false}
|
||||
parentName={key}
|
||||
options={variableList.filter(variable => variable.nodeData?.type !== 'knowledge-retrieval')}
|
||||
/>
|
||||
</Form.Item>
|
||||
)
|
||||
}
|
||||
@@ -943,7 +1172,7 @@ const Properties: FC<PropertiesProps> = ({
|
||||
isArray={!!config.isArray}
|
||||
parentName={key}
|
||||
enableJinja2={config.enableJinja2 as boolean}
|
||||
options={getFilteredVariableList(selectedNode?.data?.type)}
|
||||
options={getFilteredVariableList(selectedNode?.data?.type, key)}
|
||||
/>
|
||||
</Form.Item>
|
||||
)
|
||||
@@ -964,7 +1193,7 @@ const Properties: FC<PropertiesProps> = ({
|
||||
<Form.Item key={key} name={key}>
|
||||
<GroupVariableList
|
||||
name={key}
|
||||
options={getFilteredVariableList(selectedNode?.data?.type)}
|
||||
options={getFilteredVariableList(selectedNode?.data?.type, key)}
|
||||
isCanAdd={!!(values as any)?.group}
|
||||
/>
|
||||
</Form.Item>
|
||||
@@ -976,7 +1205,7 @@ const Properties: FC<PropertiesProps> = ({
|
||||
<Form.Item key={key} name={key}>
|
||||
<CaseList
|
||||
name={key}
|
||||
options={getFilteredVariableList(selectedNode?.data?.type)}
|
||||
options={getFilteredVariableList(selectedNode?.data?.type, key)}
|
||||
selectedNode={selectedNode}
|
||||
graphRef={graphRef}
|
||||
/>
|
||||
@@ -989,7 +1218,7 @@ const Properties: FC<PropertiesProps> = ({
|
||||
<Form.Item key={key} name={key}
|
||||
label={t(`workflow.config.${selectedNode?.data?.type}.${key}`)}
|
||||
>
|
||||
<MappingList name={key} options={getFilteredVariableList(selectedNode?.data?.type)} />
|
||||
<MappingList name={key} options={getFilteredVariableList(selectedNode?.data?.type, key)} />
|
||||
</Form.Item>
|
||||
|
||||
)
|
||||
@@ -999,7 +1228,7 @@ const Properties: FC<PropertiesProps> = ({
|
||||
<Form.Item key={key} name={key}>
|
||||
<CycleVarsList
|
||||
parentName={key}
|
||||
options={getFilteredVariableList(selectedNode?.data?.type)}
|
||||
options={getFilteredVariableList(selectedNode?.data?.type, key)}
|
||||
/>
|
||||
</Form.Item>
|
||||
)
|
||||
@@ -1013,9 +1242,9 @@ const Properties: FC<PropertiesProps> = ({
|
||||
if (config.filterLoopIterationVars) {
|
||||
const loopIterationVars: Suggestion[] = [];
|
||||
|
||||
return [...getFilteredVariableList(selectedNode?.data?.type), ...loopIterationVars];
|
||||
return [...getFilteredVariableList(selectedNode?.data?.type, key), ...loopIterationVars];
|
||||
}
|
||||
return getFilteredVariableList(selectedNode?.data?.type);
|
||||
return getFilteredVariableList(selectedNode?.data?.type, key);
|
||||
})()
|
||||
}
|
||||
/>
|
||||
@@ -1060,7 +1289,7 @@ const Properties: FC<PropertiesProps> = ({
|
||||
? <VariableSelect
|
||||
placeholder={t('common.pleaseSelect')}
|
||||
options={(() => {
|
||||
const baseVariableList = getFilteredVariableList(selectedNode?.data?.type);
|
||||
const baseVariableList = getFilteredVariableList(selectedNode?.data?.type, key);
|
||||
// Apply filtering if specified in config
|
||||
if (config.filterNodeTypes || config.filterVariableNames) {
|
||||
return baseVariableList.filter(variable => {
|
||||
@@ -1068,7 +1297,7 @@ const Properties: FC<PropertiesProps> = ({
|
||||
(Array.isArray(config.filterNodeTypes) && config.filterNodeTypes.includes(variable.nodeData?.type));
|
||||
const variableNameMatch = !config.filterVariableNames ||
|
||||
(Array.isArray(config.filterVariableNames) && config.filterVariableNames.includes(variable.label));
|
||||
return nodeTypeMatch && variableNameMatch;
|
||||
return nodeTypeMatch || variableNameMatch;
|
||||
});
|
||||
}
|
||||
// Filter child nodes for iteration output
|
||||
@@ -1085,7 +1314,7 @@ const Properties: FC<PropertiesProps> = ({
|
||||
});
|
||||
|
||||
return baseVariableList.filter(variable =>
|
||||
childNodes.some(node => node.id === variable.nodeData?.id)
|
||||
childNodes.some(node => node.id === variable.nodeData?.id) || selectedNode?.data?.type === 'iteration' && key === 'output' && variable.value.includes('sys.')
|
||||
);
|
||||
}
|
||||
return baseVariableList;
|
||||
@@ -1095,7 +1324,12 @@ const Properties: FC<PropertiesProps> = ({
|
||||
: config.type === 'switch'
|
||||
? <Switch onChange={key === 'group' ? () => { form.setFieldValue('group_variables', []) } : undefined} />
|
||||
: config.type === 'categoryList'
|
||||
? <CategoryList parentName={key} selectedNode={selectedNode} graphRef={graphRef} />
|
||||
? <CategoryList
|
||||
parentName={key}
|
||||
selectedNode={selectedNode}
|
||||
graphRef={graphRef}
|
||||
options={getFilteredVariableList(selectedNode?.data?.type, key)}
|
||||
/>
|
||||
: config.type === 'conditionList'
|
||||
? <ConditionList
|
||||
parentName={key}
|
||||
@@ -1109,18 +1343,9 @@ const Properties: FC<PropertiesProps> = ({
|
||||
value: `${selectedNode.getData().id}.${cycleVar.name}`,
|
||||
nodeData: selectedNode.getData(),
|
||||
}));
|
||||
return [...variableList.filter(variable => {
|
||||
// Keep conversation variables
|
||||
if (variable.group === 'CONVERSATION') return true;
|
||||
// Keep sys variables from start nodes
|
||||
if (variable.nodeData?.type === 'start' && variable.value?.startsWith('sys.')) return true;
|
||||
// Keep variables from non-start nodes
|
||||
if (variable.nodeData?.type !== 'start' && variable.nodeData?.type !== 'http-request' && variable.dataType !== 'boolean') return true;
|
||||
// Filter out custom variables from start nodes
|
||||
return false;
|
||||
}), ...cycleVarSuggestions];
|
||||
})()
|
||||
}
|
||||
|
||||
return [...getFilteredVariableList(selectedNode?.data?.type, key), ...cycleVarSuggestions];
|
||||
})()}
|
||||
selectedNode={selectedNode}
|
||||
graphRef={graphRef}
|
||||
addBtnText={t('workflow.config.addCase')}
|
||||
|
||||
@@ -270,7 +270,7 @@ export const nodeLibrary: NodeLibrary[] = [
|
||||
config: {
|
||||
input: {
|
||||
type: 'variableList',
|
||||
filterNodeTypes: ['knowledge-retrieval'],
|
||||
filterNodeTypes: ['knowledge-retrieval', 'iteration', 'loop'],
|
||||
filterVariableNames: ['message']
|
||||
},
|
||||
parallel: {
|
||||
@@ -334,8 +334,7 @@ export const nodeLibrary: NodeLibrary[] = [
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
type: "assigner", icon: assignerIcon,
|
||||
{ type: "assigner", icon: assignerIcon,
|
||||
config: {
|
||||
assignments: {
|
||||
type: 'assignmentList',
|
||||
@@ -656,4 +655,114 @@ export const graphNodeLibrary: Record<string, NodeConfig> = {
|
||||
items: [{ group: 'left' }],
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
export interface OutputVariable {
|
||||
default?: Array<{
|
||||
name: string;
|
||||
type: string;
|
||||
}>;
|
||||
define?: string[];
|
||||
sys?: Array<{
|
||||
name: string;
|
||||
type: string;
|
||||
}>;
|
||||
error?: Array<{
|
||||
name: string;
|
||||
type: string;
|
||||
}>;
|
||||
}
|
||||
export const outputVariable: { [key: string]: OutputVariable } = {
|
||||
start: {
|
||||
sys: [
|
||||
{ name: "message", type: "string" },
|
||||
{ name: "conversation_id", type: "string" },
|
||||
{ name: "execution_id", type: "string", },
|
||||
{ name: "workspace_id", type: "string" },
|
||||
{ name: "user_id", type: "string" },
|
||||
],
|
||||
define: ['variables']
|
||||
},
|
||||
end: {
|
||||
},
|
||||
llm: {
|
||||
default: [
|
||||
{ name: "output", type: "string" },
|
||||
]
|
||||
},
|
||||
'knowledge-retrieval': {
|
||||
default: [
|
||||
{ name: "output", type: "array[object]" },
|
||||
]
|
||||
},
|
||||
'parameter-extractor': {
|
||||
default: [
|
||||
{ name: "__is_success", type: "number" },
|
||||
{ name: "__reason", type: "string" },
|
||||
],
|
||||
define: ['params']
|
||||
},
|
||||
'memory-read': {
|
||||
default: [
|
||||
{ name: "answer", type: "string" },
|
||||
{ name: "intermediate_outputs", type: "array[object]" },
|
||||
],
|
||||
},
|
||||
'memory-write': {
|
||||
|
||||
},
|
||||
'if-else': {
|
||||
|
||||
},
|
||||
'question-classifier': {
|
||||
default: [
|
||||
{ name: "class_name", type: "string" },
|
||||
// { name: "output", type: "string" },
|
||||
],
|
||||
},
|
||||
'iteration': {
|
||||
default: [
|
||||
// { name: "item", type: "string" }, // 仅内部使用
|
||||
{ name: "output", type: "array[string]" },
|
||||
],
|
||||
},
|
||||
'loop': {
|
||||
define: ['cycle_vars']
|
||||
},
|
||||
'cycle-start': {
|
||||
|
||||
},
|
||||
'break': {
|
||||
|
||||
},
|
||||
'var-aggregator': {
|
||||
// default: [
|
||||
// { name: "output", type: "string" },
|
||||
// ],
|
||||
define: ['group_variables']
|
||||
},
|
||||
'assigner': {
|
||||
|
||||
},
|
||||
'http-request': {
|
||||
default: [
|
||||
{ name: "body", type: "string" },
|
||||
{ name: "status_code", type: "number" },
|
||||
],
|
||||
error: [
|
||||
{ name: "error_message", type: "string" },
|
||||
{ name: "error_type", type: "string" },
|
||||
]
|
||||
},
|
||||
'tool': {
|
||||
default: [
|
||||
{ name: "data", type: "string" },
|
||||
],
|
||||
},
|
||||
'jinja-render': {
|
||||
default: [
|
||||
{ name: "output", type: "string" },
|
||||
],
|
||||
},
|
||||
}
|
||||
Reference in New Issue
Block a user