Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop

This commit is contained in:
yujiangping
2026-01-14 12:07:28 +08:00
68 changed files with 1697 additions and 800 deletions

View File

@@ -60,14 +60,14 @@ def list_apps(
"""
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
# 当 ids 存在且不为 None 时,根据 ids 获取应用
if ids is not None:
app_ids = [id.strip() for id in ids.split(',') if id.strip()]
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
return success(data=items)
# 正常分页查询
items_orm, total = app_service.list_apps(
db,

View File

@@ -30,7 +30,7 @@ from sqlalchemy.orm import Session
api_logger = get_api_logger()
router = APIRouter(
prefix="/memory/emotion",
prefix="/memory/emotion-memory",
tags=["Emotion Analysis"],
dependencies=[Depends(get_current_user)] # 所有路由都需要认证
)

View File

@@ -39,7 +39,7 @@ from app.services.memory_forget_service import MemoryForgetService
api_logger = get_api_logger()
router = APIRouter(
prefix="/memory/forget",
prefix="/memory/forget-memory",
tags=["Memory Forgetting Engine"],
dependencies=[Depends(get_current_user)] # 所有路由都需要认证
)

View File

@@ -842,7 +842,7 @@ async def run_hybrid_search(
if search_type in ["keyword", "hybrid"]:
# Keyword-based search
logger.info("Starting keyword search...")
logger.info("[PERF] Starting keyword search...")
keyword_start = time.time()
keyword_task = asyncio.create_task(
search_graph(
@@ -856,7 +856,7 @@ async def run_hybrid_search(
if search_type in ["embedding", "hybrid"]:
# Embedding-based search
logger.info("Starting embedding search...")
logger.info("[PERF] Starting embedding search...")
embedding_start = time.time()
# 从数据库读取嵌入器配置(按 ID并构建 RedBearModelConfig
@@ -872,13 +872,13 @@ async def run_hybrid_search(
type="llm"
)
config_load_time = time.time() - config_load_start
logger.info(f"Config loading took {config_load_time:.4f}s")
logger.info(f"[PERF] Config loading took {config_load_time:.4f}s")
# Init embedder
embedder_init_start = time.time()
embedder = OpenAIEmbedderClient(model_config=rb_config)
embedder_init_time = time.time() - embedder_init_start
logger.info(f"Embedder init took {embedder_init_time:.4f}s")
logger.info(f"[PERF] Embedder init took {embedder_init_time:.4f}s")
embedding_task = asyncio.create_task(
search_graph_by_embedding(
@@ -895,7 +895,7 @@ async def run_hybrid_search(
keyword_results = await keyword_task
keyword_latency = time.time() - keyword_start
latency_metrics["keyword_search_latency"] = round(keyword_latency, 4)
logger.info(f"Keyword search completed in {keyword_latency:.4f}s")
logger.info(f"[PERF] Keyword search completed in {keyword_latency:.4f}s")
if search_type == "keyword":
results = keyword_results
else:
@@ -905,7 +905,7 @@ async def run_hybrid_search(
embedding_results = await embedding_task
embedding_latency = time.time() - embedding_start
latency_metrics["embedding_search_latency"] = round(embedding_latency, 4)
logger.info(f"Embedding search completed in {embedding_latency:.4f}s")
logger.info(f"[PERF] Embedding search completed in {embedding_latency:.4f}s")
if search_type == "embedding":
results = embedding_results
else:
@@ -922,17 +922,21 @@ async def run_hybrid_search(
# Apply two-stage reranking with ACTR activation calculation
rerank_start = time.time()
logger.info("Using two-stage reranking with ACTR activation")
logger.info("[PERF] Using two-stage reranking with ACTR activation")
# 加载遗忘引擎配置
config_start = time.time()
try:
pc = get_pipeline_config(memory_config)
forgetting_cfg = pc.forgetting_engine
except Exception as e:
logger.debug(f"Failed to load forgetting config, using defaults: {e}")
forgetting_cfg = ForgettingEngineConfig()
config_time = time.time() - config_start
logger.info(f"[PERF] Forgetting config loading took {config_time:.4f}s")
# 统一使用激活度重排序(两阶段:检索 + ACTR计算
rerank_compute_start = time.time()
reranked_results = rerank_with_activation(
keyword_results=keyword_results,
embedding_results=embedding_results,
@@ -941,10 +945,12 @@ async def run_hybrid_search(
forgetting_config=forgetting_cfg,
activation_boost_factor=activation_boost_factor,
)
rerank_compute_time = time.time() - rerank_compute_start
logger.info(f"[PERF] Rerank computation took {rerank_compute_time:.4f}s")
rerank_latency = time.time() - rerank_start
latency_metrics["reranking_latency"] = round(rerank_latency, 4)
logger.info(f"Reranking completed in {rerank_latency:.4f}s")
logger.info(f"[PERF] Total reranking completed in {rerank_latency:.4f}s")
# Optional: apply reranker placeholder if enabled via config
reranked_results = apply_reranker_placeholder(reranked_results, query_text)
@@ -985,8 +991,10 @@ async def run_hybrid_search(
else:
results["latency_metrics"] = latency_metrics
logger.info(f"Total search completed in {total_latency:.4f}s")
logger.info(f"Latency breakdown: {latency_metrics}")
logger.info(f"[PERF] ===== SEARCH PERFORMANCE SUMMARY =====")
logger.info(f"[PERF] Total search completed in {total_latency:.4f}s")
logger.info(f"[PERF] Latency breakdown: {json.dumps(latency_metrics, indent=2)}")
logger.info(f"[PERF] =========================================")
# Sanitize results: drop large/unused fields
_remove_keys_recursive(results, ["name_embedding"]) # drop entity name embeddings from outputs

View File

@@ -1,6 +1,7 @@
import asyncio
import json
from datetime import datetime
from typing import List, Optional
from typing import List, Optional, Tuple
from uuid import uuid4
from app.core.logging_config import get_memory_logger
@@ -28,6 +29,118 @@ class MemorySummaryResponse(RobustLLMResponse):
)
async def generate_title_and_type_for_summary(
content: str,
llm_client
) -> Tuple[str, str]:
"""
为MemorySummary生成标题和类型
此方法应该在创建MemorySummary节点时调用生成title和type
Args:
content: Summary的内容文本
llm_client: LLM客户端实例
Returns:
(标题, 类型)元组
"""
from app.core.memory.utils.prompt.prompt_utils import render_episodic_title_and_type_prompt
# 定义有效的类型集合
VALID_TYPES = {
"conversation", # 对话
"project_work", # 项目/工作
"learning", # 学习
"decision", # 决策
"important_event" # 重要事件
}
DEFAULT_TYPE = "conversation" # 默认类型
try:
if not content:
logger.warning("content为空无法生成标题和类型")
return ("空内容", DEFAULT_TYPE)
# 1. 渲染Jinja2提示词模板
prompt = await render_episodic_title_and_type_prompt(content)
# 2. 调用LLM生成标题和类型
messages = [
{"role": "user", "content": prompt}
]
response = await llm_client.chat(messages=messages)
# 3. 解析LLM响应
content_response = response.content
if isinstance(content_response, list):
if len(content_response) > 0:
if isinstance(content_response[0], dict):
text = content_response[0].get('text', content_response[0].get('content', str(content_response[0])))
full_response = str(text)
else:
full_response = str(content_response[0])
else:
full_response = ""
elif isinstance(content_response, dict):
full_response = str(content_response.get('text', content_response.get('content', str(content_response))))
else:
full_response = str(content_response) if content_response is not None else ""
# 4. 解析JSON响应
try:
# 尝试从响应中提取JSON
# 移除可能的markdown代码块标记
json_str = full_response.strip()
if json_str.startswith("```json"):
json_str = json_str[7:]
if json_str.startswith("```"):
json_str = json_str[3:]
if json_str.endswith("```"):
json_str = json_str[:-3]
json_str = json_str.strip()
result_data = json.loads(json_str)
title = result_data.get("title", "未知标题")
episodic_type_raw = result_data.get("type", DEFAULT_TYPE)
# 5. 校验和归一化类型
# 将类型转换为小写并去除空格
episodic_type_normalized = str(episodic_type_raw).lower().strip()
# 检查是否在有效类型集合中
if episodic_type_normalized in VALID_TYPES:
episodic_type = episodic_type_normalized
else:
# 尝试映射常见的中文类型到英文
type_mapping = {
"对话": "conversation",
"项目": "project_work",
"工作": "project_work",
"项目/工作": "project_work",
"学习": "learning",
"决策": "decision",
"重要事件": "important_event",
"事件": "important_event"
}
episodic_type = type_mapping.get(episodic_type_raw, DEFAULT_TYPE)
logger.warning(
f"LLM返回的类型 '{episodic_type_raw}' 不在有效集合中,"
f"已归一化为 '{episodic_type}'"
)
logger.info(f"成功生成标题和类型: title={title}, type={episodic_type}")
return (title, episodic_type)
except json.JSONDecodeError:
logger.error(f"无法解析LLM响应为JSON: {full_response}")
return ("解析失败", DEFAULT_TYPE)
except Exception as e:
logger.error(f"生成标题和类型时出错: {str(e)}", exc_info=True)
return ("错误", DEFAULT_TYPE)
async def _process_chunk_summary(
dialog: DialogData,
chunk,
@@ -63,10 +176,9 @@ async def _process_chunk_summary(
title = None
episodic_type = None
try:
from app.services.user_memory_service import UserMemoryService
title, episodic_type = await UserMemoryService.generate_title_and_type_for_summary(
title, episodic_type = await generate_title_and_type_for_summary(
content=summary_text,
end_user_id=dialog.group_id
llm_client=llm_client
)
logger.info(f"Generated title and type for MemorySummary: title={title}, type={episodic_type}")
except Exception as e:

View File

@@ -8,14 +8,16 @@ Classes:
AccessHistoryManager: 访问历史管理器,提供并发安全的访问记录和一致性检查
"""
import asyncio
import logging
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
from app.core.memory.storage_services.forgetting_engine.actr_calculator import (
ACTRCalculator,
)
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
logger = logging.getLogger(__name__)
@@ -188,30 +190,43 @@ class AccessHistoryManager:
Returns:
List[Dict[str, Any]]: 成功更新的节点列表
"""
import time
batch_start = time.time()
if current_time is None:
current_time = datetime.now()
# PERFORMANCE FIX: Process all nodes in parallel instead of sequentially
tasks = []
for node_id in node_ids:
task = self.record_access(
node_id=node_id,
node_label=node_label,
group_id=group_id,
current_time=current_time
)
tasks.append(task)
# Execute all tasks in parallel
task_results = await asyncio.gather(*tasks, return_exceptions=True)
# Collect successful results and count failures
results = []
failed_count = 0
for node_id in node_ids:
try:
updated_node = await self.record_access(
node_id=node_id,
node_label=node_label,
group_id=group_id,
current_time=current_time
)
results.append(updated_node)
except Exception as e:
for node_id, result in zip(node_ids, task_results):
if isinstance(result, Exception):
failed_count += 1
logger.warning(
f"批量访问记录失败: {node_label}[{node_id}], 错误: {str(e)}"
f"批量访问记录失败: {node_label}[{node_id}], 错误: {str(result)}"
)
else:
results.append(result)
batch_duration = time.time() - batch_start
logger.info(
f"批量访问记录完成: 成功 {len(results)}/{len(node_ids)}, "
f"失败 {failed_count}"
f"[PERF] 批量访问记录完成: 成功 {len(results)}/{len(node_ids)}, "
f"失败 {failed_count}, 耗时 {batch_duration:.4f}s"
)
return results
@@ -531,7 +546,10 @@ class AccessHistoryManager:
Dict[str, Any]: 更新数据,包含所有需要更新的字段
"""
access_history = node_data.get('access_history') or []
importance_score = node_data.get('importance_score', 0.5)
# Handle None importance_score - default to 0.5
importance_score = node_data.get('importance_score')
if importance_score is None:
importance_score = 0.5
# 追加新的访问时间
new_access_history = access_history + [current_time_iso]
@@ -620,34 +638,52 @@ class AccessHistoryManager:
new_version = current_version + 1
# 步骤2使用乐观锁更新节点
# 只有当版本号匹配时才更新
update_query = f"""
MATCH (n:{node_label} {{id: $node_id}})
"""
# 根据节点类型构建完整的查询语句
content_field_map = {
'Statement': 'n.statement as statement',
'MemorySummary': 'n.content as content',
'ExtractedEntity': 'null as content_placeholder' # 占位符,后续会被过滤
}
# 显式检查节点类型,不支持的类型抛出错误
if node_label not in content_field_map:
raise ValueError(
f"Unsupported node_label: {node_label}. "
f"Supported labels are: {list(content_field_map.keys())}"
)
content_field = content_field_map[node_label]
# 构建 WHERE 子句
where_conditions = []
if group_id:
update_query += " WHERE n.group_id = $group_id"
where_conditions.append("n.group_id = $group_id")
# 添加版本检查
if current_version > 0:
update_query += " AND n.version = $current_version"
where_conditions.append("n.version = $current_version")
else:
# 如果节点没有版本号,检查是否为首次更新
update_query += " AND (n.version IS NULL OR n.version = 0)"
where_conditions.append("(n.version IS NULL OR n.version = 0)")
update_query += """
where_clause = " AND ".join(where_conditions) if where_conditions else "true"
# 构建完整的更新查询
update_query = f"""
MATCH (n:{node_label} {{id: $node_id}})
WHERE {where_clause}
SET n.activation_value = $activation_value,
n.access_history = $access_history,
n.last_access_time = $last_access_time,
n.access_count = $access_count,
n.version = $new_version
RETURN n.id as id,
n.statement as statement,
n.activation_value as activation_value,
n.access_history as access_history,
n.last_access_time as last_access_time,
n.access_count as access_count,
n.importance_score as importance_score,
n.version as version
n.version as version,
{content_field}
"""
update_params = {
@@ -671,7 +707,11 @@ class AccessHistoryManager:
f"Expected version {current_version}, but node was modified by another transaction."
)
return dict(updated_node)
# 转换为字典并移除占位符字段
result_dict = dict(updated_node)
result_dict.pop('content_placeholder', None)
return result_dict
# 执行事务
try:

View File

@@ -260,17 +260,32 @@ class ForgettingStrategy:
)
# 生成标题和类型使用LLM
from app.services.user_memory_service import UserMemoryService
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import generate_title_and_type_for_summary
# 获取 LLM 客户端
llm_client = None
if config_id is not None and db is not None:
try:
llm_client = await self._get_llm_client(db, config_id)
except Exception as e:
logger.warning(f"获取 LLM 客户端失败: {str(e)}")
# 生成标题和类型
try:
title, episodic_type = await UserMemoryService.generate_title_and_type_for_summary(
content=summary_text,
end_user_id=group_id
)
logger.info(f"成功为MemorySummary生成标题和类型: title={title}, type={episodic_type}")
if llm_client is not None:
title, episodic_type = await generate_title_and_type_for_summary(
content=summary_text,
llm_client=llm_client
)
logger.info(f"成功为MemorySummary生成标题和类型: title={title}, type={episodic_type}")
else:
logger.warning("LLM 客户端不可用,使用默认标题和类型")
title = "未命名"
episodic_type = "conversation"
except Exception as e:
logger.error(f"生成标题和类型失败,使用默认值: {str(e)}")
title = "未命名"
episodic_type = "其他"
episodic_type = "conversation"
# 计算继承的激活值和重要性(取较高值)
inherited_activation = max(statement_activation, entity_activation)

View File

@@ -3,13 +3,11 @@
基于 LangGraph 的工作流执行引擎。
"""
# import uuid
import datetime
import logging
import uuid
from typing import Any
from langchain_core.messages import HumanMessage
from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.graph_builder import GraphBuilder
@@ -55,6 +53,12 @@ class WorkflowExecutor:
self.edges = workflow_config.get("edges", [])
self.execution_config = workflow_config.get("execution_config", {})
self.checkpoint_config = {
"configurable": {
"thread_id": uuid.uuid4(),
}
}
def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState:
"""准备初始状态(注入系统变量和会话变量)
@@ -95,7 +99,7 @@ class WorkflowExecutor:
case VariableType.ARRAY_NUMBER | VariableType.ARRAY_OBJECT | VariableType.ARRAY_BOOLEAN | VariableType.ARRAY_STRING:
conversation_vars[var_name] = []
input_variables = input_data.get("variables") or {} # Start 节点的自定义变量
conversation_vars = conversation_vars | input_data.get("conv", {})
# 构建分层的变量结构
variables = {
"sys": {
@@ -110,7 +114,7 @@ class WorkflowExecutor:
}
return {
"messages": [HumanMessage(content=user_message)],
"messages": [('user', user_message)],
"variables": variables,
"node_outputs": {},
"runtime_vars": {}, # 运行时节点变量(简化版,供快速访问)
@@ -196,6 +200,28 @@ class WorkflowExecutor:
logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}")
return prefixes, adjacent_and_referenced
def _build_final_output(self, result, elapsed_time):
node_outputs = result.get("node_outputs", {})
final_output = self._extract_final_output(node_outputs)
token_usage = self._aggregate_token_usage(node_outputs)
conversation_id = None
for node_id, node_output in node_outputs.items():
if node_output.get("node_type") == "start":
conversation_id = node_output.get("output", {}).get("conversation_id")
break
return {
"status": "completed",
"output": final_output,
"node_outputs": node_outputs,
"messages": result.get("messages", []),
"conversation_id": conversation_id,
"elapsed_time": elapsed_time,
"token_usage": token_usage,
"error": result.get("error"),
"variables": result.get("variables", {}),
}
def build_graph(self, stream=False) -> CompiledStateGraph:
"""构建 LangGraph
@@ -236,40 +262,16 @@ class WorkflowExecutor:
# 3. 执行工作流
try:
result = await graph.ainvoke(initial_state)
result = await graph.ainvoke(initial_state, config=self.checkpoint_config)
# 计算耗时
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
# 提取节点输出(现在包含 start 和 end 节点)
node_outputs = result.get("node_outputs", {})
# 提取最终输出(从最后一个非 start/end 节点)
final_output = self._extract_final_output(node_outputs)
# 聚合 token 使用情况
token_usage = self._aggregate_token_usage(node_outputs)
# 提取 conversation_id从 start 节点输出)
conversation_id = None
for node_id, node_output in node_outputs.items():
if node_output.get("node_type") == "start":
conversation_id = node_output.get("output", {}).get("conversation_id")
break
logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
return {
"status": "completed",
"output": final_output,
"node_outputs": node_outputs,
"messages": result.get("messages", []),
"conversation_id": conversation_id,
"elapsed_time": elapsed_time,
"token_usage": token_usage,
"error": result.get("error")
}
return self._build_final_output(result, elapsed_time)
except Exception as e:
# 计算耗时(即使失败也记录)
@@ -331,11 +333,11 @@ class WorkflowExecutor:
# 3. Execute workflow
try:
chunk_count = 0
final_state = None
async for event in graph.astream(
initial_state,
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
config=self.checkpoint_config
):
# event should be a tuple: (mode, data)
# But let's handle both cases
@@ -411,12 +413,11 @@ class WorkflowExecutor:
elif mode == "updates":
# Handle state updates - store final state
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())}")
final_state = data
# 计算耗时
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
result = graph.get_state(self.checkpoint_config).values
logger.info(
f"Workflow execution completed (streaming), "
f"total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s"
@@ -425,12 +426,7 @@ class WorkflowExecutor:
# 发送 workflow_end 事件
yield {
"event": "workflow_end",
"data": {
"execution_id": self.execution_id,
"status": "completed",
"elapsed_time": elapsed_time,
"timestamp": end_time.isoformat()
}
"data": self._build_final_output(result, elapsed_time)
}
except Exception as e:

View File

@@ -4,6 +4,7 @@ from typing import Any
from langgraph.graph.state import CompiledStateGraph, StateGraph
from langgraph.graph import START, END
from langgraph.checkpoint.memory import InMemorySaver
from app.core.workflow.expression_evaluator import evaluate_condition
from app.core.workflow.nodes import WorkflowState, NodeFactory
@@ -249,4 +250,5 @@ class GraphBuilder:
self.graph = StateGraph(WorkflowState)
self.add_nodes()
self.add_edges() # 添加边必须在添加节点之后
return self.graph.compile()
checkpointer = InMemorySaver()
return self.graph.compile(checkpointer=checkpointer)

View File

@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
class AssignerNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = AssignerNodeConfig(**self.config)
self.typed_config: AssignerNodeConfig | None = None
async def execute(self, state: WorkflowState) -> Any:
"""
@@ -28,6 +28,7 @@ class AssignerNode(BaseNode):
None or the result of the assignment operation.
"""
# Initialize a variable pool for accessing conversation, node, and system variables
self.typed_config = AssignerNodeConfig(**self.config)
logger.info(f"节点 {self.node_id} 开始执行")
pool = VariablePool(state)
for assignment in self.typed_config.assignments:

View File

@@ -25,7 +25,7 @@ class WorkflowState(TypedDict):
The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
"""
# List of messages (append mode)
messages: Annotated[list[AnyMessage], add]
messages: Annotated[list[tuple[str, str]], add]
# Set of loop node IDs, used for assigning values in loop nodes
cycle_nodes: list
@@ -203,6 +203,7 @@ class BaseNode(ABC):
# 返回包装后的输出和运行时变量
return {
**wrapped_output,
"variables": state["variables"],
"runtime_vars": {
self.node_id: runtime_var
},
@@ -355,6 +356,7 @@ class BaseNode(ABC):
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
state_update = {
**final_output,
"variables": state["variables"],
"runtime_vars": {
self.node_id: runtime_var
},

View File

@@ -30,7 +30,6 @@ class CycleGraphNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config: LoopNodeConfig | IterationNodeConfig | None = None
self.cycle_nodes = list() # Nodes belonging to this cycle
self.cycle_edges = list() # Edges connecting nodes within the cycle

View File

@@ -32,7 +32,7 @@ class HttpRequestNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = HttpRequestNodeConfig(**self.config)
self.typed_config: HttpRequestNodeConfig | None = None
def _build_timeout(self) -> Timeout:
"""
@@ -181,6 +181,7 @@ class HttpRequestNode(BaseNode):
- dict: Serialized HttpRequestNodeOutput on success
- str: Branch identifier (e.g. "ERROR") when branching is enabled
"""
self.typed_config = HttpRequestNodeConfig(**self.config)
async with httpx.AsyncClient(
verify=self.typed_config.verify_ssl,
timeout=self._build_timeout(),

View File

@@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
class IfElseNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = IfElseNodeConfig(**self.config)
self.typed_config: IfElseNodeConfig | None= None
@staticmethod
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
@@ -109,6 +109,7 @@ class IfElseNode(BaseNode):
Returns:
str: The matched branch identifier, e.g., 'CASE1', 'CASE2', ..., used for node transitions.
"""
self.typed_config = IfElseNodeConfig(**self.config)
expressions = self.evaluate_conditional_edge_expressions(state)
# TODO: 变量类型及文本类型解析
for i in range(len(expressions)):

View File

@@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
class JinjaRenderNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = JinjaRenderNodeConfig(**self.config)
self.typed_config: JinjaRenderNodeConfig | None = None
async def execute(self, state: WorkflowState) -> Any:
"""
@@ -34,6 +34,7 @@ class JinjaRenderNode(BaseNode):
RuntimeError: If Jinja2 template rendering fails due to invalid template
syntax or missing variables.
"""
self.typed_config = JinjaRenderNodeConfig(**self.config)
render = TemplateRenderer(strict=False)
context = {}

View File

@@ -44,8 +44,8 @@ class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
description="Knowledge base config"
)
reranker_id: UUID = Field(
default="",
reranker_id: UUID | None = Field(
default=None,
description="Reranker top k"
)

View File

@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
class KnowledgeRetrievalNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
@staticmethod
def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType):
@@ -171,6 +171,7 @@ class KnowledgeRetrievalNode(BaseNode):
Raises:
RuntimeError: If no valid knowledge base is found or access is denied.
"""
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
query = self._render_template(self.typed_config.query, state)
with get_db_read() as db:
knowledge_bases = self.typed_config.knowledge_bases

View File

@@ -68,7 +68,7 @@ class LLMNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = LLMNodeConfig(**self.config)
self.typed_config: LLMNodeConfig | None = None
def _render_context(self, message, state):
context = f"<context>{self._render_template(self.typed_config.context, state)}</context>"
@@ -164,6 +164,7 @@ class LLMNode(BaseNode):
Returns:
LLM 响应消息
"""
self.typed_config = LLMNodeConfig(**self.config)
llm, prompt_or_messages = self._prepare_llm(state, True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")

View File

@@ -10,9 +10,10 @@ from app.services.memory_agent_service import MemoryAgentService
class MemoryReadNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = MemoryReadNodeConfig(**self.config)
self.typed_config: MemoryReadNodeConfig | None = None
async def execute(self, state: WorkflowState) -> Any:
self.typed_config = MemoryReadNodeConfig(**self.config)
with get_db_read() as db:
workspace_id = self.get_variable('sys.workspace_id', state)
end_user_id = self.get_variable("sys.user_id", state)

View File

@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
class ParameterExtractorNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = ParameterExtractorNodeConfig(**self.config)
self.typed_config: ParameterExtractorNodeConfig | None = None
@staticmethod
def _get_prompt():
@@ -145,6 +145,7 @@ class ParameterExtractorNode(BaseNode):
Raises:
BusinessException: If LLM output cannot be parsed as valid JSON.
"""
self.typed_config = ParameterExtractorNodeConfig(**self.config)
llm = self._get_llm_instance()
system_prompt, user_prompt = self._get_prompt()

View File

@@ -21,8 +21,8 @@ class QuestionClassifierNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = QuestionClassifierNodeConfig(**self.config)
self.category_to_case_map = self._build_category_case_map()
self.typed_config: QuestionClassifierNodeConfig | None = None
self.category_to_case_map = {}
def _get_llm_instance(self) -> RedBearLLM:
"""获取LLM实例"""
@@ -67,6 +67,8 @@ class QuestionClassifierNode(BaseNode):
async def execute(self, state: WorkflowState) -> dict:
"""执行问题分类"""
self.typed_config = QuestionClassifierNodeConfig(**self.config)
self.category_to_case_map = self._build_category_case_map()
question = self.typed_config.input_variable
supplement_prompt = self.typed_config.user_supplement_prompt or ""
categories = self.typed_config.categories or []

View File

@@ -7,6 +7,7 @@ Start 节点实现
import logging
from typing import Any
from app.core.workflow.nodes.base_config import VariableType
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.start.config import StartNodeConfig
@@ -34,7 +35,7 @@ class StartNode(BaseNode):
super().__init__(node_config, workflow_config)
# 解析并验证配置
self.typed_config = StartNodeConfig(**self.config)
self.typed_config: StartNodeConfig | None = None
async def execute(self, state: WorkflowState) -> dict[str, Any]:
"""执行 start 节点业务逻辑
@@ -47,6 +48,7 @@ class StartNode(BaseNode):
Returns:
包含系统参数、会话变量和自定义变量的字典
"""
self.typed_config = StartNodeConfig(**self.config)
logger.info(f"节点 {self.node_id} (Start) 开始执行")
# 创建变量池实例(在方法内复用)
@@ -113,6 +115,18 @@ class StartNode(BaseNode):
logger.debug(
f"变量 '{var_name}' 使用默认值: {var_def.default}"
)
else:
match var_def.type:
case VariableType.STRING:
processed[var_name] = ""
case VariableType.NUMBER:
processed[var_name] = 0
case VariableType.OBJECT:
processed[var_name] = {}
case VariableType.BOOLEAN:
processed[var_name] = False
case VariableType.ARRAY_NUMBER | VariableType.ARRAY_OBJECT | VariableType.ARRAY_BOOLEAN | VariableType.ARRAY_STRING:
processed[var_name] = []
return processed

View File

@@ -19,10 +19,11 @@ class ToolNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = ToolNodeConfig(**self.config)
self.typed_config: ToolNodeConfig | None = None
async def execute(self, state: WorkflowState) -> dict[str, Any]:
"""执行工具"""
self.typed_config = ToolNodeConfig(**self.config)
# 获取租户ID和用户ID
tenant_id = self.get_variable("sys.tenant_id", state)
user_id = self.get_variable("sys.user_id", state)

View File

@@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
class VariableAggregatorNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = VariableAggregatorNodeConfig(**self.config)
self.typed_config: VariableAggregatorNodeConfig | None = None
@staticmethod
def _get_express(variable_string: str) -> Any:
@@ -37,6 +37,7 @@ class VariableAggregatorNode(BaseNode):
- str: In non-group mode, returns the first non-None variable value.
- dict: In group mode, returns a mapping of group_name -> first non-None variable value.
"""
self.typed_config = VariableAggregatorNodeConfig(**self.config)
if not self.typed_config.group:
# --------------------------
# Non-group mode

View File

@@ -66,24 +66,38 @@ async def _update_activation_values_batch(
max_retries=max_retries
)
# 提取节点ID列表
node_ids = [node.get('id') for node in nodes if node.get('id')]
# 提取节点ID列表并去重(保持原始顺序)
seen_ids = set()
unique_node_ids = []
for node in nodes:
node_id = node.get('id')
if node_id and node_id not in seen_ids:
seen_ids.add(node_id)
unique_node_ids.append(node_id)
if not node_ids:
if not unique_node_ids:
logger.warning(f"批量更新激活值没有有效的节点ID")
return nodes
# 记录去重信息(仅针对具有有效 ID 的节点)
id_nodes_count = sum(1 for n in nodes if n.get("id"))
if len(unique_node_ids) < id_nodes_count:
logger.info(
f"批量更新激活值检测到重复节点具有有效ID的节点数量={id_nodes_count}, "
f"去重后唯一ID数量={len(unique_node_ids)}"
)
# 批量记录访问
try:
updated_nodes = await access_manager.record_batch_access(
node_ids=node_ids,
node_ids=unique_node_ids,
node_label=node_label,
group_id=group_id
)
logger.info(
f"批量更新激活值成功: {node_label}, "
f"更新数量={len(updated_nodes)}/{len(node_ids)}"
f"更新数量={len(updated_nodes)}/{len(unique_node_ids)}"
)
return updated_nodes
@@ -153,19 +167,38 @@ async def _update_search_results_activation(
original_nodes = results[key]
updated_nodes = update_result
# 创建 ID 到原始节点的映射(用于快速查找 score
original_map = {node.get('id'): node for node in original_nodes if node.get('id')}
# 创建 ID 到更新节点的映射(用于快速查找激活值数据
updated_map = {node.get('id'): node for node in updated_nodes if node.get('id')}
# 合并数据:激活值来自更新结果score 来自原始结果
# 合并数据:保留所有原始节点(包括重复的),用更新后的激活值数据填充
merged_nodes = []
for updated_node in updated_nodes:
node_id = updated_node.get('id')
if node_id and node_id in original_map:
# 保留原始的 score 字段
original_score = original_map[node_id].get('score')
if original_score is not None:
updated_node['score'] = original_score
merged_nodes.append(updated_node)
for original_node in original_nodes:
node_id = original_node.get('id')
if node_id and node_id in updated_map:
# 从原始节点开始,用更新后的激活值数据覆盖
merged_node = original_node.copy()
# 更新激活值相关字段
activation_fields = {
'activation_value',
'access_history',
'last_access_time',
'access_count',
'importance_score',
'version',
'statement', # Statement 节点的内容字段
'content' # MemorySummary 节点的内容字段
}
# 只更新激活值相关字段,保留原始节点的其他字段
for field in activation_fields:
if field in updated_map[node_id]:
merged_node[field] = updated_map[node_id][field]
merged_nodes.append(merged_node)
else:
# 如果没有更新数据,保留原始节点
merged_nodes.append(original_node)
updated_results[key] = merged_nodes
else:

View File

@@ -15,6 +15,7 @@ from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.celery_app import celery_app
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
@@ -22,6 +23,7 @@ from app.core.rag.nlp.search import knowledge_retrieval
from app.models import AgentConfig, ModelApiKey, ModelConfig
from app.repositories.tool_repository import ToolRepository
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
from app.services import task_service
from app.services.langchain_tool_server import Search
from app.services.memory_agent_service import MemoryAgentService
from app.services.model_parameter_merger import ModelParameterMerger
@@ -101,6 +103,14 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
user_rag_memory_id=user_rag_memory_id
)
)
task = celery_app.send_task(
"app.core.memory.agent.read_message",
args=[end_user_id, question, [], "1", config_id, storage_type, user_rag_memory_id]
)
result = task_service.get_task_memory_read_result(task.id)
status = result.get("status")
logger.info(f"读取任务状态:{status}")
finally:
db.close()
logger.info(f'用户IDAgent:{end_user_id}')

