Merge branch 'develop' into feature/knowledgeBase_yjp

This commit is contained in:
yujiangping
2026-01-14 16:43:29 +08:00
17 changed files with 268 additions and 203 deletions

View File

@@ -583,7 +583,7 @@ async def chat(
event_data = event.get("data", {})
# 转换为标准 SSE 格式(字符串)
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data, default=str, ensure_ascii=False)}\n\n"
yield sse_message
return StreamingResponse(

View File

@@ -425,15 +425,9 @@ async def Input_Summary(
try:
# Extract services from context
template_service = get_context_resource(ctx, "template_service")
session_service = get_context_resource(ctx, "session_service")
search_service = get_context_resource(ctx, "search_service")
# Get LLM client from memory_config
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Resolve session ID
sessionid = Resolve_username(usermessages) or ""
sessionid = sessionid.replace('call_id_', '')
@@ -539,31 +533,11 @@ async def Input_Summary(
)
retrieve_info, question, raw_results = "", query, []
# Return retrieved information directly without LLM processing
# Use the raw retrieved info as the answer
aimessages = retrieve_info if retrieve_info else "信息不足,无法回答"
# Render template
system_prompt = await template_service.render_template(
template_name='Retrieve_Summary_prompt.jinja2',
operation_name='input_summary',
query=query,
history=history,
retrieve_info=retrieve_info
)
# Call LLM with structured response
try:
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=RetrieveSummaryResponse
)
aimessages = structured.data.query_answer or "信息不足,无法回答"
except Exception as e:
logger.error(
f"Input_Summary: response_structured failed, using default answer: {e}",
exc_info=True
)
aimessages = "信息不足,无法回答"
logger.info(f"Quick answer summary: {storage_type}--{user_rag_memory_id}--{aimessages}")
logger.info(f"Quick answer (no LLM): {storage_type}--{user_rag_memory_id}--{aimessages[:500]}...")
# Emit intermediate output for frontend
return {

View File

@@ -10,9 +10,6 @@ from app.core.logging_config import get_business_logger
logger = get_business_logger()
# 为了兼容性,创建别名
# SchemaParser = OpenAPISchemaParser = None
class OpenAPISchemaParser:
"""OpenAPI Schema解析器 - 解析OpenAPI 3.0规范"""
@@ -213,7 +210,9 @@ class OpenAPISchemaParser:
if not isinstance(operation, dict):
continue
summary = operation.get("summary", "")
# 生成操作ID
operation_id = operation.get("operationId")
if not operation_id:
@@ -223,7 +222,7 @@ class OpenAPISchemaParser:
operations[operation_id] = {
"method": method.upper(),
"path": path,
"summary": operation.get("summary", ""),
"summary": summary if summary else operation_id,
"description": operation.get("description", ""),
"parameters": self._extract_parameters(operation),
"request_body": self._extract_request_body(operation),

View File

@@ -226,6 +226,7 @@ class LLMNode(BaseNode):
Yields:
文本片段chunk或完成标记
"""
self.typed_config = LLMNodeConfig(**self.config)
from langgraph.config import get_stream_writer
llm, prompt_or_messages = self._prepare_llm(state, True)

View File

@@ -1,9 +1,51 @@
"""
情景记忆的请求和响应模型
"""
from abc import ABC
from pydantic import BaseModel, Field
from typing import Optional
type_mapping = {
"Person": "人物实体节点",
"Organization": "组织实体节点",
"ORG": "组织实体节点",
"Location": "地点实体节点",
"LOC": "地点实体节点",
"Event": "事件实体节点",
"Concept": "概念实体节点",
"Time": "时间实体节点",
"Position": "职位实体节点",
"WorkRole": "职业实体节点",
"System": "系统实体节点",
"Policy": "政策实体节点",
"HistoricalPeriod": "历史时期实体节点",
"HistoricalState": "历史国家实体节点",
"HistoricalEvent": "历史事件实体节点",
"EconomicFactor": "经济因素实体节点",
"Condition": "条件实体节点",
"Numeric": "数值实体节点"
}
class EmotionType(ABC):
JOY_TYPE = "joy"
SURPRISE_TYPE = "surprise"
SANDROWNESS_TYPE = "sadness"
FEAR_TYPE = "fear"
ANGET_TYPE="anger"
NEUTRAL_TYPE="neutral"
EMOTION_MAPPING={
"joy":"愉快",
"surprise":"惊喜",
"sadness":"悲伤",
"fear":"恐惧",
"anger":"生气",
"neutral":"中性"
}
class EmotionSubject(ABC):
SUBJECT_MAPPING={
"self":"自己",
"other":"别人",
"object":"事物对象"
}
class EpisodicMemoryOverviewRequest(BaseModel):
"""情景记忆总览查询请求"""

View File

@@ -15,6 +15,8 @@ from neo4j.time import DateTime as Neo4jDateTime
import json
from datetime import datetime
from app.schemas.memory_episodic_schema import EmotionType
logger = logging.getLogger(__name__)
class MemoryEntityService:
@@ -123,7 +125,7 @@ class MemoryEntityService:
extracted_entity_list = self._deduplicate_dict_list(extracted_entity_list)
# 合并所有数据并处理相同text的合并
all_timeline_data = memory_summary_list + statement_list + extracted_entity_list
all_timeline_data = memory_summary_list + statement_list
all_timeline_data = self._merge_same_text_items(all_timeline_data)
result = {
@@ -496,11 +498,11 @@ class MemoryEmotion:
length_data.append(emotion_intensity)
if emotion_type is not None and emotion_intensity is not None and formatted_created_at is not None:
# 使用(emotion_type, created_at)作为分组键
if emotion_type in {"joy", "surprise"}:
if emotion_type in {EmotionType.JOY_TYPE, EmotionType.SURPRISE_TYPE}:
emotion_type='positive'
elif emotion_type in {"sadness", "fear", "anger"}:
elif emotion_type in {EmotionType.SANDROWNESS_TYPE, EmotionType.FEAR_TYPE, EmotionType.ANGET_TYPE}:
emotion_type='negative'
elif emotion_type=='neutral':
elif emotion_type==EmotionType.NEUTRAL_TYPE:
emotion_type='neutral'
group_key = (emotion_type, formatted_created_at)
# 累加emotion_intensity

View File

@@ -13,10 +13,14 @@ from typing import Any, Dict, List, Optional, Tuple
from app.core.logging_config import get_logger
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.conversation_repository import ConversationRepository
from app.repositories.end_user_repository import EndUserRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.implicit_memory_service import ImplicitMemoryService
from app.services.memory_base_service import MemoryBaseService
from app.services.memory_config_service import MemoryConfigService
from app.services.memory_perceptual_service import MemoryPerceptualService
from app.services.memory_short_service import ShortService
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
@@ -1196,18 +1200,17 @@ async def analytics_memory_types(
end_user_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
统计9种记忆类型的数量和百分比
统计8种记忆类型的数量和百分比
计算规则:
1. 感知记忆 (PERCEPTUAL_MEMORY) = statement + entity
2. 工作记忆 (WORKING_MEMORY) = chunk + entity
3. 短期记忆 (SHORT_TERM_MEMORY) = chunk
4. 长期记忆 (LONG_TERM_MEMORY) = entity
5. 性记忆 (EXPLICIT_MEMORY) = 情景记忆 + 语义记忆(通过 MemoryBaseService.get_explicit_memory_count 获取)
6. 隐性记忆 (IMPLICIT_MEMORY) = 1/3 * entity
7. 情记忆 (EMOTIONAL_MEMORY) = 情绪标签统计总数(通过 MemoryBaseService.get_emotional_memory_count 获取)
8. 情景记忆 (EPISODIC_MEMORY) = memory_summary(通过 MemoryBaseService.get_episodic_memory_count 获取)
9. 遗忘记忆 (FORGET_MEMORY) = 激活值低于阈值的节点数(通过 MemoryBaseService.get_forget_memory_count 获取)
1. 感知记忆 (PERCEPTUAL_MEMORY) = 通过 MemoryPerceptualService.get_memory_count 获取的 total_count
2. 工作记忆 (WORKING_MEMORY) = 会话数量(通过 ConversationRepository.get_conversation_by_user_id 获取)
3. 短期记忆 (SHORT_TERM_MEMORY) = /short_term 接口返回的问答对数量
4. 显性记忆 (EXPLICIT_MEMORY) = 情景记忆 + 语义记忆(通过 MemoryBaseService.get_explicit_memory_count 获取)
5. 性记忆 (IMPLICIT_MEMORY) = Statement 节点数量的三分之一
6. 情绪记忆 (EMOTIONAL_MEMORY) = 情绪标签统计总数(通过 MemoryBaseService.get_emotional_memory_count 获取)
7. 情记忆 (EPISODIC_MEMORY) = memory_summary(通过 MemoryBaseService.get_episodic_memory_count 获取)
8. 遗忘记忆 (FORGET_MEMORY) = 激活值低于阈值的节点数(通过 MemoryBaseService.get_forget_memory_count 获取)
Args:
db: 数据库会话
@@ -1227,7 +1230,6 @@ async def analytics_memory_types(
- PERCEPTUAL_MEMORY: 感知记忆
- WORKING_MEMORY: 工作记忆
- SHORT_TERM_MEMORY: 短期记忆
- LONG_TERM_MEMORY: 长期记忆
- EXPLICIT_MEMORY: 显性记忆
- IMPLICIT_MEMORY: 隐性记忆
- EMOTIONAL_MEMORY: 情绪记忆
@@ -1237,40 +1239,78 @@ async def analytics_memory_types(
# 初始化基础服务
base_service = MemoryBaseService()
# 定义需要查询的基础节点类型
node_types = {
"Statement": "Statement",
"Entity": "ExtractedEntity",
"Chunk": "Chunk"
}
# 初始化感知记忆服务
perceptual_service = MemoryPerceptualService(db)
# 存储每种节点类型的计数
node_counts = {}
# 获取感知记忆数量
if end_user_id:
perceptual_stats = perceptual_service.get_memory_count(uuid.UUID(end_user_id))
perceptual_count = perceptual_stats.get("total", 0)
else:
perceptual_count = 0
# 查询每种节点类型的数量
for key, node_type in node_types.items():
if end_user_id:
query = f"""
MATCH (n:{node_type})
# 获取工作记忆数量(基于会话数量
work_count = 0
if end_user_id:
try:
conversation_repo = ConversationRepository(db)
conversations = conversation_repo.get_conversation_by_user_id(
user_id=uuid.UUID(end_user_id),
limit=100, # 获取更多会话以准确统计
is_activate=True
)
work_count = len(conversations)
logger.debug(f"工作记忆数量(会话数): {work_count} (end_user_id={end_user_id})")
except Exception as e:
logger.warning(f"获取会话数量失败工作记忆数量设为0: {str(e)}")
work_count = 0
# 获取隐性记忆数量(基于 Statement 节点数量的三分之一)
implicit_count = 0
if end_user_id:
try:
# 查询 Statement 节点数量
query = """
MATCH (n:Statement)
WHERE n.group_id = $group_id
RETURN count(n) as count
"""
result = await _neo4j_connector.execute_query(query, group_id=end_user_id)
else:
query = f"""
MATCH (n:{node_type})
RETURN count(n) as count
"""
result = await _neo4j_connector.execute_query(query)
# 提取计数结果
count = result[0]["count"] if result and len(result) > 0 else 0
node_counts[key] = count
statement_count = result[0]["count"] if result and len(result) > 0 else 0
# 取三分之一作为隐性记忆数量
implicit_count = round(statement_count / 3)
logger.debug(f"隐性记忆数量Statement数量的1/3: {implicit_count} (Statement总数={statement_count}, end_user_id={end_user_id})")
except Exception as e:
logger.warning(f"获取Statement数量失败隐性记忆数量设为0: {str(e)}")
implicit_count = 0
# 获取各节点类型的数量
statement_count = node_counts.get("Statement", 0)
entity_count = node_counts.get("Entity", 0)
chunk_count = node_counts.get("Chunk", 0)
# 原有的基于行为习惯的统计方式(已注释)
# implicit_count = 0
# if end_user_id:
# try:
# implicit_service = ImplicitMemoryService(db, end_user_id)
# behavior_habits = await implicit_service.get_behavior_habits(
# user_id=end_user_id
# )
# implicit_count = len(behavior_habits)
# logger.debug(f"隐性记忆数量(行为习惯数): {implicit_count} (end_user_id={end_user_id})")
# except Exception as e:
# logger.warning(f"获取行为习惯数量失败隐性记忆数量设为0: {str(e)}")
# implicit_count = 0
# 获取短期记忆数量(基于 /short_term 接口返回的问答对数量)
short_term_count = 0
if end_user_id:
try:
short_term_service = ShortService(end_user_id)
short_term_data = short_term_service.get_short_databasets()
# 统计 short_term 数组的长度
if short_term_data:
short_term_count = len(short_term_data)
logger.debug(f"短期记忆数量(问答对数): {short_term_count} (end_user_id={end_user_id})")
except Exception as e:
logger.warning(f"获取短期记忆数量失败短期记忆数量设为0: {str(e)}")
short_term_count = 0
# 获取用户的遗忘阈值配置
forgetting_threshold = 0.3 # 默认值
@@ -1296,17 +1336,16 @@ async def analytics_memory_types(
# 使用 MemoryBaseService 的共享方法获取特殊记忆类型的数量
episodic_count = await base_service.get_episodic_memory_count(end_user_id)
explicit_count = await base_service.get_explicit_memory_count(end_user_id)
emotion_count = await base_service.get_emotional_memory_count(end_user_id, statement_count)
emotion_count = await base_service.get_emotional_memory_count(end_user_id, perceptual_count)
forget_count = await base_service.get_forget_memory_count(end_user_id, forgetting_threshold)
# 按规则计算9种记忆类型的数量使用英文枚举作为key
# 按规则计算8种记忆类型的数量使用英文枚举作为key
memory_counts = {
"PERCEPTUAL_MEMORY": statement_count + entity_count, # 感知记忆
"WORKING_MEMORY": chunk_count + entity_count, # 工作记忆
"SHORT_TERM_MEMORY": chunk_count, # 短期记忆
"LONG_TERM_MEMORY": entity_count, # 长期记忆
"PERCEPTUAL_MEMORY": perceptual_count, # 感知记忆
"WORKING_MEMORY": work_count, # 工作记忆(基于会话数量)
"SHORT_TERM_MEMORY": short_term_count, # 短期记忆(基于问答对数量)
"EXPLICIT_MEMORY": explicit_count, # 显性记忆(情景记忆 + 语义记忆)
"IMPLICIT_MEMORY": entity_count // 3, # 隐性记忆 (1/3 entity)
"IMPLICIT_MEMORY": implicit_count, # 隐性记忆Statement数量的1/3
"EMOTIONAL_MEMORY": emotion_count, # 情绪记忆(使用情绪标签统计)
"EPISODIC_MEMORY": episodic_count, # 情景记忆
"FORGET_MEMORY": forget_count # 遗忘记忆(激活值低于阈值)
@@ -1332,7 +1371,7 @@ async def analytics_graph_data(
db: Session,
end_user_id: str,
node_types: Optional[List[str]] = None,
limit: int = 100,
limit: int = 130,
depth: int = 1,
center_node_id: Optional[str] = None
) -> Dict[str, Any]:
@@ -1423,6 +1462,7 @@ async def analytics_graph_data(
"limit": limit
}
# 执行节点查询
node_results = await _neo4j_connector.execute_query(node_query, **node_params)
@@ -1567,9 +1607,9 @@ async def _extract_node_properties(label: str, properties: Dict[str, Any],node_
allowed_fields = field_whitelist.get(label, [])
# 如果没有定义白名单,返回空字典(或者可以返回所有字段)
if not allowed_fields:
# 对于未定义的节点类型,只返回基本字段
allowed_fields = ["name", "created_at", "caption"]
# if not allowed_fields:
# # 对于未定义的节点类型,只返回基本字段
# allowed_fields = ["name", "created_at", "caption"]
count_neo4j=f"""MATCH (n)-[r]-(m) WHERE elementId(n) ="{node_id}" RETURN count(r) AS rel_count;"""
node_results = await (_neo4j_connector.execute_query(count_neo4j))
# 提取白名单中的字段
@@ -1577,10 +1617,15 @@ async def _extract_node_properties(label: str, properties: Dict[str, Any],node_
for field in allowed_fields:
if field in properties:
value = properties[field]
if str(field) == 'entity_type':
value=type_mapping.get(value,'')
if str(field)=="emotion_type":
value=EmotionType.EMOTION_MAPPING.get(value)
if str(field)=="emotion_subject":
value=EmotionSubject.SUBJECT_MAPPING.get(value)
# 清理 Neo4j 特殊类型
filtered_props[field] = _clean_neo4j_value(value)
filtered_props['associative_memory']=[i['rel_count'] for i in node_results][0]
print(filtered_props)
return filtered_props
@@ -1621,6 +1666,5 @@ def _clean_neo4j_value(value: Any) -> Any:
return str(value)
except Exception:
return None
# 返回原始值
return value
return value

View File

@@ -181,10 +181,15 @@ const Conversation: FC = () => {
currentConversationId = newId
break
case 'message':
const { content } = item.data as { content: string }
updateAssistantMessage(content)
const { content, chunk, conversation_id: curId } = item.data as { content: string; chunk: string; conversation_id: string; }
updateAssistantMessage(content ?? chunk)
if (curId) {
currentConversationId = curId;
}
break
case 'end':
case 'workflow_end':
setLoading(false)
if (currentConversationId && currentConversationId !== conversation_id) {
setConversationId(currentConversationId)

View File

@@ -103,7 +103,11 @@ const ChatVariableModal = forwardRef<ChatVariableModalRef, ChatVariableModalProp
label={t('workflow.config.parameter-extractor.default')}
>
{type === 'number'
? <InputNumber placeholder={t('common.enter')} style={{ width: '100%' }} />
? <InputNumber
placeholder={t('common.enter')}
style={{ width: '100%' }}
onChange={(value) => form.setFieldValue('defaultValue', value)}
/>
: type === 'boolean'
? <Select
placeholder={t('common.pleaseSelect')}

View File

@@ -54,6 +54,7 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId
const handleClose = () => {
setOpen(false)
setChatList([])
setVariables([])
}
const handleEditVariables = () => {
variableConfigModalRef.current?.handleOpen(variables)

View File

@@ -80,7 +80,7 @@ const VariableConfigModal = forwardRef<VariableConfigModalRef, VariableEditModal
field.type === 'string' && <Input placeholder={t('common.pleaseEnter')} />
}
{
field.type === 'number' && <InputNumber placeholder={t('common.pleaseEnter')} style={{ width: '100%' }} />
field.type === 'number' && <InputNumber placeholder={t('common.pleaseEnter')} style={{ width: '100%' }} onChange={(value) => form.setFieldValue(['variables', name, 'value'], value)} />
}
{
field.type === 'boolean' && <Checkbox>{`${field.name}·${field.description}`}</Checkbox>

View File

@@ -46,7 +46,8 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
graph.addEdge({
source: { cell: edge.getSourceCellId(), port: edge.getSourcePortId() },
target: { cell: newNode.id, port: newNode.getPorts().find((port: any) => port.group === 'left')?.id || 'left' },
attrs: edge.getAttrs()
attrs: edge.getAttrs(),
zIndex: 3
});
});

View File

@@ -85,6 +85,7 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
},
},
},
zIndex: 0
});
// 循环节点内子节点通过连接桩添加时,调整循环节点大小

View File

@@ -60,7 +60,7 @@ const AssignmentList: FC<AssignmentListProps> = ({
>
<VariableSelect
placeholder={t('common.pleaseSelect')}
options={options.filter(vo => vo.nodeData.type === 'loop' || vo.value.includes('conv.'))}
options={options.filter(vo => vo.nodeData.type === 'loop' || vo.value.includes('conv.') || (vo.nodeData.type === 'iteration' && (vo.label === 'item' || vo.label === 'index')))}
popupMatchSelectWidth={false}
onChange={() => {
form.setFieldValue([parentName, name, 'operation'], undefined);

View File

@@ -40,8 +40,6 @@ const operatorsObj: { [key: string]: SelectProps['options'] } = {
boolean: [
{ value: 'eq', label: 'workflow.config.if-else.boolean.eq' },
{ value: 'ne', label: 'workflow.config.if-else.boolean.ne' },
{ value: 'empty', label: 'workflow.config.if-else.empty' },
{ value: 'not_empty', label: 'workflow.config.if-else.not_empty' },
]
}
@@ -85,11 +83,14 @@ const CaseList: FC<CaseListProps> = ({
selectedNode.prop('size', { width: 240, height: newHeight })
// 计算端口间距
const dy = totalPorts;
// 添加 IF 端口
selectedNode.addPort({
id: 'CASE1',
group: 'right',
args: { dy: 24 },
// args: { dy },
attrs: { text: { text: 'IF', fontSize: 12, fill: '#5B6167' }}
});
@@ -98,6 +99,7 @@ const CaseList: FC<CaseListProps> = ({
selectedNode.addPort({
id: `CASE${i + 1}`,
group: 'right',
// args: { dy },
attrs: { text: { text: 'ELIF', fontSize: 12, fill: '#5B6167' }}
});
}
@@ -106,11 +108,22 @@ const CaseList: FC<CaseListProps> = ({
selectedNode.addPort({
id: `CASE${caseCount + 1}`,
group: 'right',
// args: { dy },
attrs: { text: { text: 'ELSE', fontSize: 12, fill: '#5B6167' }}
});
// 恢复仍然存在的端口连线
setTimeout(() => {
// 计算删除前的总端口数来确定原ELSE端口编号
const originalCaseCount = removedCaseIndex !== undefined ? caseCount + 1 : caseCount;
const originalElsePortNumber = originalCaseCount + 1;
// 检查ELSE端口是否有连线
const elseHasConnection = edgeConnections.some(({ sourcePortId, isIncoming }: any) => {
const caseNumber = parseInt(sourcePortId.match(/CASE(\d+)/)?.[1] || '0');
return !isIncoming && caseNumber === originalElsePortNumber;
});
edgeConnections.forEach(({ edge, sourcePortId, targetCellId, targetPortId, sourceCellId, isIncoming }: any) => {
// 如果是进入连线(左侧端口),直接恢复
if (isIncoming) {
@@ -138,7 +151,7 @@ const CaseList: FC<CaseListProps> = ({
// 处理右侧端口连线
const originalCaseNumber = parseInt(sourcePortId.match(/CASE(\d+)/)?.[1] || '0');
// 如果是被删除的端口,不重新创建连线
// 如果是被删除的端口,只删除该端口的连线
if (removedCaseIndex !== undefined && originalCaseNumber === removedCaseIndex + 1) {
graphRef.current?.removeCell(edge);
return;
@@ -146,15 +159,8 @@ const CaseList: FC<CaseListProps> = ({
let newPortId = sourcePortId;
// 如果是原来的ELSE端口重新映射到新的ELSE端口
const maxOriginalCaseNumber = Math.max(...edgeConnections
.filter(({ isIncoming }: any) => !isIncoming)
.map(({ sourcePortId }: any) => {
const match = sourcePortId.match(/CASE(\d+)/);
return match ? parseInt(match[1]) : 0;
}));
if (originalCaseNumber === maxOriginalCaseNumber) {
// 如果是原来的ELSE端口且有连线重新映射到新的ELSE端口
if (originalCaseNumber === originalElsePortNumber && elseHasConnection) {
newPortId = `CASE${caseCount + 1}`; // 新的ELSE端口
} else if (removedCaseIndex !== undefined && originalCaseNumber > removedCaseIndex + 1) {
// 如果是被删除端口之后的端口,编号向前移动

View File

@@ -53,6 +53,7 @@ const operatorsObj: { [key: string]: SelectProps['options'] } = {
const ConditionList: FC<CaseListProps> = ({
options,
parentName,
selectedNode,
}) => {
const { t } = useTranslation();
const form = Form.useFormInstance();
@@ -114,7 +115,12 @@ const ConditionList: FC<CaseListProps> = ({
<Col span={14}>
<Form.Item name={[field.name, 'left']} noStyle>
<VariableSelect
options={options.filter(vo => vo.value.includes('sys.') || vo.value.includes('conv.') || vo.nodeData.type === 'loop')}
options={options.filter(vo =>
vo.value.includes('sys.') ||
vo.value.includes('conv.') ||
vo.nodeData.type === 'loop' ||
(vo.nodeData.cycle && vo.nodeData.cycle === selectedNode?.id)
)}
size="small"
allowClear={false}
popupMatchSelectWidth={false}

View File

@@ -566,42 +566,6 @@ const Properties: FC<PropertiesProps> = ({
nodeData: nodeData,
});
}
// Check if connected via ERROR connection point
const errorEdges = edges.filter(edge =>
edge.getTargetCellId() === selectedNode.id &&
edge.getSourceCellId() === nodeId &&
edge.getSourcePortId() === 'ERROR'
);
if (errorEdges.length > 0) {
const errorMessageKey = `${dataNodeId}_error_message`;
const errorTypeKey = `${dataNodeId}_error_type`;
if (!addedKeys.has(errorMessageKey)) {
addedKeys.add(errorMessageKey);
variableList.push({
key: errorMessageKey,
label: 'error_message',
type: 'variable',
dataType: 'string',
value: `${dataNodeId}.error_message`,
nodeData: nodeData,
});
}
if (!addedKeys.has(errorTypeKey)) {
addedKeys.add(errorTypeKey);
variableList.push({
key: errorTypeKey,
label: 'error_type',
type: 'variable',
dataType: 'string',
value: `${dataNodeId}.error_type`,
nodeData: nodeData,
});
}
}
break
case 'jinja-render':
const jinjaOutputKey = `${dataNodeId}_output`;
@@ -793,54 +757,6 @@ const Properties: FC<PropertiesProps> = ({
}
}
}
// Check if parent loop/iteration is connected to http-request via ERROR connection
if (parentData.type === 'loop' || parentData.type === 'iteration') {
const parentPreviousNodeIds = getAllPreviousNodes(parentLoopNode.id);
parentPreviousNodeIds.forEach(prevNodeId => {
const prevNode = nodes.find(n => n.id === prevNodeId);
if (!prevNode) return;
const prevNodeData = prevNode.getData();
if (prevNodeData.type === 'http-request') {
// Check if connected via ERROR connection point
const errorEdges = edges.filter(edge => {
return edge.getTargetCellId() === parentLoopNode.id &&
edge.getSourceCellId() === prevNodeId &&
edge.getSourcePortId() === 'ERROR'
});
if (errorEdges.length > 0) {
const errorMessageKey = `${prevNodeData.id}_error_message`;
const errorTypeKey = `${prevNodeData.id}_error_type`;
if (!addedKeys.has(errorMessageKey)) {
addedKeys.add(errorMessageKey);
variableList.push({
key: errorMessageKey,
label: 'error_message',
type: 'variable',
dataType: 'string',
value: `${prevNodeData.id}.error_message`,
nodeData: prevNodeData,
});
}
if (!addedKeys.has(errorTypeKey)) {
addedKeys.add(errorTypeKey);
variableList.push({
key: errorTypeKey,
label: 'error_type',
type: 'variable',
dataType: 'string',
value: `${prevNodeData.id}.error_type`,
nodeData: prevNodeData,
});
}
}
}
});
}
}
return variableList;
@@ -999,12 +915,12 @@ const Properties: FC<PropertiesProps> = ({
return filteredList;
}
if (nodeType === 'knowledge-retrieval' || nodeType === 'parameter-extractor' && key !== 'prompt' || nodeType === 'memory-read' || nodeType === 'memory-write' || nodeType === 'question-classifier') {
let filteredList = variableList.filter(variable => variable.dataType === 'string');
return addParentIterationVars(filteredList);
let filteredList = addParentIterationVars(variableList).filter(variable => variable.dataType === 'string');
return filteredList;
}
if (nodeType === 'parameter-extractor' && key === 'prompt') {
let filteredList = variableList.filter(variable => variable.dataType === 'string' || variable.dataType === 'number');
return addParentIterationVars(filteredList);
let filteredList = addParentIterationVars(variableList).filter(variable => variable.dataType === 'string' || variable.dataType === 'number');
return filteredList;
}
if (nodeType === 'iteration' && key === 'output') {
return variableList.filter(variable => variable.value.includes('sys.'));
@@ -1013,8 +929,71 @@ const Properties: FC<PropertiesProps> = ({
return variableList.filter(variable => variable.dataType.includes('array'));
}
if (nodeType === 'loop' && key === 'condition') {
let filteredList = variableList.filter(variable => variable.nodeData.type !== 'loop');
return addParentIterationVars(filteredList);
let filteredList = addParentIterationVars(variableList).filter(variable => variable.nodeData.type !== 'loop');
// Add child node output variables for loop nodes
if (selectedNode) {
const graph = graphRef.current;
if (graph) {
const nodes = graph.getNodes();
const childNodes = nodes.filter(node => {
const nodeData = node.getData();
return nodeData?.cycle === selectedNode.id;
});
// Add output variables from child nodes
childNodes.forEach(childNode => {
const childData = childNode.getData();
const childNodeId = childData.id;
// Add child node output variables based on their type
switch(childData.type) {
case 'llm':
case 'jinja-render':
case 'tool':
const outputKey = `${childNodeId}_output`;
const existingOutput = filteredList.find(v => v.key === outputKey);
if (!existingOutput) {
filteredList.push({
key: outputKey,
label: 'output',
type: 'variable',
dataType: 'string',
value: `${childNodeId}.output`,
nodeData: childData,
});
}
break;
case 'http-request':
const bodyKey = `${childNodeId}_body`;
const statusKey = `${childNodeId}_status_code`;
if (!filteredList.find(v => v.key === bodyKey)) {
filteredList.push({
key: bodyKey,
label: 'body',
type: 'variable',
dataType: 'string',
value: `${childNodeId}.body`,
nodeData: childData,
});
}
if (!filteredList.find(v => v.key === statusKey)) {
filteredList.push({
key: statusKey,
label: 'status_code',
type: 'variable',
dataType: 'number',
value: `${childNodeId}.status_code`,
nodeData: childData,
});
}
break;
}
});
}
}
return filteredList;
}
// For all other node types, add parent iteration variables if applicable
@@ -1025,7 +1004,7 @@ const Properties: FC<PropertiesProps> = ({
// const defaultVariableList = calculateVariableList(selectedNode as Node, graphRef, workflowConfig )
console.log('values', values)
// console.log('variableList', variableList, defaultVariableList)
console.log('variableList', variableList)
return (
<div className="rb:w-75 rb:fixed rb:right-0 rb:top-16 rb:bottom-0 rb:p-3">