Merge branch 'develop' into feature/knowledgeBase_yjp
This commit is contained in:
@@ -583,7 +583,7 @@ async def chat(
|
||||
event_data = event.get("data", {})
|
||||
|
||||
# 转换为标准 SSE 格式(字符串)
|
||||
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
|
||||
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data, default=str, ensure_ascii=False)}\n\n"
|
||||
yield sse_message
|
||||
|
||||
return StreamingResponse(
|
||||
|
||||
@@ -425,15 +425,9 @@ async def Input_Summary(
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, "template_service")
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
search_service = get_context_resource(ctx, "search_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages) or ""
|
||||
sessionid = sessionid.replace('call_id_', '')
|
||||
@@ -539,31 +533,11 @@ async def Input_Summary(
|
||||
)
|
||||
retrieve_info, question, raw_results = "", query, []
|
||||
|
||||
# Return retrieved information directly without LLM processing
|
||||
# Use the raw retrieved info as the answer
|
||||
aimessages = retrieve_info if retrieve_info else "信息不足,无法回答"
|
||||
|
||||
# Render template
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='Retrieve_Summary_prompt.jinja2',
|
||||
operation_name='input_summary',
|
||||
query=query,
|
||||
history=history,
|
||||
retrieve_info=retrieve_info
|
||||
)
|
||||
|
||||
# Call LLM with structured response
|
||||
try:
|
||||
structured = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=RetrieveSummaryResponse
|
||||
)
|
||||
aimessages = structured.data.query_answer or "信息不足,无法回答"
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Input_Summary: response_structured failed, using default answer: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
logger.info(f"Quick answer summary: {storage_type}--{user_rag_memory_id}--{aimessages}")
|
||||
logger.info(f"Quick answer (no LLM): {storage_type}--{user_rag_memory_id}--{aimessages[:500]}...")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
return {
|
||||
|
||||
@@ -10,9 +10,6 @@ from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
# 为了兼容性,创建别名
|
||||
# SchemaParser = OpenAPISchemaParser = None
|
||||
|
||||
|
||||
class OpenAPISchemaParser:
|
||||
"""OpenAPI Schema解析器 - 解析OpenAPI 3.0规范"""
|
||||
@@ -213,7 +210,9 @@ class OpenAPISchemaParser:
|
||||
|
||||
if not isinstance(operation, dict):
|
||||
continue
|
||||
|
||||
|
||||
summary = operation.get("summary", "")
|
||||
|
||||
# 生成操作ID
|
||||
operation_id = operation.get("operationId")
|
||||
if not operation_id:
|
||||
@@ -223,7 +222,7 @@ class OpenAPISchemaParser:
|
||||
operations[operation_id] = {
|
||||
"method": method.upper(),
|
||||
"path": path,
|
||||
"summary": operation.get("summary", ""),
|
||||
"summary": summary if summary else operation_id,
|
||||
"description": operation.get("description", ""),
|
||||
"parameters": self._extract_parameters(operation),
|
||||
"request_body": self._extract_request_body(operation),
|
||||
|
||||
@@ -226,6 +226,7 @@ class LLMNode(BaseNode):
|
||||
Yields:
|
||||
文本片段(chunk)或完成标记
|
||||
"""
|
||||
self.typed_config = LLMNodeConfig(**self.config)
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
llm, prompt_or_messages = self._prepare_llm(state, True)
|
||||
|
||||
@@ -1,9 +1,51 @@
|
||||
"""
|
||||
情景记忆的请求和响应模型
|
||||
"""
|
||||
from abc import ABC
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
|
||||
type_mapping = {
|
||||
"Person": "人物实体节点",
|
||||
"Organization": "组织实体节点",
|
||||
"ORG": "组织实体节点",
|
||||
"Location": "地点实体节点",
|
||||
"LOC": "地点实体节点",
|
||||
"Event": "事件实体节点",
|
||||
"Concept": "概念实体节点",
|
||||
"Time": "时间实体节点",
|
||||
"Position": "职位实体节点",
|
||||
"WorkRole": "职业实体节点",
|
||||
"System": "系统实体节点",
|
||||
"Policy": "政策实体节点",
|
||||
"HistoricalPeriod": "历史时期实体节点",
|
||||
"HistoricalState": "历史国家实体节点",
|
||||
"HistoricalEvent": "历史事件实体节点",
|
||||
"EconomicFactor": "经济因素实体节点",
|
||||
"Condition": "条件实体节点",
|
||||
"Numeric": "数值实体节点"
|
||||
}
|
||||
class EmotionType(ABC):
|
||||
JOY_TYPE = "joy"
|
||||
SURPRISE_TYPE = "surprise"
|
||||
SANDROWNESS_TYPE = "sadness"
|
||||
FEAR_TYPE = "fear"
|
||||
ANGET_TYPE="anger"
|
||||
NEUTRAL_TYPE="neutral"
|
||||
EMOTION_MAPPING={
|
||||
"joy":"愉快",
|
||||
"surprise":"惊喜",
|
||||
"sadness":"悲伤",
|
||||
"fear":"恐惧",
|
||||
"anger":"生气",
|
||||
"neutral":"中性"
|
||||
}
|
||||
class EmotionSubject(ABC):
|
||||
SUBJECT_MAPPING={
|
||||
"self":"自己",
|
||||
"other":"别人",
|
||||
"object":"事物对象"
|
||||
}
|
||||
|
||||
class EpisodicMemoryOverviewRequest(BaseModel):
|
||||
"""情景记忆总览查询请求"""
|
||||
|
||||
@@ -15,6 +15,8 @@ from neo4j.time import DateTime as Neo4jDateTime
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from app.schemas.memory_episodic_schema import EmotionType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MemoryEntityService:
|
||||
@@ -123,7 +125,7 @@ class MemoryEntityService:
|
||||
extracted_entity_list = self._deduplicate_dict_list(extracted_entity_list)
|
||||
|
||||
# 合并所有数据并处理相同text的合并
|
||||
all_timeline_data = memory_summary_list + statement_list + extracted_entity_list
|
||||
all_timeline_data = memory_summary_list + statement_list
|
||||
all_timeline_data = self._merge_same_text_items(all_timeline_data)
|
||||
|
||||
result = {
|
||||
@@ -496,11 +498,11 @@ class MemoryEmotion:
|
||||
length_data.append(emotion_intensity)
|
||||
if emotion_type is not None and emotion_intensity is not None and formatted_created_at is not None:
|
||||
# 使用(emotion_type, created_at)作为分组键
|
||||
if emotion_type in {"joy", "surprise"}:
|
||||
if emotion_type in {EmotionType.JOY_TYPE, EmotionType.SURPRISE_TYPE}:
|
||||
emotion_type='positive'
|
||||
elif emotion_type in {"sadness", "fear", "anger"}:
|
||||
elif emotion_type in {EmotionType.SANDROWNESS_TYPE, EmotionType.FEAR_TYPE, EmotionType.ANGET_TYPE}:
|
||||
emotion_type='negative'
|
||||
elif emotion_type=='neutral':
|
||||
elif emotion_type==EmotionType.NEUTRAL_TYPE:
|
||||
emotion_type='neutral'
|
||||
group_key = (emotion_type, formatted_created_at)
|
||||
# 累加emotion_intensity
|
||||
|
||||
@@ -13,10 +13,14 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.conversation_repository import ConversationRepository
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
from app.services.memory_base_service import MemoryBaseService
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||
from app.services.memory_short_service import ShortService
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -1196,18 +1200,17 @@ async def analytics_memory_types(
|
||||
end_user_id: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
统计9种记忆类型的数量和百分比
|
||||
统计8种记忆类型的数量和百分比
|
||||
|
||||
计算规则:
|
||||
1. 感知记忆 (PERCEPTUAL_MEMORY) = statement + entity
|
||||
2. 工作记忆 (WORKING_MEMORY) = chunk + entity
|
||||
3. 短期记忆 (SHORT_TERM_MEMORY) = chunk
|
||||
4. 长期记忆 (LONG_TERM_MEMORY) = entity
|
||||
5. 显性记忆 (EXPLICIT_MEMORY) = 情景记忆 + 语义记忆(通过 MemoryBaseService.get_explicit_memory_count 获取)
|
||||
6. 隐性记忆 (IMPLICIT_MEMORY) = 1/3 * entity
|
||||
7. 情绪记忆 (EMOTIONAL_MEMORY) = 情绪标签统计总数(通过 MemoryBaseService.get_emotional_memory_count 获取)
|
||||
8. 情景记忆 (EPISODIC_MEMORY) = memory_summary(通过 MemoryBaseService.get_episodic_memory_count 获取)
|
||||
9. 遗忘记忆 (FORGET_MEMORY) = 激活值低于阈值的节点数(通过 MemoryBaseService.get_forget_memory_count 获取)
|
||||
1. 感知记忆 (PERCEPTUAL_MEMORY) = 通过 MemoryPerceptualService.get_memory_count 获取的 total_count
|
||||
2. 工作记忆 (WORKING_MEMORY) = 会话数量(通过 ConversationRepository.get_conversation_by_user_id 获取)
|
||||
3. 短期记忆 (SHORT_TERM_MEMORY) = /short_term 接口返回的问答对数量
|
||||
4. 显性记忆 (EXPLICIT_MEMORY) = 情景记忆 + 语义记忆(通过 MemoryBaseService.get_explicit_memory_count 获取)
|
||||
5. 隐性记忆 (IMPLICIT_MEMORY) = Statement 节点数量的三分之一
|
||||
6. 情绪记忆 (EMOTIONAL_MEMORY) = 情绪标签统计总数(通过 MemoryBaseService.get_emotional_memory_count 获取)
|
||||
7. 情景记忆 (EPISODIC_MEMORY) = memory_summary(通过 MemoryBaseService.get_episodic_memory_count 获取)
|
||||
8. 遗忘记忆 (FORGET_MEMORY) = 激活值低于阈值的节点数(通过 MemoryBaseService.get_forget_memory_count 获取)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
@@ -1227,7 +1230,6 @@ async def analytics_memory_types(
|
||||
- PERCEPTUAL_MEMORY: 感知记忆
|
||||
- WORKING_MEMORY: 工作记忆
|
||||
- SHORT_TERM_MEMORY: 短期记忆
|
||||
- LONG_TERM_MEMORY: 长期记忆
|
||||
- EXPLICIT_MEMORY: 显性记忆
|
||||
- IMPLICIT_MEMORY: 隐性记忆
|
||||
- EMOTIONAL_MEMORY: 情绪记忆
|
||||
@@ -1237,40 +1239,78 @@ async def analytics_memory_types(
|
||||
# 初始化基础服务
|
||||
base_service = MemoryBaseService()
|
||||
|
||||
# 定义需要查询的基础节点类型
|
||||
node_types = {
|
||||
"Statement": "Statement",
|
||||
"Entity": "ExtractedEntity",
|
||||
"Chunk": "Chunk"
|
||||
}
|
||||
# 初始化感知记忆服务
|
||||
perceptual_service = MemoryPerceptualService(db)
|
||||
|
||||
# 存储每种节点类型的计数
|
||||
node_counts = {}
|
||||
# 获取感知记忆数量
|
||||
if end_user_id:
|
||||
perceptual_stats = perceptual_service.get_memory_count(uuid.UUID(end_user_id))
|
||||
perceptual_count = perceptual_stats.get("total", 0)
|
||||
else:
|
||||
perceptual_count = 0
|
||||
|
||||
# 查询每种节点类型的数量
|
||||
for key, node_type in node_types.items():
|
||||
if end_user_id:
|
||||
query = f"""
|
||||
MATCH (n:{node_type})
|
||||
# 获取工作记忆数量(基于会话数量)
|
||||
work_count = 0
|
||||
if end_user_id:
|
||||
try:
|
||||
conversation_repo = ConversationRepository(db)
|
||||
conversations = conversation_repo.get_conversation_by_user_id(
|
||||
user_id=uuid.UUID(end_user_id),
|
||||
limit=100, # 获取更多会话以准确统计
|
||||
is_activate=True
|
||||
)
|
||||
work_count = len(conversations)
|
||||
logger.debug(f"工作记忆数量(会话数): {work_count} (end_user_id={end_user_id})")
|
||||
except Exception as e:
|
||||
logger.warning(f"获取会话数量失败,工作记忆数量设为0: {str(e)}")
|
||||
work_count = 0
|
||||
|
||||
# 获取隐性记忆数量(基于 Statement 节点数量的三分之一)
|
||||
implicit_count = 0
|
||||
if end_user_id:
|
||||
try:
|
||||
# 查询 Statement 节点数量
|
||||
query = """
|
||||
MATCH (n:Statement)
|
||||
WHERE n.group_id = $group_id
|
||||
RETURN count(n) as count
|
||||
"""
|
||||
result = await _neo4j_connector.execute_query(query, group_id=end_user_id)
|
||||
else:
|
||||
query = f"""
|
||||
MATCH (n:{node_type})
|
||||
RETURN count(n) as count
|
||||
"""
|
||||
result = await _neo4j_connector.execute_query(query)
|
||||
|
||||
# 提取计数结果
|
||||
count = result[0]["count"] if result and len(result) > 0 else 0
|
||||
node_counts[key] = count
|
||||
statement_count = result[0]["count"] if result and len(result) > 0 else 0
|
||||
# 取三分之一作为隐性记忆数量
|
||||
implicit_count = round(statement_count / 3)
|
||||
logger.debug(f"隐性记忆数量(Statement数量的1/3): {implicit_count} (Statement总数={statement_count}, end_user_id={end_user_id})")
|
||||
except Exception as e:
|
||||
logger.warning(f"获取Statement数量失败,隐性记忆数量设为0: {str(e)}")
|
||||
implicit_count = 0
|
||||
|
||||
# 获取各节点类型的数量
|
||||
statement_count = node_counts.get("Statement", 0)
|
||||
entity_count = node_counts.get("Entity", 0)
|
||||
chunk_count = node_counts.get("Chunk", 0)
|
||||
# 原有的基于行为习惯的统计方式(已注释)
|
||||
# implicit_count = 0
|
||||
# if end_user_id:
|
||||
# try:
|
||||
# implicit_service = ImplicitMemoryService(db, end_user_id)
|
||||
# behavior_habits = await implicit_service.get_behavior_habits(
|
||||
# user_id=end_user_id
|
||||
# )
|
||||
# implicit_count = len(behavior_habits)
|
||||
# logger.debug(f"隐性记忆数量(行为习惯数): {implicit_count} (end_user_id={end_user_id})")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"获取行为习惯数量失败,隐性记忆数量设为0: {str(e)}")
|
||||
# implicit_count = 0
|
||||
|
||||
# 获取短期记忆数量(基于 /short_term 接口返回的问答对数量)
|
||||
short_term_count = 0
|
||||
if end_user_id:
|
||||
try:
|
||||
short_term_service = ShortService(end_user_id)
|
||||
short_term_data = short_term_service.get_short_databasets()
|
||||
# 统计 short_term 数组的长度
|
||||
if short_term_data:
|
||||
short_term_count = len(short_term_data)
|
||||
logger.debug(f"短期记忆数量(问答对数): {short_term_count} (end_user_id={end_user_id})")
|
||||
except Exception as e:
|
||||
logger.warning(f"获取短期记忆数量失败,短期记忆数量设为0: {str(e)}")
|
||||
short_term_count = 0
|
||||
|
||||
# 获取用户的遗忘阈值配置
|
||||
forgetting_threshold = 0.3 # 默认值
|
||||
@@ -1296,17 +1336,16 @@ async def analytics_memory_types(
|
||||
# 使用 MemoryBaseService 的共享方法获取特殊记忆类型的数量
|
||||
episodic_count = await base_service.get_episodic_memory_count(end_user_id)
|
||||
explicit_count = await base_service.get_explicit_memory_count(end_user_id)
|
||||
emotion_count = await base_service.get_emotional_memory_count(end_user_id, statement_count)
|
||||
emotion_count = await base_service.get_emotional_memory_count(end_user_id, perceptual_count)
|
||||
forget_count = await base_service.get_forget_memory_count(end_user_id, forgetting_threshold)
|
||||
|
||||
# 按规则计算9种记忆类型的数量(使用英文枚举作为key)
|
||||
# 按规则计算8种记忆类型的数量(使用英文枚举作为key)
|
||||
memory_counts = {
|
||||
"PERCEPTUAL_MEMORY": statement_count + entity_count, # 感知记忆
|
||||
"WORKING_MEMORY": chunk_count + entity_count, # 工作记忆
|
||||
"SHORT_TERM_MEMORY": chunk_count, # 短期记忆
|
||||
"LONG_TERM_MEMORY": entity_count, # 长期记忆
|
||||
"PERCEPTUAL_MEMORY": perceptual_count, # 感知记忆
|
||||
"WORKING_MEMORY": work_count, # 工作记忆(基于会话数量)
|
||||
"SHORT_TERM_MEMORY": short_term_count, # 短期记忆(基于问答对数量)
|
||||
"EXPLICIT_MEMORY": explicit_count, # 显性记忆(情景记忆 + 语义记忆)
|
||||
"IMPLICIT_MEMORY": entity_count // 3, # 隐性记忆 (1/3 entity)
|
||||
"IMPLICIT_MEMORY": implicit_count, # 隐性记忆(Statement数量的1/3)
|
||||
"EMOTIONAL_MEMORY": emotion_count, # 情绪记忆(使用情绪标签统计)
|
||||
"EPISODIC_MEMORY": episodic_count, # 情景记忆
|
||||
"FORGET_MEMORY": forget_count # 遗忘记忆(激活值低于阈值)
|
||||
@@ -1332,7 +1371,7 @@ async def analytics_graph_data(
|
||||
db: Session,
|
||||
end_user_id: str,
|
||||
node_types: Optional[List[str]] = None,
|
||||
limit: int = 100,
|
||||
limit: int = 130,
|
||||
depth: int = 1,
|
||||
center_node_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
@@ -1423,6 +1462,7 @@ async def analytics_graph_data(
|
||||
"limit": limit
|
||||
}
|
||||
|
||||
|
||||
# 执行节点查询
|
||||
node_results = await _neo4j_connector.execute_query(node_query, **node_params)
|
||||
|
||||
@@ -1567,9 +1607,9 @@ async def _extract_node_properties(label: str, properties: Dict[str, Any],node_
|
||||
allowed_fields = field_whitelist.get(label, [])
|
||||
|
||||
# 如果没有定义白名单,返回空字典(或者可以返回所有字段)
|
||||
if not allowed_fields:
|
||||
# 对于未定义的节点类型,只返回基本字段
|
||||
allowed_fields = ["name", "created_at", "caption"]
|
||||
# if not allowed_fields:
|
||||
# # 对于未定义的节点类型,只返回基本字段
|
||||
# allowed_fields = ["name", "created_at", "caption"]
|
||||
count_neo4j=f"""MATCH (n)-[r]-(m) WHERE elementId(n) ="{node_id}" RETURN count(r) AS rel_count;"""
|
||||
node_results = await (_neo4j_connector.execute_query(count_neo4j))
|
||||
# 提取白名单中的字段
|
||||
@@ -1577,10 +1617,15 @@ async def _extract_node_properties(label: str, properties: Dict[str, Any],node_
|
||||
for field in allowed_fields:
|
||||
if field in properties:
|
||||
value = properties[field]
|
||||
if str(field) == 'entity_type':
|
||||
value=type_mapping.get(value,'')
|
||||
if str(field)=="emotion_type":
|
||||
value=EmotionType.EMOTION_MAPPING.get(value)
|
||||
if str(field)=="emotion_subject":
|
||||
value=EmotionSubject.SUBJECT_MAPPING.get(value)
|
||||
# 清理 Neo4j 特殊类型
|
||||
filtered_props[field] = _clean_neo4j_value(value)
|
||||
filtered_props['associative_memory']=[i['rel_count'] for i in node_results][0]
|
||||
print(filtered_props)
|
||||
return filtered_props
|
||||
|
||||
|
||||
@@ -1621,6 +1666,5 @@ def _clean_neo4j_value(value: Any) -> Any:
|
||||
return str(value)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# 返回原始值
|
||||
return value
|
||||
return value
|
||||
@@ -181,10 +181,15 @@ const Conversation: FC = () => {
|
||||
currentConversationId = newId
|
||||
break
|
||||
case 'message':
|
||||
const { content } = item.data as { content: string }
|
||||
updateAssistantMessage(content)
|
||||
const { content, chunk, conversation_id: curId } = item.data as { content: string; chunk: string; conversation_id: string; }
|
||||
updateAssistantMessage(content ?? chunk)
|
||||
|
||||
if (curId) {
|
||||
currentConversationId = curId;
|
||||
}
|
||||
break
|
||||
case 'end':
|
||||
case 'workflow_end':
|
||||
setLoading(false)
|
||||
if (currentConversationId && currentConversationId !== conversation_id) {
|
||||
setConversationId(currentConversationId)
|
||||
|
||||
@@ -103,7 +103,11 @@ const ChatVariableModal = forwardRef<ChatVariableModalRef, ChatVariableModalProp
|
||||
label={t('workflow.config.parameter-extractor.default')}
|
||||
>
|
||||
{type === 'number'
|
||||
? <InputNumber placeholder={t('common.enter')} style={{ width: '100%' }} />
|
||||
? <InputNumber
|
||||
placeholder={t('common.enter')}
|
||||
style={{ width: '100%' }}
|
||||
onChange={(value) => form.setFieldValue('defaultValue', value)}
|
||||
/>
|
||||
: type === 'boolean'
|
||||
? <Select
|
||||
placeholder={t('common.pleaseSelect')}
|
||||
|
||||
@@ -54,6 +54,7 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId
|
||||
const handleClose = () => {
|
||||
setOpen(false)
|
||||
setChatList([])
|
||||
setVariables([])
|
||||
}
|
||||
const handleEditVariables = () => {
|
||||
variableConfigModalRef.current?.handleOpen(variables)
|
||||
|
||||
@@ -80,7 +80,7 @@ const VariableConfigModal = forwardRef<VariableConfigModalRef, VariableEditModal
|
||||
field.type === 'string' && <Input placeholder={t('common.pleaseEnter')} />
|
||||
}
|
||||
{
|
||||
field.type === 'number' && <InputNumber placeholder={t('common.pleaseEnter')} style={{ width: '100%' }} />
|
||||
field.type === 'number' && <InputNumber placeholder={t('common.pleaseEnter')} style={{ width: '100%' }} onChange={(value) => form.setFieldValue(['variables', name, 'value'], value)} />
|
||||
}
|
||||
{
|
||||
field.type === 'boolean' && <Checkbox>{`${field.name}·${field.description}`}</Checkbox>
|
||||
|
||||
@@ -46,7 +46,8 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
|
||||
graph.addEdge({
|
||||
source: { cell: edge.getSourceCellId(), port: edge.getSourcePortId() },
|
||||
target: { cell: newNode.id, port: newNode.getPorts().find((port: any) => port.group === 'left')?.id || 'left' },
|
||||
attrs: edge.getAttrs()
|
||||
attrs: edge.getAttrs(),
|
||||
zIndex: 3
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -85,6 +85,7 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
|
||||
},
|
||||
},
|
||||
},
|
||||
zIndex: 0
|
||||
});
|
||||
|
||||
// 循环节点内子节点通过连接桩添加时,调整循环节点大小
|
||||
|
||||
@@ -60,7 +60,7 @@ const AssignmentList: FC<AssignmentListProps> = ({
|
||||
>
|
||||
<VariableSelect
|
||||
placeholder={t('common.pleaseSelect')}
|
||||
options={options.filter(vo => vo.nodeData.type === 'loop' || vo.value.includes('conv.'))}
|
||||
options={options.filter(vo => vo.nodeData.type === 'loop' || vo.value.includes('conv.') || (vo.nodeData.type === 'iteration' && (vo.label === 'item' || vo.label === 'index')))}
|
||||
popupMatchSelectWidth={false}
|
||||
onChange={() => {
|
||||
form.setFieldValue([parentName, name, 'operation'], undefined);
|
||||
|
||||
@@ -40,8 +40,6 @@ const operatorsObj: { [key: string]: SelectProps['options'] } = {
|
||||
boolean: [
|
||||
{ value: 'eq', label: 'workflow.config.if-else.boolean.eq' },
|
||||
{ value: 'ne', label: 'workflow.config.if-else.boolean.ne' },
|
||||
{ value: 'empty', label: 'workflow.config.if-else.empty' },
|
||||
{ value: 'not_empty', label: 'workflow.config.if-else.not_empty' },
|
||||
]
|
||||
}
|
||||
|
||||
@@ -85,11 +83,14 @@ const CaseList: FC<CaseListProps> = ({
|
||||
|
||||
selectedNode.prop('size', { width: 240, height: newHeight })
|
||||
|
||||
// 计算端口间距
|
||||
const dy = totalPorts;
|
||||
|
||||
// 添加 IF 端口
|
||||
selectedNode.addPort({
|
||||
id: 'CASE1',
|
||||
group: 'right',
|
||||
args: { dy: 24 },
|
||||
// args: { dy },
|
||||
attrs: { text: { text: 'IF', fontSize: 12, fill: '#5B6167' }}
|
||||
});
|
||||
|
||||
@@ -98,6 +99,7 @@ const CaseList: FC<CaseListProps> = ({
|
||||
selectedNode.addPort({
|
||||
id: `CASE${i + 1}`,
|
||||
group: 'right',
|
||||
// args: { dy },
|
||||
attrs: { text: { text: 'ELIF', fontSize: 12, fill: '#5B6167' }}
|
||||
});
|
||||
}
|
||||
@@ -106,11 +108,22 @@ const CaseList: FC<CaseListProps> = ({
|
||||
selectedNode.addPort({
|
||||
id: `CASE${caseCount + 1}`,
|
||||
group: 'right',
|
||||
// args: { dy },
|
||||
attrs: { text: { text: 'ELSE', fontSize: 12, fill: '#5B6167' }}
|
||||
});
|
||||
|
||||
// 恢复仍然存在的端口连线
|
||||
setTimeout(() => {
|
||||
// 计算删除前的总端口数来确定原ELSE端口编号
|
||||
const originalCaseCount = removedCaseIndex !== undefined ? caseCount + 1 : caseCount;
|
||||
const originalElsePortNumber = originalCaseCount + 1;
|
||||
|
||||
// 检查ELSE端口是否有连线
|
||||
const elseHasConnection = edgeConnections.some(({ sourcePortId, isIncoming }: any) => {
|
||||
const caseNumber = parseInt(sourcePortId.match(/CASE(\d+)/)?.[1] || '0');
|
||||
return !isIncoming && caseNumber === originalElsePortNumber;
|
||||
});
|
||||
|
||||
edgeConnections.forEach(({ edge, sourcePortId, targetCellId, targetPortId, sourceCellId, isIncoming }: any) => {
|
||||
// 如果是进入连线(左侧端口),直接恢复
|
||||
if (isIncoming) {
|
||||
@@ -138,7 +151,7 @@ const CaseList: FC<CaseListProps> = ({
|
||||
// 处理右侧端口连线
|
||||
const originalCaseNumber = parseInt(sourcePortId.match(/CASE(\d+)/)?.[1] || '0');
|
||||
|
||||
// 如果是被删除的端口,不重新创建连线
|
||||
// 如果是被删除的端口,只删除该端口的连线
|
||||
if (removedCaseIndex !== undefined && originalCaseNumber === removedCaseIndex + 1) {
|
||||
graphRef.current?.removeCell(edge);
|
||||
return;
|
||||
@@ -146,15 +159,8 @@ const CaseList: FC<CaseListProps> = ({
|
||||
|
||||
let newPortId = sourcePortId;
|
||||
|
||||
// 如果是原来的ELSE端口,重新映射到新的ELSE端口
|
||||
const maxOriginalCaseNumber = Math.max(...edgeConnections
|
||||
.filter(({ isIncoming }: any) => !isIncoming)
|
||||
.map(({ sourcePortId }: any) => {
|
||||
const match = sourcePortId.match(/CASE(\d+)/);
|
||||
return match ? parseInt(match[1]) : 0;
|
||||
}));
|
||||
|
||||
if (originalCaseNumber === maxOriginalCaseNumber) {
|
||||
// 如果是原来的ELSE端口且有连线,重新映射到新的ELSE端口
|
||||
if (originalCaseNumber === originalElsePortNumber && elseHasConnection) {
|
||||
newPortId = `CASE${caseCount + 1}`; // 新的ELSE端口
|
||||
} else if (removedCaseIndex !== undefined && originalCaseNumber > removedCaseIndex + 1) {
|
||||
// 如果是被删除端口之后的端口,编号向前移动
|
||||
|
||||
@@ -53,6 +53,7 @@ const operatorsObj: { [key: string]: SelectProps['options'] } = {
|
||||
const ConditionList: FC<CaseListProps> = ({
|
||||
options,
|
||||
parentName,
|
||||
selectedNode,
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
const form = Form.useFormInstance();
|
||||
@@ -114,7 +115,12 @@ const ConditionList: FC<CaseListProps> = ({
|
||||
<Col span={14}>
|
||||
<Form.Item name={[field.name, 'left']} noStyle>
|
||||
<VariableSelect
|
||||
options={options.filter(vo => vo.value.includes('sys.') || vo.value.includes('conv.') || vo.nodeData.type === 'loop')}
|
||||
options={options.filter(vo =>
|
||||
vo.value.includes('sys.') ||
|
||||
vo.value.includes('conv.') ||
|
||||
vo.nodeData.type === 'loop' ||
|
||||
(vo.nodeData.cycle && vo.nodeData.cycle === selectedNode?.id)
|
||||
)}
|
||||
size="small"
|
||||
allowClear={false}
|
||||
popupMatchSelectWidth={false}
|
||||
|
||||
@@ -566,42 +566,6 @@ const Properties: FC<PropertiesProps> = ({
|
||||
nodeData: nodeData,
|
||||
});
|
||||
}
|
||||
|
||||
// Check if connected via ERROR connection point
|
||||
const errorEdges = edges.filter(edge =>
|
||||
edge.getTargetCellId() === selectedNode.id &&
|
||||
edge.getSourceCellId() === nodeId &&
|
||||
edge.getSourcePortId() === 'ERROR'
|
||||
);
|
||||
|
||||
if (errorEdges.length > 0) {
|
||||
const errorMessageKey = `${dataNodeId}_error_message`;
|
||||
const errorTypeKey = `${dataNodeId}_error_type`;
|
||||
|
||||
if (!addedKeys.has(errorMessageKey)) {
|
||||
addedKeys.add(errorMessageKey);
|
||||
variableList.push({
|
||||
key: errorMessageKey,
|
||||
label: 'error_message',
|
||||
type: 'variable',
|
||||
dataType: 'string',
|
||||
value: `${dataNodeId}.error_message`,
|
||||
nodeData: nodeData,
|
||||
});
|
||||
}
|
||||
|
||||
if (!addedKeys.has(errorTypeKey)) {
|
||||
addedKeys.add(errorTypeKey);
|
||||
variableList.push({
|
||||
key: errorTypeKey,
|
||||
label: 'error_type',
|
||||
type: 'variable',
|
||||
dataType: 'string',
|
||||
value: `${dataNodeId}.error_type`,
|
||||
nodeData: nodeData,
|
||||
});
|
||||
}
|
||||
}
|
||||
break
|
||||
case 'jinja-render':
|
||||
const jinjaOutputKey = `${dataNodeId}_output`;
|
||||
@@ -793,54 +757,6 @@ const Properties: FC<PropertiesProps> = ({
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if parent loop/iteration is connected to http-request via ERROR connection
|
||||
if (parentData.type === 'loop' || parentData.type === 'iteration') {
|
||||
const parentPreviousNodeIds = getAllPreviousNodes(parentLoopNode.id);
|
||||
parentPreviousNodeIds.forEach(prevNodeId => {
|
||||
const prevNode = nodes.find(n => n.id === prevNodeId);
|
||||
if (!prevNode) return;
|
||||
|
||||
const prevNodeData = prevNode.getData();
|
||||
if (prevNodeData.type === 'http-request') {
|
||||
// Check if connected via ERROR connection point
|
||||
const errorEdges = edges.filter(edge => {
|
||||
return edge.getTargetCellId() === parentLoopNode.id &&
|
||||
edge.getSourceCellId() === prevNodeId &&
|
||||
edge.getSourcePortId() === 'ERROR'
|
||||
});
|
||||
|
||||
if (errorEdges.length > 0) {
|
||||
const errorMessageKey = `${prevNodeData.id}_error_message`;
|
||||
const errorTypeKey = `${prevNodeData.id}_error_type`;
|
||||
|
||||
if (!addedKeys.has(errorMessageKey)) {
|
||||
addedKeys.add(errorMessageKey);
|
||||
variableList.push({
|
||||
key: errorMessageKey,
|
||||
label: 'error_message',
|
||||
type: 'variable',
|
||||
dataType: 'string',
|
||||
value: `${prevNodeData.id}.error_message`,
|
||||
nodeData: prevNodeData,
|
||||
});
|
||||
}
|
||||
|
||||
if (!addedKeys.has(errorTypeKey)) {
|
||||
addedKeys.add(errorTypeKey);
|
||||
variableList.push({
|
||||
key: errorTypeKey,
|
||||
label: 'error_type',
|
||||
type: 'variable',
|
||||
dataType: 'string',
|
||||
value: `${prevNodeData.id}.error_type`,
|
||||
nodeData: prevNodeData,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return variableList;
|
||||
@@ -999,12 +915,12 @@ const Properties: FC<PropertiesProps> = ({
|
||||
return filteredList;
|
||||
}
|
||||
if (nodeType === 'knowledge-retrieval' || nodeType === 'parameter-extractor' && key !== 'prompt' || nodeType === 'memory-read' || nodeType === 'memory-write' || nodeType === 'question-classifier') {
|
||||
let filteredList = variableList.filter(variable => variable.dataType === 'string');
|
||||
return addParentIterationVars(filteredList);
|
||||
let filteredList = addParentIterationVars(variableList).filter(variable => variable.dataType === 'string');
|
||||
return filteredList;
|
||||
}
|
||||
if (nodeType === 'parameter-extractor' && key === 'prompt') {
|
||||
let filteredList = variableList.filter(variable => variable.dataType === 'string' || variable.dataType === 'number');
|
||||
return addParentIterationVars(filteredList);
|
||||
let filteredList = addParentIterationVars(variableList).filter(variable => variable.dataType === 'string' || variable.dataType === 'number');
|
||||
return filteredList;
|
||||
}
|
||||
if (nodeType === 'iteration' && key === 'output') {
|
||||
return variableList.filter(variable => variable.value.includes('sys.'));
|
||||
@@ -1013,8 +929,71 @@ const Properties: FC<PropertiesProps> = ({
|
||||
return variableList.filter(variable => variable.dataType.includes('array'));
|
||||
}
|
||||
if (nodeType === 'loop' && key === 'condition') {
|
||||
let filteredList = variableList.filter(variable => variable.nodeData.type !== 'loop');
|
||||
return addParentIterationVars(filteredList);
|
||||
let filteredList = addParentIterationVars(variableList).filter(variable => variable.nodeData.type !== 'loop');
|
||||
|
||||
// Add child node output variables for loop nodes
|
||||
if (selectedNode) {
|
||||
const graph = graphRef.current;
|
||||
if (graph) {
|
||||
const nodes = graph.getNodes();
|
||||
const childNodes = nodes.filter(node => {
|
||||
const nodeData = node.getData();
|
||||
return nodeData?.cycle === selectedNode.id;
|
||||
});
|
||||
|
||||
// Add output variables from child nodes
|
||||
childNodes.forEach(childNode => {
|
||||
const childData = childNode.getData();
|
||||
const childNodeId = childData.id;
|
||||
|
||||
// Add child node output variables based on their type
|
||||
switch(childData.type) {
|
||||
case 'llm':
|
||||
case 'jinja-render':
|
||||
case 'tool':
|
||||
const outputKey = `${childNodeId}_output`;
|
||||
const existingOutput = filteredList.find(v => v.key === outputKey);
|
||||
if (!existingOutput) {
|
||||
filteredList.push({
|
||||
key: outputKey,
|
||||
label: 'output',
|
||||
type: 'variable',
|
||||
dataType: 'string',
|
||||
value: `${childNodeId}.output`,
|
||||
nodeData: childData,
|
||||
});
|
||||
}
|
||||
break;
|
||||
case 'http-request':
|
||||
const bodyKey = `${childNodeId}_body`;
|
||||
const statusKey = `${childNodeId}_status_code`;
|
||||
if (!filteredList.find(v => v.key === bodyKey)) {
|
||||
filteredList.push({
|
||||
key: bodyKey,
|
||||
label: 'body',
|
||||
type: 'variable',
|
||||
dataType: 'string',
|
||||
value: `${childNodeId}.body`,
|
||||
nodeData: childData,
|
||||
});
|
||||
}
|
||||
if (!filteredList.find(v => v.key === statusKey)) {
|
||||
filteredList.push({
|
||||
key: statusKey,
|
||||
label: 'status_code',
|
||||
type: 'variable',
|
||||
dataType: 'number',
|
||||
value: `${childNodeId}.status_code`,
|
||||
nodeData: childData,
|
||||
});
|
||||
}
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return filteredList;
|
||||
}
|
||||
|
||||
// For all other node types, add parent iteration variables if applicable
|
||||
@@ -1025,7 +1004,7 @@ const Properties: FC<PropertiesProps> = ({
|
||||
// const defaultVariableList = calculateVariableList(selectedNode as Node, graphRef, workflowConfig )
|
||||
|
||||
console.log('values', values)
|
||||
// console.log('variableList', variableList, defaultVariableList)
|
||||
console.log('variableList', variableList)
|
||||
|
||||
return (
|
||||
<div className="rb:w-75 rb:fixed rb:right-0 rb:top-16 rb:bottom-0 rb:p-3">
|
||||
|
||||
Reference in New Issue
Block a user