View File

@@ -456,23 +456,36 @@ class MemoryAgentService:
client = MultiServerMCPClient(mcp_config)
async with client.session('data_flow') as session:
session_start = time.time()
logger.debug("Connected to MCP Server: data_flow")
tools_start = time.time()
tools = await load_mcp_tools(session)
tools_time = time.time() - tools_start
logger.info(f"[PERF] MCP tools loading took: {tools_time:.4f}s")
outputs = []
intermediate_outputs = []
seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates
# Pass memory_config to the graph workflow
graph_start = time.time()
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph:
graph_init_time = time.time() - graph_start
logger.info(f"[PERF] Graph initialization took: {graph_init_time:.4f}s")
start = time.time()
config = {"configurable": {"thread_id": group_id}}
workflow_errors = [] # Track errors from workflow
event_count = 0
async for event in graph.astream(
{"messages": history, "memory_config": memory_config, "errors": []},
stream_mode="values",
config=config
):
event_count += 1
event_start = time.time()
messages = event.get('messages')
# Capture any errors from the state
if event.get('errors'):
@@ -525,9 +538,15 @@ class MemoryAgentService:
pass
except Exception as e:
logger.debug(f"Failed to extract intermediate output: {e}")
event_time = time.time() - event_start
logger.info(f"[PERF] Event {event_count} processing took: {event_time:.4f}s")
workflow_duration = time.time() - start
logger.info(f"Read graph workflow completed in {workflow_duration}s")
session_duration = time.time() - session_start
logger.info(f"[PERF] Read graph workflow completed in {workflow_duration}s")
logger.info(f"[PERF] Total session duration: {session_duration:.4f}s")
logger.info(f"[PERF] Total events processed: {event_count}")
# Extract final answer
final_answer = ""
for messages in outputs:
@@ -1186,8 +1205,8 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
ValueError: 当终端用户不存在或应用未发布时
"""
from app.models.app_release_model import AppRelease
from app.models.end_user_model import EndUser
from app.models.data_config_model import DataConfig
from app.models.end_user_model import EndUser
from sqlalchemy import select
logger.info(f"Getting connected config for end_user: {end_user_id}")
@@ -1266,8 +1285,8 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
对于查询失败的用户value 包含 error 字段
"""
from app.models.app_release_model import AppRelease
from app.models.end_user_model import EndUser
from app.models.data_config_model import DataConfig
from app.models.end_user_model import EndUser
from sqlalchemy import select
logger.info(f"Batch getting connected configs for {len(end_user_ids)} end users")

View File

@@ -9,6 +9,7 @@ from typing import Optional
from app.core.logging_config import get_logger
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.emotion_analytics_service import EmotionAnalyticsService
logger = get_logger(__name__)
@@ -109,3 +110,188 @@ class MemoryBaseService:
except Exception as e:
logger.error(f"提取情景记忆情绪时出错: {str(e)}", exc_info=True)
return None
async def get_episodic_memory_count(
self,
end_user_id: Optional[str] = None
) -> int:
"""
获取情景记忆数量
查询 MemorySummary 节点的数量。
Args:
end_user_id: 可选的终端用户ID用于过滤特定用户的节点
Returns:
情景记忆的数量
"""
try:
if end_user_id:
query = """
MATCH (n:MemorySummary)
WHERE n.group_id = $group_id
RETURN count(n) as count
"""
result = await self.neo4j_connector.execute_query(query, group_id=end_user_id)
else:
query = """
MATCH (n:MemorySummary)
RETURN count(n) as count
"""
result = await self.neo4j_connector.execute_query(query)
count = result[0]["count"] if result and len(result) > 0 else 0
logger.debug(f"情景记忆数量: {count} (end_user_id={end_user_id})")
return count
except Exception as e:
logger.error(f"获取情景记忆数量时出错: {str(e)}", exc_info=True)
return 0
async def get_explicit_memory_count(
self,
end_user_id: Optional[str] = None
) -> int:
"""
获取显性记忆数量
显性记忆 = 情景记忆MemorySummary+ 语义记忆ExtractedEntity with is_explicit_memory=true
Args:
end_user_id: 可选的终端用户ID用于过滤特定用户的节点
Returns:
显性记忆的数量
"""
try:
# 1. 获取情景记忆数量
episodic_count = await self.get_episodic_memory_count(end_user_id)
# 2. 获取语义记忆数量ExtractedEntity 且 is_explicit_memory = true
if end_user_id:
semantic_query = """
MATCH (e:ExtractedEntity)
WHERE e.group_id = $group_id AND e.is_explicit_memory = true
RETURN count(e) as count
"""
semantic_result = await self.neo4j_connector.execute_query(
semantic_query,
group_id=end_user_id
)
else:
semantic_query = """
MATCH (e:ExtractedEntity)
WHERE e.is_explicit_memory = true
RETURN count(e) as count
"""
semantic_result = await self.neo4j_connector.execute_query(semantic_query)
semantic_count = semantic_result[0]["count"] if semantic_result and len(semantic_result) > 0 else 0
# 3. 计算总数
explicit_count = episodic_count + semantic_count
logger.debug(
f"显性记忆数量: {explicit_count} "
f"(情景={episodic_count}, 语义={semantic_count}, end_user_id={end_user_id})"
)
return explicit_count
except Exception as e:
logger.error(f"获取显性记忆数量时出错: {str(e)}", exc_info=True)
return 0
async def get_emotional_memory_count(
self,
end_user_id: Optional[str] = None,
statement_count_fallback: int = 0
) -> int:
"""
获取情绪记忆数量
通过 EmotionAnalyticsService 获取情绪标签统计总数。
如果获取失败或没有指定 end_user_id使用 statement_count_fallback 作为后备。
Args:
end_user_id: 可选的终端用户ID
statement_count_fallback: 后备方案的数量(通常是 statement 节点数量)
Returns:
情绪记忆的数量
"""
try:
if end_user_id:
emotion_service = EmotionAnalyticsService()
emotion_data = await emotion_service.get_emotion_tags(
end_user_id=end_user_id,
emotion_type=None,
start_date=None,
end_date=None,
limit=10
)
emotion_count = emotion_data.get("total_count", 0)
logger.debug(f"情绪记忆数量: {emotion_count} (end_user_id={end_user_id})")
return emotion_count
else:
# 如果没有指定 end_user_id使用后备方案
logger.debug(f"情绪记忆数量: {statement_count_fallback} (使用后备方案)")
return statement_count_fallback
except Exception as e:
logger.warning(f"获取情绪记忆数量失败,使用后备方案: {str(e)}")
return statement_count_fallback
async def get_forget_memory_count(
self,
end_user_id: Optional[str] = None,
forgetting_threshold: float = 0.3
) -> int:
"""
获取遗忘记忆数量
统计激活值低于遗忘阈值的节点数量low_activation_nodes
查询范围包括Statement、ExtractedEntity、MemorySummary、Chunk 节点。
Args:
end_user_id: 可选的终端用户ID用于过滤特定用户的节点
forgetting_threshold: 遗忘阈值,默认 0.3
Returns:
遗忘记忆的数量(激活值低于阈值的节点数)
"""
try:
# 构建查询语句
query = """
MATCH (n)
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk)
"""
if end_user_id:
query += " AND n.group_id = $group_id"
query += """
RETURN sum(CASE WHEN n.activation_value IS NOT NULL AND n.activation_value < $threshold THEN 1 ELSE 0 END) as low_activation_nodes
"""
# 设置查询参数
params = {'threshold': forgetting_threshold}
if end_user_id:
params['group_id'] = end_user_id
# 执行查询
result = await self.neo4j_connector.execute_query(query, **params)
# 提取结果
forget_count = result[0]['low_activation_nodes'] if result and len(result) > 0 else 0
forget_count = forget_count or 0 # 处理 None 值
logger.debug(
f"遗忘记忆数量: {forget_count} "
f"(threshold={forgetting_threshold}, end_user_id={end_user_id})"
)
return forget_count
except Exception as e:
logger.error(f"获取遗忘记忆数量时出错: {str(e)}", exc_info=True)
return 0

View File

@@ -401,5 +401,5 @@ class MemoryEpisodicService(MemoryBaseService):
raise
# 创建全局服务实例
# 创建全局服务实例(供控制器层使用)
memory_episodic_service = MemoryEpisodicService()

View File

@@ -15,6 +15,7 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.end_user_repository import EndUserRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_base_service import MemoryBaseService
from app.services.memory_config_service import MemoryConfigService
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
@@ -1195,17 +1196,18 @@ async def analytics_memory_types(
end_user_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
统计8种记忆类型的数量和百分比
统计9种记忆类型的数量和百分比
计算规则:
1. 感知记忆 (PERCEPTUAL_MEMORY) = statement + entity
2. 工作记忆 (WORKING_MEMORY) = chunk + entity
3. 短期记忆 (SHORT_TERM_MEMORY) = chunk
4. 长期记忆 (LONG_TERM_MEMORY) = entity
5. 显性记忆 (EXPLICIT_MEMORY) = 1/2 * entity
5. 显性记忆 (EXPLICIT_MEMORY) = 情景记忆 + 语义记忆(通过 MemoryBaseService.get_explicit_memory_count 获取)
6. 隐性记忆 (IMPLICIT_MEMORY) = 1/3 * entity
7. 情绪记忆 (EMOTIONAL_MEMORY) = statement
8. 情景记忆 (EPISODIC_MEMORY) = memory_summary
7. 情绪记忆 (EMOTIONAL_MEMORY) = 情绪标签统计总数(通过 MemoryBaseService.get_emotional_memory_count 获取)
8. 情景记忆 (EPISODIC_MEMORY) = memory_summary(通过 MemoryBaseService.get_episodic_memory_count 获取)
9. 遗忘记忆 (FORGET_MEMORY) = 激活值低于阈值的节点数(通过 MemoryBaseService.get_forget_memory_count 获取)
Args:
db: 数据库会话
@@ -1230,13 +1232,16 @@ async def analytics_memory_types(
- IMPLICIT_MEMORY: 隐性记忆
- EMOTIONAL_MEMORY: 情绪记忆
- EPISODIC_MEMORY: 情景记忆
- FORGET_MEMORY: 遗忘记忆
"""
# 定义需要查询的节点类型
# 初始化基础服务
base_service = MemoryBaseService()
# 定义需要查询的基础节点类型
node_types = {
"Statement": "Statement",
"Entity": "ExtractedEntity",
"Chunk": "Chunk",
"MemorySummary": "MemorySummary"
"Chunk": "Chunk"
}
# 存储每种节点类型的计数
@@ -1266,18 +1271,45 @@ async def analytics_memory_types(
statement_count = node_counts.get("Statement", 0)
entity_count = node_counts.get("Entity", 0)
chunk_count = node_counts.get("Chunk", 0)
memory_summary_count = node_counts.get("MemorySummary", 0)
# 按规则计算8种记忆类型的数量使用英文枚举作为key
# 获取用户的遗忘阈值配置
forgetting_threshold = 0.3 # 默认值
if end_user_id:
try:
from app.services.memory_agent_service import get_end_user_connected_config
from app.core.memory.storage_services.forgetting_engine.config_utils import load_actr_config_from_db
# 获取用户关联的 config_id
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get('memory_config_id')
if config_id:
# 从数据库加载配置
config = load_actr_config_from_db(db, config_id)
forgetting_threshold = config.get('forgetting_threshold', 0.3)
logger.debug(f"使用用户配置的遗忘阈值: {forgetting_threshold} (end_user_id={end_user_id}, config_id={config_id})")
else:
logger.debug(f"用户未关联配置,使用默认遗忘阈值: {forgetting_threshold} (end_user_id={end_user_id})")
except Exception as e:
logger.warning(f"获取用户遗忘阈值配置失败,使用默认值 {forgetting_threshold}: {str(e)}")
# 使用 MemoryBaseService 的共享方法获取特殊记忆类型的数量
episodic_count = await base_service.get_episodic_memory_count(end_user_id)
explicit_count = await base_service.get_explicit_memory_count(end_user_id)
emotion_count = await base_service.get_emotional_memory_count(end_user_id, statement_count)
forget_count = await base_service.get_forget_memory_count(end_user_id, forgetting_threshold)
# 按规则计算9种记忆类型的数量使用英文枚举作为key
memory_counts = {
"PERCEPTUAL_MEMORY": statement_count + entity_count, # 感知记忆
"WORKING_MEMORY": chunk_count + entity_count, # 工作记忆
"SHORT_TERM_MEMORY": chunk_count, # 短期记忆
"LONG_TERM_MEMORY": entity_count, # 长期记忆
"EXPLICIT_MEMORY": entity_count // 2, # 显性记忆 (1/2 entity)
"EXPLICIT_MEMORY": explicit_count, # 显性记忆(情景记忆 + 语义记忆)
"IMPLICIT_MEMORY": entity_count // 3, # 隐性记忆 (1/3 entity)
"EMOTIONAL_MEMORY": statement_count, # 情绪记忆
"EPISODIC_MEMORY": memory_summary_count # 情景记忆
"EMOTIONAL_MEMORY": emotion_count, # 情绪记忆(使用情绪标签统计)
"EPISODIC_MEMORY": episodic_count, # 情景记忆
"FORGET_MEMORY": forget_count # 遗忘记忆(激活值低于阈值)
}
# 计算总数

View File

@@ -491,6 +491,17 @@ class WorkflowService:
)
end_user_id = str(new_end_user.id)
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
for exec_res in executions:
if exec_res.status == "completed":
last_state = exec_res.output_data
if isinstance(last_state, dict):
variables = last_state.get("variables", {})
conv_vars = variables.get("conv", {})
input_data["conv"] = conv_vars
break
result = await execute_workflow(
workflow_config=workflow_config_dict,
input_data=input_data,
@@ -504,7 +515,7 @@ class WorkflowService:
self.update_execution_status(
execution.execution_id,
"completed",
output_data=result.get("node_outputs", {})
output_data=result
)
else:
self.update_execution_status(
@@ -517,6 +528,7 @@ class WorkflowService:
return {
"execution_id": execution.execution_id,
"status": result.get("status"),
"variables": result.get("variables"),
"output": result.get("output"), # 最终输出(字符串)
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
"conversation_id": result.get("conversation_id"), # 所有节点输出详细数据payload., # 会话 ID
@@ -617,6 +629,16 @@ class WorkflowService:
original_user_id=payload.user_id # Save original user_id to other_id
)
end_user_id = str(new_end_user.id)
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
for exec_res in executions:
if exec_res.status == "completed":
last_state = exec_res.output_data
if isinstance(last_state, dict):
variables = last_state.get("variables", {})
conv_vars = variables.get("conv", {})
input_data["conv"] = conv_vars
break
# 调用流式执行executor 会发送 workflow_start 和 workflow_end 事件)
async for event in self._run_workflow_stream(
@@ -827,6 +849,23 @@ class WorkflowService:
user_id=user_id
):
# 直接转发事件executor 已经返回正确格式)
if event.get("event") == "workflow_end":
status = event.get("data", {}).get("status")
if status == "completed":
self.update_execution_status(
execution_id,
"completed",
output_data=event.get("data")
)
elif status == "failed":
self.update_execution_status(
execution_id,
"failed",
output_data=event.get("data")
)
else:
logger.error(f"unexpect workflow run status, status: {status}")
yield event
except Exception as e:

View File

@@ -117,26 +117,26 @@ export const getRagContent = (end_user_id: string) => {
}
// 情感分布分析
export const getWordCloud = (group_id: string) => {
return request.post(`/memory/emotion/wordcloud`, { group_id, limit: 20 })
return request.post(`/memory/emotion-memory/wordcloud`, { group_id, limit: 20 })
}
// 高频情绪关键词
export const getEmotionTags = (group_id: string) => {
return request.post(`/memory/emotion/tags`, { group_id, limit: 20 })
return request.post(`/memory/emotion-memory/tags`, { group_id, limit: 20 })
}
// 情绪健康指数
export const getEmotionHealth = (group_id: string) => {
return request.post(`/memory/emotion/health`, { group_id, limit: 20 })
return request.post(`/memory/emotion-memory/health`, { group_id, limit: 20 })
}
// 个性化建议
export const getEmotionSuggestions = (group_id: string) => {
return request.post(`/memory/emotion/suggestions`, { group_id, limit: 20 })
return request.post(`/memory/emotion-memory/suggestions`, { group_id, limit: 20 })
}
export const analyticsRefresh = (end_user_id: string) => {
return request.post('/memory-storage/analytics/generate_cache', { end_user_id })
}
// 遗忘
export const getForgetStats = (group_id: string) => {
return request.get(`/memory/forget/stats`, { group_id })
return request.get(`/memory/forget-memory/stats`, { group_id })
}
// 隐性记忆-偏好
export const getImplicitPreferences = (end_user_id: string) => {
@@ -176,10 +176,10 @@ export const getPerceptualTimeline = (end_user: string) => {
}
// 情景记忆-总览
export const getEpisodicOverview = (data: { end_user_id: string; time_range: string; episodic_type: string; } ) => {
return request.post(`/memory-storage/classifications/episodic-memory`, data)
return request.post(`/memory/episodic-memory/overview`, data)
}
export const getEpisodicDetail = (data: { end_user_id: string; summary_id: string; } ) => {
return request.post(`/memory-storage/classifications/episodic-memory-details`, data)
return request.post(`/memory/episodic-memory/details`, data)
}
// 关系演化
export const getRelationshipEvolution = (data: { id: string; label: string; } ) => {
@@ -190,10 +190,10 @@ export const getTimelineMemories = (data: { id: string; label: string; }) => {
return request.get(`/memory-storage/memory_space/timeline_memories`, data)
}
export const getExplicitMemory = (end_user_id: string) => {
return request.post(`/memory-storage/classifications/explicit-memory`, { end_user_id })
return request.post(`/memory/explicit-memory/overview`, { end_user_id })
}
export const getExplicitMemoryDetails = (data: { end_user_id: string, memory_id: string; }) => {
return request.post(`/memory-storage/classifications/explicit-memory-details`, data)
return request.post(`/memory/explicit-memory/details`, data)
}
export const getConversations = (end_user: string) => {
return request.get(`/memory/work/${end_user}/conversations`)
@@ -205,7 +205,7 @@ export const getConversationDetail = (end_user: string, conversation_id: string)
return request.get(`/memory/work/${end_user}/detail`, { conversation_id })
}
export const forgetTrigger = (data: { max_merge_batch_size: number; min_days_since_access: number; end_user_id: string;}) => {
return request.post(`/memory/forget/trigger`, data)
return request.post(`/memory/forget-memory/trigger`, data)
}
/*************** end 用户记忆 相关接口 ******************************/
@@ -229,11 +229,11 @@ export const deleteMemoryConfig = (config_id: number) => {
}
// 遗忘引擎-获取配置
export const getMemoryForgetConfig = (config_id: number | string) => {
return request.get('/memory/forget/read_config', { config_id })
return request.get('/memory/forget-memory/read_config', { config_id })
}
// 遗忘引擎-更新配置
export const updateMemoryForgetConfig = (values: ForgetConfigForm) => {
return request.post('/memory/forget/update_config', values)
return request.post('/memory/forget-memory/update_config', values)
}
// 记忆萃取引擎-获取配置
export const getMemoryExtractionConfig = (config_id: number | string) => {

View File

@@ -0,0 +1,17 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>模型 (1)</title>
<g id="v0.2.0" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<g id="红熊空间-记忆管理" transform="translate(-24, -409)" stroke="#5B6167">
<g id="记忆对话备份-2" transform="translate(12, 401)">
<g id="模型-(1)" transform="translate(12, 8)">
<g id="编组-21" transform="translate(1.5, 1.5)">
<path d="M7,0.288675135 L11.6291651,2.96132487 C11.9385662,3.13995766 12.1291651,3.47008468 12.1291651,3.82735027 L12.1291651,9.17264973 C12.1291651,9.52991532 11.9385662,9.86004234 11.6291651,10.0386751 L7,12.7113249 C6.69059892,12.8899577 6.30940108,12.8899577 6,12.7113249 L1.37083488,10.0386751 C1.0614338,9.86004234 0.870834875,9.52991532 0.870834875,9.17264973 L0.870834875,3.82735027 C0.870834875,3.47008468 1.0614338,3.13995766 1.37083488,2.96132487 L6,0.288675135 C6.30940108,0.11004234 6.69059892,0.11004234 7,0.288675135 Z" id="多边形"></path>
<polyline id="路径-15" points="0.931223827 3.37218958 6.5 6.5 6.5 12.8581283"></polyline>
<line x1="6.5" y1="6.49748419" x2="12.0714286" y2="3.37218958" id="路径-16"></line>
</g>
</g>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 1.5 KiB

View File

@@ -0,0 +1,17 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>模型 (1)</title>
<g id="v0.2.0" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<g id="红熊空间-记忆管理" transform="translate(-24, -409)" stroke="#212332">
<g id="记忆对话备份-2" transform="translate(12, 401)">
<g id="模型-(1)" transform="translate(12, 8)">
<g id="编组-21" transform="translate(1.5, 1.5)">
<path d="M7,0.288675135 L11.6291651,2.96132487 C11.9385662,3.13995766 12.1291651,3.47008468 12.1291651,3.82735027 L12.1291651,9.17264973 C12.1291651,9.52991532 11.9385662,9.86004234 11.6291651,10.0386751 L7,12.7113249 C6.69059892,12.8899577 6.30940108,12.8899577 6,12.7113249 L1.37083488,10.0386751 C1.0614338,9.86004234 0.870834875,9.52991532 0.870834875,9.17264973 L0.870834875,3.82735027 C0.870834875,3.47008468 1.0614338,3.13995766 1.37083488,2.96132487 L6,0.288675135 C6.30940108,0.11004234 6.69059892,0.11004234 7,0.288675135 Z" id="多边形"></path>
<polyline id="路径-15" points="0.931223827 3.37218958 6.5 6.5 6.5 12.8581283"></polyline>
<line x1="6.5" y1="6.49748419" x2="12.0714286" y2="3.37218958" id="路径-16"></line>
</g>
</g>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 1.5 KiB

View File

@@ -0,0 +1,19 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="28px" height="28px" viewBox="0 0 28 28" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>编组 13备份</title>
<g id="V1.1" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<g id="红熊空间-记忆管理" transform="translate(-947, -144)">
<g id="1备份-2" transform="translate(651, 128)">
<g id="编组-13备份" transform="translate(296, 16)">
<rect id="矩形" stroke="#DFE4ED" x="0.5" y="0.5" width="27" height="27" rx="6"></rect>
<g id="进入@2x" transform="translate(5.8333, 5.8333)">
<g id="编组-11" transform="translate(2.0417, 2.5521)">
<path d="M5.42385066,3.34516089 L8.15899029,5.47250014 C8.23746067,5.5335329 8.25159666,5.64662254 8.1905639,5.72509292 C8.1813906,5.73688711 8.17078448,5.74749323 8.15899029,5.75666652 L5.42385066,7.88400578 C5.34538028,7.94503854 5.23229064,7.93090256 5.17125788,7.85243218 C5.14668314,7.82083621 5.13334107,7.78195037 5.13334107,7.74192259 L5.13334107,6.2384308 L5.13334107,6.2384308 L0,6.2384308 L0,4.99073587 L5.13334107,4.99073587 L5.13334107,3.48724407 C5.13334107,3.38783282 5.21392981,3.30724407 5.31334107,3.30724407 C5.35336884,3.30724407 5.39225469,3.32058615 5.42385066,3.34516089 Z" id="路径" fill="#5B6167" fill-rule="nonzero"></path>
<path d="M1.60417096,2.83745334 L1.60417096,0.9 C1.60417096,0.402943725 2.00711469,0 2.50417096,-1.11022302e-16 L10.3291667,-1.11022302e-16 C10.8262229,-2.22044605e-16 11.2291667,0.402943725 11.2291667,0.9 L11.2291667,10.3291667 C11.2291667,10.8262229 10.8262229,11.2291667 10.3291667,11.2291667 L2.50417096,11.2291667 C2.00711469,11.2291667 1.60417096,10.8262229 1.60417096,10.3291667 L1.60417096,8.46506778 L1.60417096,8.46506778" id="路径" stroke="#5B6167" stroke-width="1.1" stroke-linejoin="round"></path>
</g>
</g>
</g>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 2.1 KiB

View File

@@ -1,4 +1,4 @@
import { useEffect, useState, useCallback, useRef, type FC, type Key } from 'react';
import { useEffect, useState, type FC, type Key } from 'react';
import { Select } from 'antd'
import type { SelectProps, DefaultOptionType } from 'antd/es/select'
import { useTranslation } from 'react-i18next';
@@ -26,7 +26,7 @@ interface CustomSelectProps extends Omit<SelectProps, 'filterOption'> {
disabled?: boolean;
style?: React.CSSProperties;
className?: string;
filterOption?: (inputValue: string, option: DefaultOptionType) => boolean;
filterOption?: (inputValue: string, option?: DefaultOptionType) => boolean;
}
interface OptionType {
[key: string]: Key | string | number;
@@ -48,44 +48,27 @@ const CustomSelect: FC<CustomSelectProps> = ({
}) => {
const { t } = useTranslation();
const [options, setOptions] = useState<OptionType[]>([]);
// 创建防抖定时器引用
const debounceRef = useRef<number>();
// 防抖搜索函数
const handleSearch = useCallback((value?: string) => {
// 清除之前的定时器
if (debounceRef.current) {
clearTimeout(debounceRef.current);
}
// 设置新的定时器
debounceRef.current = window.setTimeout(() => {
request.get<ApiResponse<OptionType>>(url, {...params, [optionFilterProp]: value}).then((res) => {
const data = res;
setOptions(Array.isArray(data) ? data || [] : Array.isArray(data?.items) ? data.items || [] : []);
});
}, 300); // 300毫秒防抖延迟
}, [url, params, optionFilterProp]);
// 默认模糊搜索函数
const defaultFilterOption = (inputValue: string, option?: DefaultOptionType) => {
if (!option || !inputValue) return true;
const label = String(option.children || option.label || '');
return label.toLowerCase().includes(inputValue.toLowerCase());
};
// 组件挂载时获取初始数据
useEffect(() => {
handleSearch();
// 组件卸载时清除定时器
return () => {
if (debounceRef.current) {
clearTimeout(debounceRef.current);
}
};
}, [url, handleSearch]);
request.get<ApiResponse<OptionType>>(url, params).then((res) => {
const data = res;
setOptions(Array.isArray(data) ? data || [] : Array.isArray(data?.items) ? data.items || [] : []);
});
}, []);
return (
<Select
placeholder={placeholder ? placeholder : t('common.select')}
onChange={onChange}
defaultValue={hasAll ? null : undefined}
showSearch={showSearch}
onSearch={handleSearch}
filterOption={filterOption || false} // 禁用本地过滤,使用服务器端过滤
filterOption={filterOption || defaultFilterOption}
{...props}
>
{hasAll && (<Select.Option>{allTitle || t('common.all')}</Select.Option>)}

View File

@@ -40,6 +40,8 @@ import apiKeyIcon from '@/assets/images/menu/apiKey.png';
import apiKeyActiveIcon from '@/assets/images/menu/apiKey_active.png';
import pricingIcon from '@/assets/images/menu/pricing.svg'
import pricingActiveIcon from '@/assets/images/menu/pricing_active.svg'
import spaceConfigIcon from '@/assets/images/menu/spaceConfig.svg'
import spaceConfigActiveIcon from '@/assets/images/menu/spaceConfig_active.svg'
// 图标路径映射表
const iconPathMap: Record<string, string> = {
@@ -68,7 +70,9 @@ const iconPathMap: Record<string, string> = {
'apiKey': apiKeyIcon,
'apiKeyActive': apiKeyActiveIcon,
'pricing': pricingIcon,
'pricingActive': pricingActiveIcon
'pricingActive': pricingActiveIcon,
'spaceConfig': spaceConfigIcon,
'spaceConfigActive': spaceConfigActiveIcon,
};
const { Sider } = Layout;

View File

@@ -87,7 +87,7 @@ export const en = {
modelManagement: 'Model Management',
memoryStore: 'Memory Store',
apiParameters: 'API Parameters',
userMemory: 'User Memory',
userMemory: 'Memory Store',
memberManagement: 'Member Management',
memorySummary: 'Memory Summary',
memoryConversation: 'Memory Validation',
@@ -110,6 +110,7 @@ export const en = {
pricing: 'Pricing Management',
orderPayment: 'Order Payment',
orderHistory: 'Order History',
spaceConfig: 'Space Configuration'
},
dashboard: {
total_models: 'Total number of available models',
@@ -1232,6 +1233,8 @@ export const en = {
hire_date: 'Hire Date',
memoryContent: 'Memory Content',
created_at: 'Created At',
updated_at: 'Updated At',
fullScreen: 'Full Screen',
memoryWindow: "{{name}}'s Window of Memory",
memory_insight: 'Overall Overview',
@@ -1258,7 +1261,7 @@ export const en = {
unix: 'items',
completeMemory: 'Complete Memory',
relationshipEvolution: 'Relationship Evolution',
timelineMemories: 'Shared Memory Timeline',
timelineMemories: 'Long-term Memory',
emotionLine: 'Emotion Changes Over Time',
interaction: 'Interaction Frequency & Relationship Stages',
timelines_memory: 'All',
@@ -1269,6 +1272,12 @@ export const en = {
negative: 'Negative Emotion',
neutral: 'Neutral Emotion',
interactionCountData: 'Interaction Count',
capacity: 'Capacity',
type: 'Type',
person: 'Personal',
memoryNum: 'memories',
memory_config_name: 'Memory Engine',
searchPlaceholder: 'Search memory store name',
},
space: {
createSpace: 'Create Space',
@@ -1284,7 +1293,8 @@ export const en = {
neo4jDesc: 'Based on knowledge graph, suitable for relational reasoning and path query',
llmModel: 'LLM Model',
embeddingModel: 'Embedding Model',
rerankModel: 'Rerank Model'
rerankModel: 'Rerank Model',
configAlert: 'Space model configuration ensures that the space can correctly call the corresponding models to process business data during runtime.',
},
memoryExtractionEngine: {
title: 'Memory Engine Module Configuration Center',
@@ -1459,6 +1469,8 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
quickReply: 'Quick Reply',
web_search: 'Online search',
memory: 'Memory',
memoryConversationAnalysisEmpty: 'There is currently no dialogue analysis content available',
memoryConversationAnalysisEmptySubTitle: 'After entering your user ID, click on "Test Memory" to view the conversation memory',
},
login: {
title: 'Red Bear Memory Science',
@@ -1613,19 +1625,17 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
JsonTool_desc: 'Data Format Conversion',
JsonTool_features: 'JSON formatting, compression, validation and conversion functions',
jsonFormat: 'JSON Formatting',
jsonGzip: 'JSON Compression',
jsonCheck: 'JSON Validation',
jsonConversion: 'Format Conversion',
jsonParse: 'JSON Parse',
jsonInsert: 'JSON Insert',
jsonReplace: 'JSON Validation',
jsonDelete: 'JSON Delete',
jsonEg: 'Example JSON',
enterJson: 'Enter JSON',
jsonPlaceholder: 'Enter JSON data, e.g.: {"name": "test", "value": 123}',
clear: 'Clear',
parse: 'Paste',
format: 'Format',
minify: 'Minify',
validate: 'Validate',
convert: 'Escape',
paste: 'Paste',
parse: 'Parse',
json_path: 'JSON Path Parameters',
outputResult: 'Output Result',
validJosn: 'JSON format is correct, validation passed!',
@@ -1944,7 +1954,8 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
variableConfig: 'Variable Configuration',
variableRequired: 'Required',
addMessage: 'Add Message',
answerDesc: 'Reply'
answerDesc: 'Reply',
addNode: 'Add Node',
},
emotionEngine: {
emotionEngineConfig: 'Emotion Engine Configuration',

View File

@@ -87,7 +87,7 @@ export const zh = {
modelManagement: '模型管理',
memoryStore: '记忆存储',
apiParameters: 'API参数',
userMemory: '用户记忆',
userMemory: '记忆',
memberManagement: '成员管理',
memorySummary: '记忆摘要',
memoryConversation: '记忆验证',
@@ -102,7 +102,7 @@ export const zh = {
knowledgeShare: '详情',
knowledgeCreateDataset: '新建数据集',
knowledgeDocumentDetails: '详情',
userMemoryDetail: '用户记忆详情',
userMemoryDetail: '记忆详情',
toolManagement: '工具管理',
emotionEngine: '情感引擎',
statementDetail: '情绪记忆',
@@ -110,6 +110,7 @@ export const zh = {
pricing: '收费管理',
orderPayment: '订单支付',
orderHistory: '订单记录',
spaceConfig: '空间配置'
},
knowledgeBase: {
home: '首页',
@@ -1313,7 +1314,7 @@ export const zh = {
updated_at: '最后更新时间',
fullScreen: '全屏',
memoryWindow: "{{name}}的记忆之窗",
memoryWindow: "{{name}} 的记忆之窗",
memory_insight: '总体概述',
key_findings: '关键发现',
behavior_pattern: '行为模式',
@@ -1338,7 +1339,7 @@ export const zh = {
unix: '个',
completeMemory: '完整记忆',
relationshipEvolution: '关系演化',
timelineMemories: '共同记忆时间线',
timelineMemories: '长期记忆',
emotionLine: '情绪随时间变化',
interaction: '互动频率 & 关系阶段',
timelines_memory: '全部',
@@ -1349,6 +1350,12 @@ export const zh = {
negative: '负向情绪',
neutral: '中性情绪',
interactionCountData: '互动次数',
capacity: '容量',
type: '类型',
person: '个人',
memoryNum: '条记忆',
memory_config_name: '记忆引擎',
searchPlaceholder: '搜索记忆库名称',
},
space: {
createSpace: '创建空间',
@@ -1364,7 +1371,8 @@ export const zh = {
neo4jDesc: '基于知识图谱,适合关系推理和路径查询',
llmModel: 'LLM 模型',
embeddingModel: 'Embedding 模型',
rerankModel: 'Rerank 模型'
rerankModel: 'Rerank 模型',
configAlert: '空间模型配置为空间的模型模型,保障空间运行时能正确的调用到相应的模型来处理业务数据。',
},
memoryExtractionEngine: {
title: '记忆引擎模块配置中心',
@@ -1537,6 +1545,8 @@ export const zh = {
quickReply: '快速回复',
web_search: '联网搜索',
memory: '记忆',
memoryConversationAnalysisEmpty: '目前没有可用的对话分析内容',
memoryConversationAnalysisEmptySubTitle: '输入您的用户ID后点击"测试记忆"查看对话记忆',
},
login: {
title: '红熊记忆科学',
@@ -1711,19 +1721,17 @@ export const zh = {
JsonTool_desc: '数据格式转换',
JsonTool_features: 'JSON格式化、压缩、验证和转换功能',
jsonFormat: 'JSON格式化',
jsonGzip: 'JSON压缩',
jsonCheck: 'JSON验证',
jsonConversion: '格式转换',
jsonParse: 'JSON解析',
jsonInsert: 'JSON插入',
jsonReplace: 'JSON验证',
jsonDelete: 'JSON删除',
jsonEg: '示例JSON',
enterJson: '输入JSON',
jsonPlaceholder: '输入JSON数据例如{"name": "测试", "value": 123}',
clear: '清空',
parse: '粘贴',
format: '格式化',
minify: '压缩',
validate: '验证',
convert: '转义',
paste: '粘贴',
parse: '解析',
json_path: 'JSON 路径参数',
outputResult: '输出结果',
validJosn: 'JSON格式正确验证通过',
@@ -2043,7 +2051,8 @@ export const zh = {
variableConfig: '变量配置',
variableRequired: '必填',
addMessage: '添加消息',
answerDesc: '回复'
answerDesc: '回复',
addNode: '添加节点',
},
emotionEngine: {
emotionEngineConfig: '情感引擎配置',

View File

@@ -66,6 +66,7 @@ const componentMap: Record<string, LazyExoticComponent<ComponentType<object>>> =
OrderHistory: lazy(() => import('@/views/OrderHistory')),
Pricing: lazy(() => import('@/views/Pricing')),
ToolManagement: lazy(() => import('@/views/ToolManagement')),
SpaceConfig: lazy(() => import('@/views/SpaceConfig')),
Login: lazy(() => import('@/views/Login')),
InviteRegister: lazy(() => import('@/views/InviteRegister')),
NoPermission: lazy(() => import('@/views/NoPermission')),

View File

@@ -33,6 +33,7 @@
{ "path": "/api-key", "element": "ApiKeyManagement" },
{ "path": "/emotion-engine/:id", "element": "EmotionEngine" },
{ "path": "/reflection-engine/:id", "element": "SelfReflectionEngine" },
{ "path": "/space-config", "element": "SpaceConfig" },
{ "path": "/no-permission", "element": "NoPermission" },
{ "path": "/*", "element": "NotFound" }
]

View File

@@ -376,6 +376,21 @@
"icon": null,
"iconActive": null,
"subs": null
},
{
"id": 12,
"parent": 0,
"code": "spaceConfig",
"label": "空间配置",
"i18nKey": "menu.spaceConfig",
"path": "/space-config",
"enable": true,
"display": true,
"level": 1,
"sort": 0,
"icon": null,
"iconActive": null,
"subs": null
}
]
}

View File

@@ -0,0 +1,118 @@
import { type FC, useEffect, useState } from 'react';
import { Form, App, Button, Skeleton } from 'antd';
import { useTranslation } from 'react-i18next';
import type { SpaceConfigData } from './types'
import { getWorkspaceModels, updateWorkspaceModels } from '@/api/workspaces'
import { getModelListUrl } from '@/api/models'
import CustomSelect from '@/components/CustomSelect'
import RbAlert from '@/components/RbAlert';
const SpaceConfig: FC = () => {
const { t } = useTranslation();
const { message } = App.useApp();
const [pageLoading, setPageLoding] = useState(false)
const [form] = Form.useForm<SpaceConfigData>();
const [loading, setLoading] = useState(false)
const values = Form.useWatch([], form);
useEffect(() => {
setPageLoding(true)
getWorkspaceModels().then((res) => {
const { llm, embedding, rerank } = res as SpaceConfigData
form.setFieldsValue({
llm,
embedding,
rerank
})
})
.finally(() => {
setPageLoding(false)
})
}, [])
// 封装保存方法,添加提交逻辑
const handleSave = () => {
form
.validateFields()
.then(() => {
setLoading(true)
updateWorkspaceModels(values)
.then(() => {
setLoading(false)
message.success(t('common.updateSuccess'))
})
.catch(() => {
setLoading(false)
});
})
.catch((err) => {
console.log('err', err)
});
}
return (
<div className="rb:h-full rb:max-w-140 rb:mx-auto">
{pageLoading
? <Skeleton active />
: <Form
form={form}
layout="vertical"
>
<Form.Item
label={t('space.llmModel')}
name="llm"
rules={[{ required: true, message: t('common.pleaseSelect') }]}
>
<CustomSelect
url={getModelListUrl}
params={{ type: 'llm', pagesize: 100 }}
valueKey="id"
labelKey="name"
hasAll={false}
style={{width: '100%'}}
/>
</Form.Item>
<Form.Item
label={t('space.embeddingModel')}
name="embedding"
rules={[{ required: true, message: t('common.pleaseSelect') }]}
>
<CustomSelect
url={getModelListUrl}
params={{ type: 'embedding', pagesize: 100 }}
valueKey="id"
labelKey="name"
hasAll={false}
style={{width: '100%'}}
/>
</Form.Item>
<Form.Item
label={t('space.rerankModel')}
name="rerank"
rules={[{ required: true, message: t('common.pleaseSelect') }]}
>
<CustomSelect
url={getModelListUrl}
params={{ type: 'rerank', pagesize: 100 }}
valueKey="id"
labelKey="name"
hasAll={false}
style={{width: '100%'}}
/>
</Form.Item>
<RbAlert>{t('space.configAlert')}</RbAlert>
<Form.Item className="rb:text-right">
<Button type="primary" className="rb:mt-6" onClick={handleSave} loading={loading}>
{t('common.save')}
</Button>
</Form.Item>
</Form>
}
</div>
);
};
export default SpaceConfig;

View File

@@ -0,0 +1,8 @@
export interface SpaceConfigData {
llm: string;
embedding: string;
rerank: string;
}
export interface SpaceConfigRef {
handleOpen: () => void;
}

View File

@@ -1,5 +1,5 @@
import { forwardRef, useImperativeHandle, useState } from 'react';
import { Form, Input, Button, Space, Tree } from 'antd';
import { Form, Input, Button, Space } from 'antd';
import { useTranslation } from 'react-i18next';
import type { TreeDataNode } from 'antd';
@@ -12,7 +12,7 @@ import { execute } from '@/api/tools';
const JsonToolModal = forwardRef<JsonToolModalRef>((_props, ref) => {
const { t } = useTranslation();
const [visible, setVisible] = useState(false);
const [form] = Form.useForm<{ json: string; }>();
const [form] = Form.useForm<{ json: string; json_path: string; }>();
const [data, setData] = useState<ToolItem>({} as ToolItem)
const [formatValue, setFormatValue] = useState<string | Record<string, any> | null>(null)
@@ -60,44 +60,29 @@ const JsonToolModal = forwardRef<JsonToolModalRef>((_props, ref) => {
}
const handleOperate = (type: string) => {
const json = form.getFieldValue('json')
const json_path = form.getFieldValue('json_path')
if (!json || !data.id) return
let params: ExecuteData = {
tool_id: data.id,
parameters: {
operation: type,
input_data: json
input_data: json,
json_path
}
}
if (type === 'format') {
if (type === 'parse') {
params = {
...params,
parameters: {
...params.parameters,
indent: 2,
ensure_ascii: false,
sort_keys: false
}
}
}
execute(params)
.then(res => {
const { data } = res as {data: {
formatted_json: string;
minified_json: string;
is_valid: boolean;
converted_json: string;
error: string;
structure: Record<string, string | number>
}}
switch (type) {
case 'format':
setFormatValue(data.formatted_json);
break
case 'minify':
setFormatValue(data.minified_json)
break
}
const { data } = res as { data: string; }
setFormatValue(data);
})
}
const clear = () => {
@@ -126,15 +111,20 @@ const JsonToolModal = forwardRef<JsonToolModalRef>((_props, ref) => {
label={<Space size={8}>
{t('tool.enterJson')}
<Button onClick={clear}>{t('tool.clear')}</Button>
<Button onClick={handleParse}>{t('tool.parse')}</Button>
<Button onClick={handleParse}>{t('tool.paste')}</Button>
</Space>}
>
<Input.TextArea rows={10} placeholder={t('tool.jsonPlaceholder')} />
</FormItem>
<FormItem
name="json_path"
label={t('tool.json_path')}
>
<Input placeholder={t('common.pleaseEnter')} />
</FormItem>
<Space size={8} className="rb:mb-3">
<Button onClick={() => handleOperate('format')}>{t('tool.format')}</Button>
<Button onClick={() => handleOperate('minify')}>{t('tool.minify')}</Button>
<Button onClick={() => handleOperate('parse')}>{t('tool.parse')}</Button>
</Space>
<FormItem
label={t('tool.outputResult')}

View File

@@ -23,6 +23,7 @@ interface CurrentTimeObj {
iso_format: string;
timestamp: string;
timestamp_ms: string;
utc_datetime: string;
}
const TimeToolModal = forwardRef<TimeToolModalRef>((_props, ref) => {
const { t } = useTranslation();
@@ -88,8 +89,8 @@ const TimeToolModal = forwardRef<TimeToolModalRef>((_props, ref) => {
}
})
.then(res => {
const response = res as { data: CurrentTimeObj }
setTimestampFormat(response.data.datetime)
const response = res as { data: string }
setTimestampFormat(response.data)
})
}
const handleChangeFormatType = () => {
@@ -149,7 +150,7 @@ const TimeToolModal = forwardRef<TimeToolModalRef>((_props, ref) => {
<Input disabled value={currentTime?.datetime} />
</FormItem>
<FormItem label={t('tool.utcTime')} >
<Input disabled value={currentTime?.iso_format} />
<Input disabled value={currentTime?.utc_datetime} />
</FormItem>
<FormItem label={t('tool.secondsTimestamp')} >
<Input disabled value={currentTime?.timestamp} />

View File

@@ -10,10 +10,10 @@ export const InnerConfigData: Record<string, InnerConfigItem> = {
},
JsonTool: {
features: [
'jsonFormat',
'jsonGzip',
'jsonCheck',
'jsonConversion'
'jsonParse',
'jsonInsert',
'jsonReplace',
'jsonDelete'
],
eg: '{"name":"工具","tool_class":"内置"}'
},

View File

@@ -130,6 +130,7 @@ export interface ExecuteData {
ensure_ascii?: boolean;
sort_keys?: boolean;
input_data?: string;
json_path?: string;
}
}
export interface CustomToolModalRef {

View File

@@ -1,127 +0,0 @@
import { forwardRef, useImperativeHandle, useState } from 'react';
import { Form, App } from 'antd';
import { useTranslation } from 'react-i18next';
import type { ConfigModalData, ConfigModalRef } from '../types'
import { getWorkspaceModels, updateWorkspaceModels } from '@/api/workspaces'
import { getModelListUrl } from '@/api/models'
import CustomSelect from '@/components/CustomSelect'
import RbModal from '@/components/RbModal'
const ConfigModal = forwardRef<ConfigModalRef>((_props, ref) => {
const { t } = useTranslation();
const { message } = App.useApp();
const [visible, setVisible] = useState(false);
const [form] = Form.useForm<ConfigModalData>();
const [loading, setLoading] = useState(false)
const values = Form.useWatch([], form);
// 封装取消方法,添加关闭弹窗逻辑
const handleClose = () => {
setVisible(false);
form.resetFields();
setLoading(false)
};
const handleOpen = () => {
getWorkspaceModels().then((res) => {
const { llm, embedding, rerank } = res as ConfigModalData
form.setFieldsValue({
llm,
embedding,
rerank
})
})
setVisible(true);
};
// 封装保存方法,添加提交逻辑
const handleSave = () => {
form
.validateFields()
.then(() => {
setLoading(true)
updateWorkspaceModels(values)
.then(() => {
setLoading(false)
handleClose()
message.success(t('common.updateSuccess'))
})
.catch(() => {
setLoading(false)
});
handleClose()
})
.catch((err) => {
console.log('err', err)
});
}
// 暴露给父组件的方法
useImperativeHandle(ref, () => ({
handleOpen,
handleClose
}));
return (
<RbModal
title={t(`userMemory.editConfig`)}
open={visible}
onCancel={handleClose}
okText={t('common.save')}
onOk={handleSave}
confirmLoading={loading}
>
<Form
form={form}
layout="vertical"
>
<Form.Item
label={t('space.llmModel')}
name="llm"
rules={[{ required: true, message: t('common.pleaseSelect') }]}
>
<CustomSelect
url={getModelListUrl}
params={{ type: 'llm', pagesize: 100 }}
valueKey="id"
labelKey="name"
hasAll={false}
style={{width: '100%'}}
/>
</Form.Item>
<Form.Item
label={t('space.embeddingModel')}
name="embedding"
rules={[{ required: true, message: t('common.pleaseSelect') }]}
>
<CustomSelect
url={getModelListUrl}
params={{ type: 'embedding', pagesize: 100 }}
valueKey="id"
labelKey="name"
hasAll={false}
style={{width: '100%'}}
/>
</Form.Item>
<Form.Item
label={t('space.rerankModel')}
name="rerank"
rules={[{ required: true, message: t('common.pleaseSelect') }]}
>
<CustomSelect
url={getModelListUrl}
params={{ type: 'rerank', pagesize: 100 }}
valueKey="id"
labelKey="name"
hasAll={false}
style={{width: '100%'}}
/>
</Form.Item>
</Form>
</RbModal>
);
});
export default ConfigModal;

View File

@@ -1,56 +1,28 @@
import { useEffect, useState, useRef } from 'react';
import { useEffect, useState, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useNavigate } from 'react-router-dom'
import { Row, Col, Radio, Button, List, Skeleton, Space } from 'antd';
import type { ColumnsType } from 'antd/es/table';
import type { RadioChangeEvent } from 'antd';
import { AppstoreOutlined, MenuOutlined } from '@ant-design/icons';
import { Row, Col, List, Skeleton } from 'antd';
import Empty from '@/components/Empty'
import type { Data, ConfigModalRef } from './types'
import totalNum from '@/assets/images/memory/totalNum.svg'
import onlineNum from '@/assets/images/memory/onlineNum.svg'
import Table from '@/components/Table'
import { getTotalEndUsers, userMemoryListUrl, getUserMemoryList } from '@/api/memory';
import ConfigModal from './components/ConfigModal';
import type { Data } from './types'
import { getUserMemoryList } from '@/api/memory';
import { useUser } from '@/store/user'
import RbCard from '@/components/RbCard/Card'
import SearchInput from '@/components/SearchInput';
const bgList = [
'linear-gradient( 180deg, #F1F6FE 0%, #FBFDFF 100%)',
'linear-gradient( 180deg, #F1F9FE 0%, #FBFDFF 100%)',
'linear-gradient( 180deg, #FEFBF7 0%, #FBFDFF 100%)',
'linear-gradient( 180deg, #F1F9FE 0%, #FBFDFF 100%)',
]
const countList = [
'total_num', 'online_num',
]
const IconList: Record<string, string> = {
total_num: totalNum,
online_num: onlineNum,
}
export default function UserMemory() {
const { t } = useTranslation();
const navigate = useNavigate()
const { storageType } = useUser()
const configModalRef = useRef<ConfigModalRef>(null)
const [loading, setLoading] = useState<boolean>(false);
const [data, setData] = useState<Data[]>([]);
const [countData, setCountData] = useState<Record<string, number>>({});
const [layout, setLayout] = useState<'card' | 'list'>('card');
const [search, setSearch] = useState<string | undefined>(undefined);
// 获取数据
useEffect(() => {
getCountData()
getData()
}, []);
// 用户记忆统计
const getCountData = () => {
getTotalEndUsers().then((res) => {
setCountData(res as Record<string, number> || {})
})
}
const getData = () => {
setLoading(true)
getUserMemoryList().then((res) => {
@@ -60,7 +32,6 @@ export default function UserMemory() {
setLoading(false)
})
}
console.log('storageType', storageType)
const handleViewDetail = (id: string | number) => {
switch (storageType) {
case 'neo4j':
@@ -70,112 +41,77 @@ export default function UserMemory() {
navigate(`/user-memory/${id}`)
}
}
const handleChangeLayout = (e: RadioChangeEvent) => {
const type = e.target.value
setLayout(type)
const handleViewMemoryConfig = () => {
navigate(`/memory`)
}
// 表格列配置
const columns: ColumnsType = [
{
title: t('userMemory.user'),
dataIndex: 'end_user',
key: 'end_user',
render: (value) => value?.other_name && value?.other_name !== '' ? value?.other_name : value?.id || '-'
},
{
title: t('userMemory.knowledgeEntryCount'),
dataIndex: 'memory_num',
key: 'memory_num',
render: (value) => value?.total || 0
},
{
title: t('common.operation'),
key: 'action',
render: (_, record) => (
<Button
type="link"
onClick={() => handleViewDetail(record.end_user?.id)}
>
{t('common.viewDetail')}
</Button>
),
},
];
const filterData = useMemo(() => {
if (search && search.trim() !== '') {
return data.filter((item) => {
const { end_user } = item as Data;
const name = end_user?.other_name && end_user?.other_name !== '' ? end_user?.other_name : end_user?.id
return name?.includes(search)
})
}
return data
}, [search, data])
return (
<div>
<Row gutter={16} className="rb:mb-4">
{countList.map(key => (
<Col key={key} span={6}>
<div className="rb:bg-[#FBFDFF] rb:border rb:border-[#DFE4ED] rb:rounded-xl rb:p-[18px_20px_20px_20px]">
<div className="rb:text-[28px] rb:font-extrabold rb:leading-8.75 rb:flex rb:items-center rb:justify-between rb:mb-3">
{countData[key] || 0}{key === 'avgInteractionTime' ? 's' : ''}
<img className="rb:w-6 rb:h-6" src={IconList[key]} />
</div>
<div className="rb:text-[12px] rb:text-[#5B6167] rb:font-regular rb:leading-4">{t(`userMemory.${key}`)}</div>
</div>
</Col>
))}
<Col span={12} className="rb:text-right">
<Space>
<Button type="primary" onClick={() => configModalRef?.current?.handleOpen()}>{t('userMemory.chooseModel')}</Button>
<Radio.Group value={layout} onChange={handleChangeLayout}>
<Radio.Button value="card" disabled={layout === 'card'}><AppstoreOutlined /></Radio.Button>
<Radio.Button value="list" disabled={layout === 'list'}><MenuOutlined /></Radio.Button>
</Radio.Group>
</Space>
<Col span={8}>
<SearchInput
placeholder={t('userMemory.searchPlaceholder')}
onSearch={(value) => setSearch(value)}
style={{ width: '100%' }}
/>
</Col>
</Row>
{layout === 'card' &&
<>
{loading ?
<Skeleton active />
: data.length > 0 ? (
<List
grid={{ gutter: 16, column: 4 }}
dataSource={data}
renderItem={(item, index) => {
const { end_user, memory_num } = item as Data;
const name = end_user?.other_name && end_user?.other_name !== '' ? end_user?.other_name : end_user?.id
return (
<List.Item key={index}>
<div
className="rb:p-5 rb:rounded-xl rb:border rb:border-[#DFE4ED] rb:cursor-pointer"
style={{
background: bgList[index % bgList.length],
}}
{loading ?
<Skeleton active />
: filterData.length > 0 ? (
<List
grid={{ gutter: 16, column: 3 }}
dataSource={filterData}
renderItem={(item, index) => {
const { end_user, memory_num, memory_config } = item as Data;
const name = end_user?.other_name && end_user?.other_name !== '' ? end_user?.other_name : end_user?.id
return (
<List.Item key={index}>
<RbCard
avatar={<div className="rb:w-12 rb:h-12 rb:text-center rb:font-semibold rb:text-[28px] rb:leading-12 rb:rounded-lg rb:text-[#FBFDFF] rb:bg-[#155EEF] rb:mr-2">{name[0]}</div>}
title={name || '-'}
extra={<div
className="rb:w-7 rb:h-7 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/userMemory/goto.svg')]"
onClick={() => handleViewDetail(end_user.id)}
>
<div className="rb:flex rb:items-center">
<div className="rb:w-12 rb:h-12 rb:text-center rb:font-semibold rb:text-[28px] rb:leading-12 rb:rounded-lg rb:text-[#FBFDFF] rb:bg-[#155EEF]">{name[0]}</div>
<div className="rb:max-w-[calc(100%-60px)] rb:text-base rb:font-medium rb:leading-6 rb:ml-3 rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">
{name || '-'}<br/>
</div>
</div>
<div className="rb:grid rb:grid-cols-1 rb:gap-3 rb:mt-7 rb:mb-7">
<div className="rb:text-center">
<div className="rb:text-[24px] rb:leading-7.5 rb:font-extrabold">{memory_num.total || 0}</div>
<div className="rb:wrap-break-word">{t(`userMemory.knowledgeEntryCount`)}</div>
</div>
</div>
></div>}
>
<div className="rb:flex rb:justify-between rb:items-center">
<div>{t('userMemory.capacity')}</div>
<div>{memory_num?.total || 0} {t('userMemory.memoryNum')}</div>
</div>
<div className="rb:flex rb:justify-between rb:items-center rb:mt-2.5">
<div>{t('userMemory.type')}</div>
<div>{t(`userMemory.${item.type || 'person'}`)}</div>
</div>
</List.Item>
)
}}
/>
) : <Empty />}
</>
}
{layout === 'list' &&
<Table
apiUrl={userMemoryListUrl}
columns={columns}
rowKey="end_user.id"
pagination={false}
/>
<div className="rb:mt-3 rb:bg-[#F6F8FC] rb:rounded-lg rb:border rb:border-[#DFE4ED] rb:py-2 rb:px-3" onClick={handleViewMemoryConfig}>
<div className="rb:text-[#5B6167] rb:leading-5 rb:flex rb:justify-between rb:items-center">
{t('userMemory.memory_config_name')}
<div
className="rb:w-7 rb:h-7 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/userMemory/arrow_right.svg')]"
></div>
</div>
<div className="rb:font-medium rb:leading-5 rb:mt-1">{memory_config?.memory_config_name || '-'}</div>
</div>
</RbCard>
</List.Item>
)
}}
/>
) : <Empty />
}
<ConfigModal ref={configModalRef} />
</div>
);
}

View File

@@ -17,13 +17,10 @@ export interface Data {
entity: number;
}
},
memory_config: {
memory_config_id: string;
memory_config_name: string;
},
type: string;
name?: string;
}
export interface ConfigModalData {
llm: string;
embedding: string;
rerank: string;
}
export interface ConfigModalRef {
handleOpen: () => void;
}

View File

@@ -3,8 +3,7 @@ import { useTranslation } from 'react-i18next'
import ReactEcharts from 'echarts-for-react';
import Empty from '@/components/Empty'
import Loading from '@/components/Empty/Loading'
import type { Emotion } from './GraphDetail'
import { format } from 'echarts';
import type { Emotion } from '../pages/GraphDetail'
interface EmotionLineProps {
chartData: Emotion[];

View File

@@ -3,7 +3,7 @@ import { useTranslation } from 'react-i18next'
import ReactEcharts from 'echarts-for-react'
import Empty from '@/components/Empty'
import Loading from '@/components/Empty/Loading'
import type { Interaction } from './GraphDetail'
import type { Interaction } from '../pages/GraphDetail'
interface InteractionBarProps {
chartData: Interaction[];

View File

@@ -1,19 +1,18 @@
import React, { type FC, useEffect, useState, useRef, useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import { useParams } from 'react-router-dom'
import { useParams, useNavigate } from 'react-router-dom'
import { Col, Row, Space, Button } from 'antd'
import dayjs from 'dayjs'
import RbCard from '@/components/RbCard/Card'
import ReactEcharts from 'echarts-for-react'
import detailEmpty from '@/assets/images/userMemory/detail_empty.png'
import type { Node, Edge, GraphData, StatementNodeProperties, ExtractedEntityNodeProperties, GraphDetailRef } from '../types'
import type { Node, Edge, GraphData, StatementNodeProperties, ExtractedEntityNodeProperties } from '../types'
import {
getMemorySearchEdges,
} from '@/api/memory'
import Empty from '@/components/Empty'
import Tag from '@/components/Tag'
import GraphDetail from '../components/GraphDetail'
const colors = ['#155EEF', '#369F21', '#4DA8FF', '#FF5D34', '#9C6FFF', '#FF8A4C', '#8BAEF7', '#FFB048']
const RelationshipNetwork:FC = () => {
@@ -26,7 +25,7 @@ const RelationshipNetwork:FC = () => {
const [categories, setCategories] = useState<{ name: string }[]>([])
const [selectedNode, setSelectedNode] = useState<Node | null>(null)
// const [fullScreen, setFullScreen] = useState<boolean>(false)
const graphDetailRef = useRef<GraphDetailRef>(null)
const navigate = useNavigate()
console.log('categories', categories)
// 关系网络
@@ -133,15 +132,14 @@ const RelationshipNetwork:FC = () => {
}
}, [nodes])
// const handleFullScreen = () => {
// setFullScreen(prev => !prev)
// }
console.log('selectedNode', selectedNode)
const handleViewAll = () => {
if (!selectedNode) return
graphDetailRef.current?.handleOpen(selectedNode)
const params = new URLSearchParams({
nodeId: selectedNode.id,
nodeLabel: selectedNode.label,
nodeName: selectedNode.name || ''
})
navigate(`/user-memory/detail/${id}/GRAPH?${params.toString()}`)
}
return (
@@ -336,8 +334,6 @@ const RelationshipNetwork:FC = () => {
</div>
</RbCard>
</Col>
<GraphDetail ref={graphDetailRef} />
</Row>
)
}

View File

@@ -1,16 +1,17 @@
import { useState, forwardRef, useImperativeHandle, useMemo } from 'react'
import { useState, forwardRef, useImperativeHandle, useMemo, useEffect } from 'react'
import { useTranslation } from 'react-i18next'
import { useSearchParams } from 'react-router-dom'
import { Row, Col, Tabs, Space, Skeleton } from 'antd'
import { getRelationshipEvolution, getTimelineMemories } from '@/api/memory'
import type { Node, GraphDetailRef } from '../types'
import RbDrawer from '@/components/RbDrawer'
import RbCard from '@/components/RbCard/Card'
import EmotionLine from './EmotionLine'
import EmotionLine from '../components/EmotionLine'
import { formatDateTime } from '@/utils/format'
import Tag from '@/components/Tag'
import InteractionBar from './InteractionBar'
import InteractionBar from '../components/InteractionBar'
import Empty from '@/components/Empty'
import PageHeader from '../components/PageHeader'
export interface Emotion {
emotion_intensity: number;
@@ -35,7 +36,7 @@ interface Timeline {
const GraphDetail = forwardRef<GraphDetailRef>((_props, ref) => {
const { t } = useTranslation()
const [open, setOpen] = useState(false);
const [searchParams] = useSearchParams()
const [vo, setVo] = useState<Node | null>(null)
const [loading, setLoading] = useState(false)
const [emotionData, setEmotionData] = useState<Emotion[]>([])
@@ -43,14 +44,23 @@ const GraphDetail = forwardRef<GraphDetailRef>((_props, ref) => {
const [activeTab, setActiveTab] = useState('timelines_memory')
const [timelineLoading, setTimelineLoading] = useState(false)
const [timelineMemories, setTimelineMemories] = useState<Timeline>({ timelines_memory: [], MemorySummary: [], Statement: [], ExtractedEntity: []})
useEffect(() => {
const nodeId = searchParams.get('nodeId')
const nodeLabel = searchParams.get('nodeLabel')
const nodeName = searchParams.get('nodeName')
if (nodeId && nodeLabel) {
const nodeFromUrl = {
id: nodeId,
label: nodeLabel,
name: nodeName || nodeLabel
}
handleOpen(nodeFromUrl as Node)
}
}, [searchParams])
const handleCancel = () => {
setVo(null)
setOpen(false)
}
const handleOpen = (vo: Node) => {
setActiveTab('timelines_memory')
setOpen(true)
setVo(vo)
getRelationshipEvolutionData(vo)
getTimelineMemoriesData(vo)
@@ -85,56 +95,57 @@ const GraphDetail = forwardRef<GraphDetailRef>((_props, ref) => {
}, [activeTab, timelineMemories])
return (
<RbDrawer
title={vo?.name}
open={open}
onClose={handleCancel}
width={1000}
>
<div className="rb:text-[16px] rb:font-medium rb:leading-5.5 rb:mb-3">{t('userMemory.relationshipEvolution')}</div>
<RbCard>
<Row gutter={16}>
<Col span={12}>
<EmotionLine chartData={emotionData} loading={loading} />
</Col>
<Col span={12}>
<InteractionBar chartData={interactionData} loading={loading} />
</Col>
</Row>
</RbCard>
<>
<PageHeader
name={vo?.name}
source="node"
/>
<div className="rb:h-full rb:max-w-266 rb:mx-auto">
<div className="rb:text-[16px] rb:font-medium rb:leading-5.5 rb:mb-3">{t('userMemory.relationshipEvolution')}</div>
<RbCard>
<Row gutter={16}>
<Col span={12}>
<EmotionLine chartData={emotionData} loading={loading} />
</Col>
<Col span={12}>
<InteractionBar chartData={interactionData} loading={loading} />
</Col>
</Row>
</RbCard>
<div className="rb:text-[16px] rb:font-medium rb:leading-5.5 rb:mb-3 rb:mt-6">{t('userMemory.timelineMemories')}</div>
<RbCard>
<Tabs
activeKey={activeTab}
items={['timelines_memory', 'ExtractedEntity', 'Statement', 'MemorySummary'].map(key => ({
label: t(`userMemory.${key}`),
key
}))}
onChange={(key: string) => setActiveTab(key)}
/>
{timelineLoading
? <Skeleton active />
: !activeContent || activeContent.length === 0
? <Empty size={120} className="rb:mt-12 rb:mb-20.25" />
: <Space size={16} direction="vertical" className="rb:w-full">
{activeContent.map((vo, index) => (
<RbCard
key={index}
headerType="borderL"
headerClassName="rb:before:bg-[#155EEF]!"
title={vo.text}
>
<div className="rb:text-[#A8A9AA] rb:text-[12px] rb:leading-4">{formatDateTime(vo.created_at)}</div>
<Tag className="rb:mt-2">{vo.type}</Tag>
</RbCard>
))}
</Space>
}
<div className="rb:text-[16px] rb:font-medium rb:leading-5.5 rb:mb-3 rb:mt-6">{t('userMemory.timelineMemories')}</div>
<RbCard>
<Tabs
activeKey={activeTab}
items={['timelines_memory', 'Statement', 'MemorySummary'].map(key => ({
label: t(`userMemory.${key}`),
key
}))}
onChange={(key: string) => setActiveTab(key)}
/>
{timelineLoading
? <Skeleton active />
: !activeContent || activeContent.length === 0
? <Empty size={120} className="rb:mt-12 rb:mb-20.25" />
: <Space size={16} direction="vertical" className="rb:w-full">
{activeContent.map((vo, index) => (
<RbCard
key={index}
headerType="borderL"
headerClassName="rb:before:bg-[#155EEF]!"
title={vo.text}
>
<div className="rb:text-[#A8A9AA] rb:text-[12px] rb:leading-4">{formatDateTime(vo.created_at)}</div>
<Tag className="rb:mt-2">{vo.type}</Tag>
</RbCard>
))}
</Space>
}
</RbCard>
</RbDrawer>
</RbCard>
</div>
</>
)
})
export default GraphDetail

View File

@@ -1,7 +1,7 @@
import { type FC, useEffect, useState, useMemo, useRef } from 'react'
import { useParams, useNavigate } from 'react-router-dom'
import { useTranslation } from 'react-i18next'
import { Dropdown, Space, Button } from 'antd'
import { Dropdown, Button } from 'antd'
import PageHeader from '../components/PageHeader'
import StatementDetail from './StatementDetail'
@@ -16,6 +16,7 @@ import {
getEndUserProfile,
} from '@/api/memory'
import refreshIcon from '@/assets/images/refresh_hover.svg'
import GraphDetail from './GraphDetail'
const Detail: FC = () => {
const { t } = useTranslation()
@@ -47,6 +48,10 @@ const Detail: FC = () => {
forgetDetailRef.current?.handleRefresh()
}
if (type === 'GRAPH') {
return <GraphDetail />
}
return (
<div className="rb:h-full rb:w-full">
<PageHeader

View File

@@ -107,7 +107,7 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
<div style={{ maxHeight: '300px', overflowY: 'auto', minWidth: '240px' }}>
{nodeLibrary.map((category, categoryIndex) => {
const filteredNodes = category.nodes.filter(nodeType =>
nodeType.type !== 'start' && nodeType.type !== 'end' && nodeType.type !== 'loop' && nodeType.type !== 'cycle-start'
nodeType.type !== 'start' && nodeType.type !== 'end' && nodeType.type !== 'iteration' && nodeType.type !== 'loop' && nodeType.type !== 'cycle-start'
);
if (filteredNodes.length === 0) return null;

View File

@@ -33,7 +33,7 @@ const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => {
y: cycleStartBBox.y,
data: {
type: 'add-node',
label: '添加节点',
label: t('workflow.addNode'),
icon: '+',
parentId: node.id,
cycle: data.id,
@@ -61,7 +61,7 @@ const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => {
},
},
},
zIndex: 3
zIndex: 10
});
}
}
@@ -97,7 +97,7 @@ const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => {
y: centerY,
data: {
type: 'add-node',
label: '添加节点',
label: t('workflow.addNode'),
icon: '+',
parentId: node.id,
cycle: data.id,
@@ -128,7 +128,7 @@ const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => {
},
},
},
zIndex: 3
zIndex: 10
}
graph.addEdge(edgeConfig)

View File

@@ -151,11 +151,11 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
let filteredNodes;
if (isChildOfLoop) {
// Use same filtering as AddNode for child nodes of loop
// Use same filtering as AddNode for child nodes of loop, but allow break
filteredNodes = category.nodes.filter(nodeType => !['start', 'end', 'loop', 'cycle-start', 'iteration'].includes(nodeType.type));
} else if (isChildOfIteration) {
// Filter out loop and iteration nodes for children of iteration nodes
filteredNodes = category.nodes.filter(nodeType => !['start', 'end', 'loop', 'break', 'cycle-start', 'iteration'].includes(nodeType.type));
// Filter out loop and iteration nodes for children of iteration nodes, but allow break
filteredNodes = category.nodes.filter(nodeType => !['start', 'end', 'loop', 'cycle-start', 'iteration'].includes(nodeType.type));
} else {
// Original filtering for non-loop child nodes
filteredNodes = category.nodes.filter(nodeType => !['start', 'end', 'break', 'cycle-start'].includes(nodeType.type));

View File

@@ -60,7 +60,7 @@ const AssignmentList: FC<AssignmentListProps> = ({
>
<VariableSelect
placeholder={t('common.pleaseSelect')}
options={options}
options={options.filter(vo => vo.nodeData.type === 'loop' || vo.value.includes('conv.'))}
popupMatchSelectWidth={false}
onChange={() => {
form.setFieldValue([parentName, name, 'operation'], undefined);

View File

@@ -1,17 +1,19 @@
import { type FC } from 'react';
import { useTranslation } from 'react-i18next';
import { Input, Button, Form, Space } from 'antd';
import { PlusOutlined, CopyOutlined, DeleteOutlined, ExpandOutlined } from '@ant-design/icons';
import { Button, Form, Space } from 'antd';
import { DeleteOutlined } from '@ant-design/icons';
import { Graph, Node } from '@antv/x6';
import type { PortMetadata } from '@antv/x6/lib/model/port';
import Editor from '../../Editor';
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
interface CategoryListProps {
parentName: string;
options: Suggestion[];
selectedNode?: Node | null;
graphRef?: React.MutableRefObject<Graph | undefined>;
}
const CategoryList: FC<CategoryListProps> = ({ parentName, selectedNode, graphRef }) => {
const CategoryList: FC<CategoryListProps> = ({ parentName, selectedNode, graphRef, options }) => {
const { t } = useTranslation();
const form = Form.useFormInstance();
const formValues = Form.useWatch([parentName], form);
@@ -167,9 +169,9 @@ const CategoryList: FC<CategoryListProps> = ({ parentName, selectedNode, graphRe
name={[name, 'class_name']}
noStyle
>
<Input.TextArea
<Editor
placeholder={t('common.pleaseEnter')}
rows={2}
options={options}
/>
</Form.Item>
</div>

View File

@@ -1,6 +1,6 @@
import { type FC } from 'react'
import { useTranslation } from 'react-i18next';
import { Form, Button, Select, Row, Col, InputNumber, Radio, type SelectProps } from 'antd'
import { Form, Button, Select, Row, Col, InputNumber, Radio, Input, type SelectProps } from 'antd'
import { DeleteOutlined } from '@ant-design/icons';
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
@@ -114,7 +114,7 @@ const ConditionList: FC<CaseListProps> = ({
<Col span={14}>
<Form.Item name={[field.name, 'left']} noStyle>
<VariableSelect
options={options}
options={options.filter(vo => vo.value.includes('sys.') || vo.value.includes('conv.') || vo.nodeData.type === 'loop')}
size="small"
allowClear={false}
popupMatchSelectWidth={false}
@@ -186,7 +186,7 @@ const ConditionList: FC<CaseListProps> = ({
<Radio.Button value={true}>True</Radio.Button>
<Radio.Button value={false}>False</Radio.Button>
</Radio.Group>
: <Editor options={options} />
: <Input placeholder={t('common.pleaseEnter')} />
}
</Form.Item>
</Col>

View File

@@ -1,6 +1,6 @@
import { type FC } from 'react'
import { useTranslation } from 'react-i18next';
import { Form, Button, Select, Row, Col, Input } from 'antd'
import { Form, Select, Row, Col, Input } from 'antd'
import { DeleteOutlined, PlusOutlined } from '@ant-design/icons';
import VariableSelect from '../VariableSelect'
@@ -36,7 +36,6 @@ const CycleVarsList: FC<CycleVarsListProps> = ({
value = [],
options,
parentName,
onChange,
selectedNode,
graphRef
}) => {
@@ -139,12 +138,17 @@ const CycleVarsList: FC<CycleVarsListProps> = ({
<Form.Item name={[name, 'value']} noStyle>
{currentInputType === 'variable' ? (
<VariableSelect
placeholder="选择变量"
options={availableOptions}
placeholder={t('common.pleaseSelect')}
options={availableOptions.filter(option => {
const currentType = value?.[index]?.type;
if (!currentType) return true;
return option.dataType === currentType
})}
/>
) : (
<Input.TextArea
placeholder="输入值"
placeholder={t('common.pleaseEnter')}
rows={3}
className="rb:w-full"
/>

View File

@@ -18,8 +18,22 @@ const GroupVariableList: FC<GroupVariableListProps> = ({
isCanAdd = false
}) => {
const { t } = useTranslation();
const form = Form.useFormInstance();
const value = form.getFieldValue(name) || [];
console.log('GroupVariableList', value)
if (!isCanAdd) {
// Filter options based on first variable's dataType if value exists
let filteredOptions = options;
if (value.length > 0) {
const firstVariableValue = value[0];
const firstVariable = options.find(opt => `{{${opt.value}}}` === firstVariableValue);
if (firstVariable) {
filteredOptions = options.filter(opt => opt.dataType === firstVariable.dataType);
}
}
return (
<div className="rb:mb-4">
<Row gutter={12} className="rb:mb-2!">
@@ -38,7 +52,7 @@ const GroupVariableList: FC<GroupVariableListProps> = ({
>
<VariableSelect
placeholder={t('common.pleaseSelect')}
options={options}
options={filteredOptions}
mode="multiple"
/>
</Form.Item>
@@ -77,7 +91,18 @@ const GroupVariableList: FC<GroupVariableListProps> = ({
>
<VariableSelect
placeholder={t('common.pleaseSelect')}
options={options}
options={(() => {
const currentGroupValue = value[name]?.value || [];
if (currentGroupValue.length > 0) {
const firstVariableValue = currentGroupValue[0];
const firstVariable = options.find(opt => `{{${opt.value}}}` === firstVariableValue);
if (firstVariable) {
return options.filter(opt => opt.dataType === firstVariable.dataType);
}
}
return options;
})()
}
mode="multiple"
/>
</Form.Item>

View File

@@ -90,7 +90,7 @@ const HttpRequest: FC<{ options: Suggestion[]; selectedNode?: any; graphRef?: an
</Col>
<Col span={16}>
<Form.Item name="url">
<Editor options={options} variant="outlined" />
<Editor options={options.filter(vo => vo.dataType === 'string' || vo.dataType === 'number')} variant="outlined" />
</Form.Item>
</Col>
</Row>
@@ -144,7 +144,7 @@ const HttpRequest: FC<{ options: Suggestion[]; selectedNode?: any; graphRef?: an
<Form.Item name={['body', 'data']} noStyle>
<EditableTable
parentName={['body', 'data']}
options={options}
options={options.filter(vo => vo.dataType === 'string' || vo.dataType === 'number')}
filterBooleanType={true}
/>
</Form.Item>
@@ -154,7 +154,7 @@ const HttpRequest: FC<{ options: Suggestion[]; selectedNode?: any; graphRef?: an
<MessageEditor
key="json"
parentName={['body', 'data']}
options={options}
options={options.filter(vo => vo.dataType === 'string' || vo.dataType === 'number')}
isArray={false}
title="JSON"
/>

View File

@@ -91,6 +91,7 @@ const VariableSelect: FC<VariableSelectProps> = ({
showSearch
allowClear={allowClear}
filterOption={(input, option) => {
if (input === '/') return true;
if (option?.options) {
return option.label?.toLowerCase().includes(input.toLowerCase()) ||
option.options.some((opt: any) =>

View File

@@ -22,6 +22,7 @@ import ConditionList from './ConditionList'
import CycleVarsList from './CycleVarsList'
import AssignmentList from './AssignmentList'
import ToolConfig from './ToolConfig'
// import { calculateVariableList } from './utils/variableListCalculator'
interface PropertiesProps {
selectedNode?: Node | null;
@@ -338,112 +339,35 @@ const Properties: FC<PropertiesProps> = ({
const parentLoopNode = getParentLoopNode(selectedNode.id);
console.log('childNodeIds', selectedNode, childNodeIds)
const allRelevantNodeIds = [...allPreviousNodeIds, ...childNodeIds];
let allRelevantNodeIds = [...allPreviousNodeIds, ...childNodeIds];
// Add parent loop/iteration node variables if current node is a child
// Add variables from nodes preceding the parent loop/iteration node if current node is a child
if (parentLoopNode) {
const parentData = parentLoopNode.getData();
const parentNodeId = parentLoopNode.getData().id;
if (parentData.type === 'loop') {
const cycleVars = parentData.cycle_vars || [];
cycleVars.forEach((cycleVar: any) => {
const key = `${parentNodeId}_cycle_${cycleVar.name}`;
if (!addedKeys.has(key)) {
addedKeys.add(key);
variableList.push({
key,
label: cycleVar.name,
type: 'variable',
dataType: cycleVar.type || 'String',
value: `${parentNodeId}.${cycleVar.name}`,
nodeData: parentData,
});
}
});
} else if (parentData.type === 'iteration') {
// Add item and index variables for iteration parent
const itemKey = `${parentNodeId}_item`;
const indexKey = `${parentNodeId}_index`;
if (!addedKeys.has(itemKey)) {
addedKeys.add(itemKey);
variableList.push({
key: itemKey,
label: 'item',
type: 'variable',
dataType: 'Object',
value: `${parentNodeId}.item`,
nodeData: parentData,
});
}
if (!addedKeys.has(indexKey)) {
addedKeys.add(indexKey);
variableList.push({
key: indexKey,
label: 'index',
type: 'variable',
dataType: 'Number',
value: `${parentNodeId}.index`,
nodeData: parentData,
});
}
}
// Check if parent loop/iteration is connected to http-request via ERROR connection
if (parentData.type === 'loop' || parentData.type === 'iteration') {
const parentPreviousNodeIds = getAllPreviousNodes(parentLoopNode.id);
parentPreviousNodeIds.forEach(prevNodeId => {
const prevNode = nodes.find(n => n.id === prevNodeId);
if (!prevNode) return;
const prevNodeData = prevNode.getData();
if (prevNodeData.type === 'http-request') {
// Check if connected via ERROR connection point
const errorEdges = edges.filter(edge => {
return edge.getTargetCellId() === parentLoopNode.id &&
edge.getSourceCellId() === prevNodeId &&
edge.getSourcePortId() === 'ERROR'
});
if (errorEdges.length > 0) {
const errorMessageKey = `${prevNodeData.id}_error_message`;
const errorTypeKey = `${prevNodeData.id}_error_type`;
if (!addedKeys.has(errorMessageKey)) {
addedKeys.add(errorMessageKey);
variableList.push({
key: errorMessageKey,
label: 'error_message',
type: 'variable',
dataType: 'string',
value: `${prevNodeData.id}.error_message`,
nodeData: prevNodeData,
});
}
if (!addedKeys.has(errorTypeKey)) {
addedKeys.add(errorTypeKey);
variableList.push({
key: errorTypeKey,
label: 'error_type',
type: 'variable',
dataType: 'string',
value: `${prevNodeData.id}.error_type`,
nodeData: prevNodeData,
});
}
}
}
});
}
// Add variables from nodes preceding the parent loop/iteration node
const parentPreviousNodeIds = getAllPreviousNodes(parentLoopNode.id);
allRelevantNodeIds.push(...parentPreviousNodeIds);
}
// Add conversation variables from global config
const conversationVariables = workflowConfig?.variables || [];
conversationVariables.forEach((variable: any) => {
const key = `CONVERSATION_${variable.name}`;
if (!addedKeys.has(key)) {
addedKeys.add(key);
variableList.push({
key,
label: variable.name,
type: 'variable',
dataType: variable.type,
value: `conv.${variable.name}`,
nodeData: { type: 'CONVERSATION', name: 'CONVERSATION', icon: '' },
group: 'CONVERSATION'
});
}
});
allRelevantNodeIds.forEach(nodeId => {
const node = nodes.find(n => n.id === nodeId);
if (!node) return;
@@ -496,7 +420,7 @@ const Properties: FC<PropertiesProps> = ({
key: llmKey,
label: 'output',
type: 'variable',
dataType: 'String',
dataType: 'string',
value: `${dataNodeId}.output`,
nodeData: nodeData,
});
@@ -565,6 +489,17 @@ const Properties: FC<PropertiesProps> = ({
const groupVariables = nodeData.config.group_variables.defaultValue || [];
groupVariables?.forEach((groupVar: any) => {
if (!groupVar || !groupVar.key) return;
// Determine dataType from first variable in the group
let groupDataType = 'string';
if (groupVar.value && Array.isArray(groupVar.value) && groupVar.value.length > 0) {
const firstVariableValue = groupVar.value[0];
const firstVariable = variableList.find(v => `{{${v.value}}}` === firstVariableValue);
if (firstVariable) {
groupDataType = firstVariable.dataType;
}
}
const groupVarKey = `${dataNodeId}_${groupVar.key}`;
if (!addedKeys.has(groupVarKey)) {
addedKeys.add(groupVarKey);
@@ -572,14 +507,26 @@ const Properties: FC<PropertiesProps> = ({
key: groupVarKey,
label: groupVar.key,
type: 'variable',
dataType: 'string',
dataType: groupDataType,
value: `${dataNodeId}.${groupVar.key}`,
nodeData: nodeData,
});
}
});
} else {
// If group=false, add output variable
// If group=false, add output variable with type from first group_variable
const groupVariables = nodeData.config.group_variables.defaultValue || [];
const firstVariable = groupVariables[0];
let outputDataType: string = 'any';
if (firstVariable) {
const filterVo = [...variableList].find(v => {
return `{{${v.value}}}` === firstVariable
})
if (filterVo) {
outputDataType = filterVo?.dataType
}
}
const varAggregatorKey = `${dataNodeId}_output`;
if (!addedKeys.has(varAggregatorKey)) {
addedKeys.add(varAggregatorKey);
@@ -587,7 +534,7 @@ const Properties: FC<PropertiesProps> = ({
key: varAggregatorKey,
label: 'output',
type: 'variable',
dataType: 'string',
dataType: outputDataType,
value: `${dataNodeId}.output`,
nodeData: nodeData,
});
@@ -684,21 +631,20 @@ const Properties: FC<PropertiesProps> = ({
nodeData: nodeData,
});
}
if (!addedKeys.has(outputKey)) {
addedKeys.add(outputKey);
variableList.push({
key: outputKey,
label: 'output',
type: 'variable',
dataType: 'string',
value: `${dataNodeId}.output`,
nodeData: nodeData,
});
}
// if (!addedKeys.has(outputKey)) {
// addedKeys.add(outputKey);
// variableList.push({
// key: outputKey,
// label: 'output',
// type: 'variable',
// dataType: 'string',
// value: `${dataNodeId}.output`,
// nodeData: nodeData,
// });
// }
break
case 'iteration':
const iterationOutputKey = `${dataNodeId}_output`;
const iterationItemKey = `${dataNodeId}_item`;
if (!addedKeys.has(iterationOutputKey)) {
addedKeys.add(iterationOutputKey);
// Get the data type from the output configuration, default to string
@@ -715,22 +661,11 @@ const Properties: FC<PropertiesProps> = ({
key: iterationOutputKey,
label: 'output',
type: 'variable',
dataType: outputDataType,
dataType: `array[${outputDataType}]`,
value: `${dataNodeId}.output`,
nodeData: nodeData,
});
}
if (!addedKeys.has(iterationItemKey)) {
addedKeys.add(iterationItemKey);
variableList.push({
key: iterationItemKey,
label: 'item',
type: 'variable',
dataType: 'string',
value: `${dataNodeId}.item`,
nodeData: nodeData,
});
}
break
case 'loop':
const cycleVars = nodeData.config.cycle_vars.defaultValue || [];
@@ -760,47 +695,337 @@ const Properties: FC<PropertiesProps> = ({
key: toolDataKey,
label: 'data',
type: 'variable',
dataType: 'object',
dataType: 'string',
value: `${dataNodeId}.data`,
nodeData: nodeData,
});
}
break
case 'memory-read':
const memoryReadAnswerKey = `${dataNodeId}_answer`;
const memoryReadIntermediateOutputs = `${dataNodeId}_intermediate_outputs`;
if (!addedKeys.has(memoryReadAnswerKey)) {
addedKeys.add(memoryReadAnswerKey);
variableList.push({
key: memoryReadAnswerKey,
label: 'answer',
type: 'variable',
dataType: 'string',
value: `${dataNodeId}.answer`,
nodeData: nodeData,
});
}
if (!addedKeys.has(memoryReadIntermediateOutputs)) {
addedKeys.add(memoryReadIntermediateOutputs);
variableList.push({
key: memoryReadIntermediateOutputs,
label: 'intermediate_outputs',
type: 'variable',
dataType: 'array[object]',
value: `${dataNodeId}.intermediate_outputs`,
nodeData: nodeData,
});
}
break
}
});
// Add conversation variables from global config
const conversationVariables = workflowConfig?.variables || [];
conversationVariables.forEach((variable: any) => {
const key = `CONVERSATION_${variable.name}`;
if (!addedKeys.has(key)) {
addedKeys.add(key);
variableList.push({
key,
label: variable.name,
type: 'variable',
dataType: variable.type,
value: `conv.${variable.name}`,
nodeData: { type: 'CONVERSATION', name: 'CONVERSATION', icon: '' },
group: 'CONVERSATION'
// Add parent loop/iteration node variables if current node is a child
if (parentLoopNode) {
const parentData = parentLoopNode.getData();
const parentNodeId = parentLoopNode.getData().id;
if (parentData.type === 'loop') {
const cycleVars = parentData.cycle_vars || [];
cycleVars.forEach((cycleVar: any) => {
const key = `${parentNodeId}_cycle_${cycleVar.name}`;
if (!addedKeys.has(key)) {
addedKeys.add(key);
variableList.push({
key,
label: cycleVar.name,
type: 'variable',
dataType: cycleVar.type || 'String',
value: `${parentNodeId}.${cycleVar.name}`,
nodeData: parentData,
});
}
});
} else if (parentData.type === 'iteration') {
// Add item and index variables for iteration parent only if input has value
if (parentData.config.input.defaultValue) {
const itemKey = `${parentNodeId}_item`;
const indexKey = `${parentNodeId}_index`;
// Determine item dataType from input variable
let itemDataType = 'object';
const inputVariable = variableList.find(v => `{{${v.value}}}` === parentData.config.input.defaultValue);
console.log('itemDataType defaultValue', parentData.config.input.defaultValue, variableList, inputVariable)
if (inputVariable && inputVariable.dataType.startsWith('array[')) {
itemDataType = inputVariable.dataType.replace(/^array\[(.+)\]$/, '$1');
console.log('itemDataType', itemDataType)
}
if (!addedKeys.has(itemKey)) {
addedKeys.add(itemKey);
variableList.push({
key: itemKey,
label: 'item',
type: 'variable',
dataType: itemDataType,
value: `${parentNodeId}.item`,
nodeData: parentData,
});
}
if (!addedKeys.has(indexKey)) {
addedKeys.add(indexKey);
variableList.push({
key: indexKey,
label: 'index',
type: 'variable',
dataType: 'number',
value: `${parentNodeId}.index`,
nodeData: parentData,
});
}
}
}
// Check if parent loop/iteration is connected to http-request via ERROR connection
if (parentData.type === 'loop' || parentData.type === 'iteration') {
const parentPreviousNodeIds = getAllPreviousNodes(parentLoopNode.id);
parentPreviousNodeIds.forEach(prevNodeId => {
const prevNode = nodes.find(n => n.id === prevNodeId);
if (!prevNode) return;
const prevNodeData = prevNode.getData();
if (prevNodeData.type === 'http-request') {
// Check if connected via ERROR connection point
const errorEdges = edges.filter(edge => {
return edge.getTargetCellId() === parentLoopNode.id &&
edge.getSourceCellId() === prevNodeId &&
edge.getSourcePortId() === 'ERROR'
});
if (errorEdges.length > 0) {
const errorMessageKey = `${prevNodeData.id}_error_message`;
const errorTypeKey = `${prevNodeData.id}_error_type`;
if (!addedKeys.has(errorMessageKey)) {
addedKeys.add(errorMessageKey);
variableList.push({
key: errorMessageKey,
label: 'error_message',
type: 'variable',
dataType: 'string',
value: `${prevNodeData.id}.error_message`,
nodeData: prevNodeData,
});
}
if (!addedKeys.has(errorTypeKey)) {
addedKeys.add(errorTypeKey);
variableList.push({
key: errorTypeKey,
label: 'error_type',
type: 'variable',
dataType: 'string',
value: `${prevNodeData.id}.error_type`,
nodeData: prevNodeData,
});
}
}
}
});
}
});
}
return variableList;
}, [selectedNode, graphRef, workflowConfig?.variables]);
// Filter out boolean type variables for loop and llm nodes
const getFilteredVariableList = (nodeType?: string) => {
if (nodeType === 'loop' || nodeType === 'llm') {
return variableList.filter(variable => variable.dataType !== 'boolean');
const getFilteredVariableList = (nodeType?: string, key?: string) => {
// Check if current node is a child of iteration node
const parentIterationNode = selectedNode ? (() => {
const nodes = graphRef.current?.getNodes() || [];
const nodeData = selectedNode.getData();
const cycle = nodeData?.cycle;
if (cycle) {
const parentNode = nodes.find(n => n.getData().id === cycle);
if (parentNode) {
const parentData = parentNode.getData();
if (parentData?.type === 'iteration') {
return parentNode;
}
}
}
return null;
})() : null;
// Helper function to add parent iteration variables
const addParentIterationVars = (filteredList: any[]) => {
if (parentIterationNode) {
const parentData = parentIterationNode.getData();
const parentNodeId = parentData.id;
if (parentData.config?.input?.defaultValue) {
const itemKey = `${parentNodeId}_item`;
const indexKey = `${parentNodeId}_index`;
const existingItemVar = filteredList.find(v => v.key === itemKey);
const existingIndexVar = filteredList.find(v => v.key === indexKey);
if (!existingItemVar) {
// Determine item dataType from input variable
let itemDataType = 'object';
const inputVariable = variableList.find(v => `{{${v.value}}}` === parentData.config.input.defaultValue);
if (inputVariable && inputVariable.dataType.startsWith('array[')) {
itemDataType = inputVariable.dataType.replace(/^array\[(.+)\]$/, '$1');
}
filteredList.push({
key: itemKey,
label: 'item',
type: 'variable',
dataType: itemDataType,
value: `${parentNodeId}.item`,
nodeData: parentData,
});
}
if (!existingIndexVar) {
filteredList.push({
key: indexKey,
label: 'index',
type: 'variable',
dataType: 'number',
value: `${parentNodeId}.index`,
nodeData: parentData,
});
}
}
}
return filteredList;
};
if (nodeType === 'llm') {
// For LLM nodes that are children of iteration or loop nodes, include parent variables
const parentLoopNode = selectedNode ? (() => {
const nodes = graphRef.current?.getNodes() || [];
const nodeData = selectedNode.getData();
const cycle = nodeData?.cycle;
if (cycle) {
const parentNode = nodes.find(n => n.getData().id === cycle);
if (parentNode) {
const parentData = parentNode.getData();
if (parentData?.type === 'loop' || parentData?.type === 'iteration') {
return parentNode;
}
}
}
return null;
})() : null;
let filteredList = variableList.filter(variable => variable.dataType !== 'boolean');
// If this LLM node is a child of iteration/loop, ensure parent variables are included
if (parentLoopNode) {
const parentData = parentLoopNode.getData();
const parentNodeId = parentData.id;
// Ensure parent loop/iteration variables are included
if (parentData.type === 'loop') {
const cycleVars = parentData.cycle_vars || [];
cycleVars.forEach((cycleVar: any) => {
const key = `${parentNodeId}_cycle_${cycleVar.name}`;
const existingVar = filteredList.find(v => v.key === key);
if (!existingVar && cycleVar.name && cycleVar.type !== 'boolean') {
filteredList.push({
key,
label: cycleVar.name,
type: 'variable',
dataType: cycleVar.type || 'String',
value: `${parentNodeId}.${cycleVar.name}`,
nodeData: parentData,
});
}
});
} else if (parentData.type === 'iteration') {
// Add item and index variables for iteration parent
if (parentData.config?.input?.defaultValue) {
const itemKey = `${parentNodeId}_item`;
const indexKey = `${parentNodeId}_index`;
const existingItemVar = filteredList.find(v => v.key === itemKey);
const existingIndexVar = filteredList.find(v => v.key === indexKey);
if (!existingItemVar) {
// Determine item dataType from input variable
let itemDataType = 'object';
const inputVariable = variableList.find(v => `{{${v.value}}}` === parentData.config.input.defaultValue);
if (inputVariable && inputVariable.dataType.startsWith('array[')) {
itemDataType = inputVariable.dataType.replace(/^array\[(.+)\]$/, '$1');
}
filteredList.push({
key: itemKey,
label: 'item',
type: 'variable',
dataType: itemDataType,
value: `${parentNodeId}.item`,
nodeData: parentData,
});
}
if (!existingIndexVar) {
filteredList.push({
key: indexKey,
label: 'index',
type: 'variable',
dataType: 'Number',
value: `${parentNodeId}.index`,
nodeData: parentData,
});
}
}
}
}
return filteredList;
}
return variableList;
if (nodeType === 'knowledge-retrieval' || nodeType === 'parameter-extractor' && key !== 'prompt' || nodeType === 'memory-read' || nodeType === 'memory-write' || nodeType === 'question-classifier') {
let filteredList = variableList.filter(variable => variable.dataType === 'string');
return addParentIterationVars(filteredList);
}
if (nodeType === 'parameter-extractor' && key === 'prompt') {
let filteredList = variableList.filter(variable => variable.dataType === 'string' || variable.dataType === 'number');
return addParentIterationVars(filteredList);
}
if (nodeType === 'iteration' && key === 'output') {
return variableList.filter(variable => variable.value.includes('sys.'));
}
if (nodeType === 'iteration') {
return variableList.filter(variable => variable.dataType.includes('array'));
}
if (nodeType === 'loop' && key === 'condition') {
let filteredList = variableList.filter(variable => variable.nodeData.type !== 'loop');
return addParentIterationVars(filteredList);
}
// For all other node types, add parent iteration variables if applicable
let baseList = variableList;
return addParentIterationVars(baseList);
};
// const defaultVariableList = calculateVariableList(selectedNode as Node, graphRef, workflowConfig )
console.log('values', values)
console.log('variableList', variableList, selectedNode?.data)
// console.log('variableList', variableList, defaultVariableList)
return (
<div className="rb:w-75 rb:fixed rb:right-0 rb:top-16 rb:bottom-0 rb:p-3">
@@ -901,11 +1126,10 @@ const Properties: FC<PropertiesProps> = ({
});
}
}
return (
<Form.Item key={key} name={key}>
<MessageEditor
key={key}
key={key}
options={contextVariableList.filter(variable => variable.nodeData?.type !== 'knowledge-retrieval')}
parentName={key}
/>
@@ -915,7 +1139,12 @@ const Properties: FC<PropertiesProps> = ({
if (selectedNode?.data?.type === 'end' && key === 'output') {
return (
<Form.Item key={key} name={key}>
<MessageEditor key={key} isArray={false} parentName={key} options={variableList} />
<MessageEditor
key={key}
isArray={false}
parentName={key}
options={variableList.filter(variable => variable.nodeData?.type !== 'knowledge-retrieval')}
/>
</Form.Item>
)
}
@@ -943,7 +1172,7 @@ const Properties: FC<PropertiesProps> = ({
isArray={!!config.isArray}
parentName={key}
enableJinja2={config.enableJinja2 as boolean}
options={getFilteredVariableList(selectedNode?.data?.type)}
options={getFilteredVariableList(selectedNode?.data?.type, key)}
/>
</Form.Item>
)
@@ -964,7 +1193,7 @@ const Properties: FC<PropertiesProps> = ({
<Form.Item key={key} name={key}>
<GroupVariableList
name={key}
options={getFilteredVariableList(selectedNode?.data?.type)}
options={getFilteredVariableList(selectedNode?.data?.type, key)}
isCanAdd={!!(values as any)?.group}
/>
</Form.Item>
@@ -976,7 +1205,7 @@ const Properties: FC<PropertiesProps> = ({
<Form.Item key={key} name={key}>
<CaseList
name={key}
options={getFilteredVariableList(selectedNode?.data?.type)}
options={getFilteredVariableList(selectedNode?.data?.type, key)}
selectedNode={selectedNode}
graphRef={graphRef}
/>
@@ -989,7 +1218,7 @@ const Properties: FC<PropertiesProps> = ({
<Form.Item key={key} name={key}
label={t(`workflow.config.${selectedNode?.data?.type}.${key}`)}
>
<MappingList name={key} options={getFilteredVariableList(selectedNode?.data?.type)} />
<MappingList name={key} options={getFilteredVariableList(selectedNode?.data?.type, key)} />
</Form.Item>
)
@@ -999,7 +1228,7 @@ const Properties: FC<PropertiesProps> = ({
<Form.Item key={key} name={key}>
<CycleVarsList
parentName={key}
options={getFilteredVariableList(selectedNode?.data?.type)}
options={getFilteredVariableList(selectedNode?.data?.type, key)}
/>
</Form.Item>
)
@@ -1013,9 +1242,9 @@ const Properties: FC<PropertiesProps> = ({
if (config.filterLoopIterationVars) {
const loopIterationVars: Suggestion[] = [];
return [...getFilteredVariableList(selectedNode?.data?.type), ...loopIterationVars];
return [...getFilteredVariableList(selectedNode?.data?.type, key), ...loopIterationVars];
}
return getFilteredVariableList(selectedNode?.data?.type);
return getFilteredVariableList(selectedNode?.data?.type, key);
})()
}
/>
@@ -1060,7 +1289,7 @@ const Properties: FC<PropertiesProps> = ({
? <VariableSelect
placeholder={t('common.pleaseSelect')}
options={(() => {
const baseVariableList = getFilteredVariableList(selectedNode?.data?.type);
const baseVariableList = getFilteredVariableList(selectedNode?.data?.type, key);
// Apply filtering if specified in config
if (config.filterNodeTypes || config.filterVariableNames) {
return baseVariableList.filter(variable => {
@@ -1068,7 +1297,7 @@ const Properties: FC<PropertiesProps> = ({
(Array.isArray(config.filterNodeTypes) && config.filterNodeTypes.includes(variable.nodeData?.type));
const variableNameMatch = !config.filterVariableNames ||
(Array.isArray(config.filterVariableNames) && config.filterVariableNames.includes(variable.label));
return nodeTypeMatch && variableNameMatch;
return nodeTypeMatch || variableNameMatch;
});
}
// Filter child nodes for iteration output
@@ -1085,7 +1314,7 @@ const Properties: FC<PropertiesProps> = ({
});
return baseVariableList.filter(variable =>
childNodes.some(node => node.id === variable.nodeData?.id)
childNodes.some(node => node.id === variable.nodeData?.id) || selectedNode?.data?.type === 'iteration' && key === 'output' && variable.value.includes('sys.')
);
}
return baseVariableList;
@@ -1095,7 +1324,12 @@ const Properties: FC<PropertiesProps> = ({
: config.type === 'switch'
? <Switch onChange={key === 'group' ? () => { form.setFieldValue('group_variables', []) } : undefined} />
: config.type === 'categoryList'
? <CategoryList parentName={key} selectedNode={selectedNode} graphRef={graphRef} />
? <CategoryList
parentName={key}
selectedNode={selectedNode}
graphRef={graphRef}
options={getFilteredVariableList(selectedNode?.data?.type, key)}
/>
: config.type === 'conditionList'
? <ConditionList
parentName={key}
@@ -1109,18 +1343,9 @@ const Properties: FC<PropertiesProps> = ({
value: `${selectedNode.getData().id}.${cycleVar.name}`,
nodeData: selectedNode.getData(),
}));
return [...variableList.filter(variable => {
// Keep conversation variables
if (variable.group === 'CONVERSATION') return true;
// Keep sys variables from start nodes
if (variable.nodeData?.type === 'start' && variable.value?.startsWith('sys.')) return true;
// Keep variables from non-start nodes
if (variable.nodeData?.type !== 'start' && variable.nodeData?.type !== 'http-request' && variable.dataType !== 'boolean') return true;
// Filter out custom variables from start nodes
return false;
}), ...cycleVarSuggestions];
})()
}
return [...getFilteredVariableList(selectedNode?.data?.type, key), ...cycleVarSuggestions];
})()}
selectedNode={selectedNode}
graphRef={graphRef}
addBtnText={t('workflow.config.addCase')}

View File

@@ -270,7 +270,7 @@ export const nodeLibrary: NodeLibrary[] = [
config: {
input: {
type: 'variableList',
filterNodeTypes: ['knowledge-retrieval'],
filterNodeTypes: ['knowledge-retrieval', 'iteration', 'loop'],
filterVariableNames: ['message']
},
parallel: {
@@ -334,8 +334,7 @@ export const nodeLibrary: NodeLibrary[] = [
}
}
},
{
type: "assigner", icon: assignerIcon,
{ type: "assigner", icon: assignerIcon,
config: {
assignments: {
type: 'assignmentList',
@@ -656,4 +655,114 @@ export const graphNodeLibrary: Record<string, NodeConfig> = {
items: [{ group: 'left' }],
},
}
}
export interface OutputVariable {
default?: Array<{
name: string;
type: string;
}>;
define?: string[];
sys?: Array<{
name: string;
type: string;
}>;
error?: Array<{
name: string;
type: string;
}>;
}
export const outputVariable: { [key: string]: OutputVariable } = {
start: {
sys: [
{ name: "message", type: "string" },
{ name: "conversation_id", type: "string" },
{ name: "execution_id", type: "string", },
{ name: "workspace_id", type: "string" },
{ name: "user_id", type: "string" },
],
define: ['variables']
},
end: {
},
llm: {
default: [
{ name: "output", type: "string" },
]
},
'knowledge-retrieval': {
default: [
{ name: "output", type: "array[object]" },
]
},
'parameter-extractor': {
default: [
{ name: "__is_success", type: "number" },
{ name: "__reason", type: "string" },
],
define: ['params']
},
'memory-read': {
default: [
{ name: "answer", type: "string" },
{ name: "intermediate_outputs", type: "array[object]" },
],
},
'memory-write': {
},
'if-else': {
},
'question-classifier': {
default: [
{ name: "class_name", type: "string" },
// { name: "output", type: "string" },
],
},
'iteration': {
default: [
// { name: "item", type: "string" }, // 仅内部使用
{ name: "output", type: "array[string]" },
],
},
'loop': {
define: ['cycle_vars']
},
'cycle-start': {
},
'break': {
},
'var-aggregator': {
// default: [
// { name: "output", type: "string" },
// ],
define: ['group_variables']
},
'assigner': {
},
'http-request': {
default: [
{ name: "body", type: "string" },
{ name: "status_code", type: "number" },
],
error: [
{ name: "error_message", type: "string" },
{ name: "error_type", type: "string" },
]
},
'tool': {
default: [
{ name: "data", type: "string" },
],
},
'jinja-render': {
default: [
{ name: "output", type: "string" },
],
},
}