Compare commits

..

2 Commits

Author SHA1 Message Date
wxy
cef33fce0d fix(workflow): sanitize condition expression building and cache assigner node inputs
- Sanitize condition expression construction in graph_builder.py using json.dumps to prevent potential injection vulnerabilities.
- Cache input data prior to assigner node execution to ensure variable values are correctly captured before processing.
2026-05-07 16:26:47 +08:00
wxy
d9f08860bc feat(LLM node): integrate exception handling and enable branch routing
- Integrate exception handling configuration into LLM nodes, supporting three strategies: throw exception, return default value, or trigger exception branch.
- Modify execution logic to return a result structure containing a branch signal, enabling routing to designated branches upon failure.
- Update graph_builder to support LLM node branch routing logic using the branch_signal field for conditional judgment.
- Implement backward compatibility to support both legacy and new result formats.
2026-05-07 11:43:24 +08:00
80 changed files with 1182 additions and 1891 deletions

View File

@@ -3,9 +3,12 @@ name: Sync to Gitee
on:
push:
branches:
- '**' # All branchs
- main # Production
- develop # Integration
- 'release/*' # Release preparation
- 'hotfix/*' # Urgent fixes
tags:
- '**' # All version tags (v1.0.0, etc.)
- '*' # All version tags (v1.0.0, etc.)
jobs:
sync:

View File

@@ -158,19 +158,12 @@ class RedisTaskScheduler:
return {"status": status, "task_id": task_id, "result": result_content}
def _cleanup_finished(self):
cursor = 0
all_pending = {}
while True:
cursor, batch = self.redis.hscan(PENDING_HASH, cursor=cursor, count=100)
all_pending.update(batch)
if cursor == 0:
break
if not all_pending:
pending = self.redis.hgetall(PENDING_HASH)
if not pending:
return
now = time.time()
task_ids = list(all_pending.keys())
task_ids = list(pending.keys())
pipe = self.redis.pipeline()
for task_id in task_ids:
@@ -183,7 +176,7 @@ class RedisTaskScheduler:
for task_id, raw_result in zip(task_ids, results):
try:
meta = json.loads(all_pending[task_id])
meta = json.loads(pending[task_id])
lock_key = meta["lock_key"]
dispatched_at = meta.get("dispatched_at", 0)
age = now - dispatched_at
@@ -283,22 +276,6 @@ class RedisTaskScheduler:
return True
return stable_hash(user_id) % self._shard_count == self._shard_index
def _commit_post_dispatch(self, lock_key, task, msg_id, dispatch_lock):
pipe = self.redis.pipeline()
pipe.set(lock_key, task.id, ex=3600)
pipe.hset(PENDING_HASH, task.id, json.dumps({
"lock_key": lock_key,
"dispatched_at": time.time(),
"msg_id": msg_id,
}))
pipe.delete(dispatch_lock)
pipe.set(
f"task_tracker:{msg_id}",
json.dumps({"status": "DISPATCHED", "task_id": task.id}),
ex=86400,
)
pipe.execute()
def _dispatch(self, msg_id, msg_data) -> bool:
user_id = msg_data["user_id"]
task_name = msg_data["task_name"]
@@ -331,17 +308,28 @@ class RedisTaskScheduler:
task_name, user_id, msg_id, e, exc_info=True,
)
return False
for attempt in range(2):
try:
self._commit_post_dispatch(lock_key, task, msg_id, dispatch_lock)
break
except Exception as e:
logger.error(
"Post-dispatch state update failed for %s: %s",
task.id, e, exc_info=True,
)
time.sleep(0.1)
self.errors += 1
try:
pipe = self.redis.pipeline()
pipe.set(lock_key, task.id, ex=3600)
pipe.hset(PENDING_HASH, task.id, json.dumps({
"lock_key": lock_key,
"dispatched_at": time.time(),
"msg_id": msg_id,
}))
pipe.delete(dispatch_lock)
pipe.set(
f"task_tracker:{msg_id}",
json.dumps({"status": "DISPATCHED", "task_id": task.id}),
ex=86400,
)
pipe.execute()
except Exception as e:
logger.error(
"Post-dispatch state update failed for %s: %s",
task.id, e, exc_info=True,
)
self.errors += 1
self.dispatched += 1
logger.info("Task dispatched: %s (msg=%s)", task.id, msg_id)
@@ -379,21 +367,22 @@ class RedisTaskScheduler:
return
for uid, msg in candidates:
queue_key = f"{USER_QUEUE_PREFIX}{uid}"
if self._dispatch(msg["msg_id"], msg):
self.redis.lpop(queue_key)
if self.redis.llen(queue_key) > 0:
self.redis.sadd(READY_SET, uid)
self.redis.lpop(f"{USER_QUEUE_PREFIX}{uid}")
def schedule_loop(self):
self._heartbeat()
self._cleanup_finished()
ready_users = self.redis.smembers(READY_SET) or set()
pipe = self.redis.pipeline()
pipe.smembers(READY_SET)
pipe.delete(READY_SET)
results = pipe.execute()
ready_users = results[0] or set()
my_users = [uid for uid in ready_users if self._is_mine(uid)]
if my_users:
self.redis.srem(READY_SET, *my_users)
else:
if not my_users:
time.sleep(0.5)
return
@@ -456,7 +445,7 @@ class RedisTaskScheduler:
"Scheduler started: instance=%s", self.instance_id,
)
while self.running:
while True:
try:
self.schedule_loop()
@@ -491,7 +480,9 @@ class RedisTaskScheduler:
logger.error("Shutdown cleanup error: %s", e)
scheduler = RedisTaskScheduler()
scheduler: RedisTaskScheduler | None = None
if scheduler is None:
scheduler = RedisTaskScheduler()
if __name__ == "__main__":
import signal

View File

@@ -27,7 +27,6 @@ from app.services import task_service, workspace_service
from app.services.memory_agent_service import MemoryAgentService
from app.services.memory_agent_service import get_end_user_connected_config as get_config
from app.services.model_service import ModelConfigService
from app.utils.tmp_session import ChatSessionCache
load_dotenv()
api_logger = get_api_logger()
@@ -301,39 +300,60 @@ async def read_server(
if knowledge:
user_rag_memory_id = str(knowledge.id)
session_id = user_input.session_id.hex
api_logger.info(
f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}, session_id={session_id}")
f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
try:
# result = await memory_agent_service.read_memory(
# user_input.end_user_id,
# user_input.message,
# user_input.history,
# user_input.search_switch,
# config_id,
# db,
# storage_type,
# user_rag_memory_id
# )
# if str(user_input.search_switch) == "2":
# retrieve_info = result['answer']
# history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id,
# user_input.end_user_id)
# query = user_input.message
#
# # 调用 memory_agent_service 的方法生成最终答案
# result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
# end_user_id=user_input.end_user_id,
# retrieve_info=retrieve_info,
# history=history,
# query=query,
# config_id=config_id,
# db=db
# )
# if "信息不足,无法回答" in result['answer']:
# result['answer'] = retrieve_info
memory_config = get_config(user_input.end_user_id, db)
service = MemoryService(
db,
memory_config["memory_config_id"],
end_user_id=user_input.end_user_id
)
session_cache = ChatSessionCache(session_id)
search_result = await service.read(
user_input.message,
SearchStrategy(user_input.search_switch),
history=await session_cache.get_history(),
SearchStrategy(user_input.search_switch)
)
intermediate_outputs = []
sub_queries = set()
for memory in search_result.memories:
sub_queries.add(str(memory.query))
idx = 0
if user_input.search_switch in [SearchStrategy.DEEP, SearchStrategy.NORMAL]:
intermediate_outputs.append({
"type": "problem_split",
"title": "问题拆分",
"data": [
{
"id": f"Q{(idx := idx + 1)}",
"id": f"Q{idx+1}",
"question": question
}
for question in sub_queries
if question
for idx, question in enumerate(sub_queries)
]
})
perceptual_data = [
@@ -355,24 +375,16 @@ async def read_server(
"raw_result": search_result.memories,
"total": len(search_result.memories),
})
answer = await memory_agent_service.generate_summary_from_retrieve(
end_user_id=user_input.end_user_id,
retrieve_info=search_result.content,
history=[],
query=user_input.message,
config_id=config_id,
db=db
)
await session_cache.append_many(
[
{"role": "user", "content": user_input.message},
{"role": "assistant", "content": answer}
]
)
result = {
'answer': answer,
"intermediate_outputs": intermediate_outputs,
"session_id": session_id,
'answer': await memory_agent_service.generate_summary_from_retrieve(
end_user_id=user_input.end_user_id,
retrieve_info=search_result.content,
history=[],
query=user_input.message,
config_id=config_id,
db=db
),
"intermediate_outputs": intermediate_outputs
}
return success(data=result, msg="回复对话消息成功")
@@ -468,11 +480,9 @@ async def read_server_async(
if knowledge: user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Async read: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
try:
session_id = user_input.session_id.hex
session_cache = ChatSessionCache(session_id)
task = celery_app.send_task(
"app.core.memory.agent.read_message",
args=[user_input.end_user_id, user_input.message, await session_cache.get_history(), user_input.search_switch,
args=[user_input.end_user_id, user_input.message, user_input.history, user_input.search_switch,
config_id, storage_type, user_rag_memory_id]
)
api_logger.info(f"Read task queued: {task.id}")

View File

@@ -1,4 +1,4 @@
import asyncio
import uuid
from fastapi import APIRouter, Depends, HTTPException, status, Query
from pydantic import BaseModel, Field
@@ -10,7 +10,7 @@ from app.dependencies import get_current_user
from app.models.user_model import User
from app.schemas.response_schema import ApiResponse
from app.services import memory_dashboard_service, workspace_service
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
from app.services.memory_agent_service import get_end_users_connected_configs_batch
from app.services.app_statistics_service import AppStatisticsService
from app.core.logging_config import get_api_logger
@@ -48,7 +48,7 @@ def get_workspace_total_end_users(
@router.get("/end_users", response_model=ApiResponse)
def get_workspace_end_users(
async def get_workspace_end_users(
workspace_id: Optional[uuid.UUID] = Query(None, description="工作空间ID可选默认当前用户工作空间"),
keyword: Optional[str] = Query(None, description="搜索关键词(同时模糊匹配 other_name 和 id"),
page: int = Query(1, ge=1, description="页码从1开始"),
@@ -58,15 +58,6 @@ def get_workspace_end_users(
):
"""
获取工作空间的宿主列表(分页查询,支持模糊搜索)
新增:记忆数量过滤:
Neo4j 模式:
- 使用 end_users.memory_count 过滤 memory_count > 0 的宿主
- memory_num.total 直接取 end_user.memory_count
RAG 模式:
- 使用 documents.chunk_num 聚合过滤 chunk 总数 > 0 的宿主
- memory_num.total 取聚合后的 chunk 总数
返回工作空间下的宿主列表,支持分页查询和模糊搜索。
通过 keyword 参数同时模糊匹配 other_name 和 id 字段。
@@ -89,29 +80,17 @@ def get_workspace_end_users(
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表, 类型: {current_workspace_type}")
if current_workspace_type == "rag":
end_users_result = memory_dashboard_service.get_workspace_end_users_paginated_rag(
db=db,
workspace_id=workspace_id,
current_user=current_user,
page=page,
pagesize=pagesize,
keyword=keyword,
)
raw_items = end_users_result.get("items", [])
end_users = [item["end_user"] for item in raw_items]
else:
end_users_result = memory_dashboard_service.get_workspace_end_users_paginated(
db=db,
workspace_id=workspace_id,
current_user=current_user,
page=page,
pagesize=pagesize,
keyword=keyword,
)
raw_items = end_users_result.get("items", [])
end_users = raw_items
# 获取分页的 end_users
end_users_result = memory_dashboard_service.get_workspace_end_users_paginated(
db=db,
workspace_id=workspace_id,
current_user=current_user,
page=page,
pagesize=pagesize,
keyword=keyword
)
end_users = end_users_result.get("items", [])
total = end_users_result.get("total", 0)
if not end_users:
@@ -122,19 +101,50 @@ def get_workspace_end_users(
"page": page,
"pagesize": pagesize,
"total": total,
"hasnext": (page * pagesize) < total,
},
"hasnext": (page * pagesize) < total
}
}, msg="宿主列表获取成功")
end_user_ids = [str(user.id) for user in end_users]
try:
memory_configs_map = get_end_users_connected_configs_batch(end_user_ids, db)
except Exception as e:
api_logger.error(f"批量获取记忆配置失败: {str(e)}")
memory_configs_map = {}
# 并发执行两个独立的查询任务
async def get_memory_configs():
"""获取记忆配置(在线程池中执行同步查询)"""
try:
return await asyncio.to_thread(
get_end_users_connected_configs_batch,
end_user_ids, db
)
except Exception as e:
api_logger.error(f"批量获取记忆配置失败: {str(e)}")
return {}
# 触发按需初始化:为 implicit_emotions_storage / interest_distribution 中没有记录的用户异步生成数据
async def get_memory_nums():
"""获取记忆数量"""
if current_workspace_type == "rag":
# RAG 模式:批量查询
try:
chunk_map = await asyncio.to_thread(
memory_dashboard_service.get_users_total_chunk_batch,
end_user_ids, db, current_user
)
return {uid: {"total": count} for uid, count in chunk_map.items()}
except Exception as e:
api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}")
return {uid: {"total": 0} for uid in end_user_ids}
elif current_workspace_type == "neo4j":
# Neo4j 模式批量查询简化版本只返回total
try:
batch_result = await memory_storage_service.search_all_batch(end_user_ids)
return {uid: {"total": count} for uid, count in batch_result.items()}
except Exception as e:
api_logger.error(f"批量获取 Neo4j 记忆数量失败: {str(e)}")
return {uid: {"total": 0} for uid in end_user_ids}
return {uid: {"total": 0} for uid in end_user_ids}
# 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据
try:
from app.celery_app import celery_app as _celery_app
_celery_app.send_task(
@@ -149,26 +159,27 @@ def get_workspace_end_users(
except Exception as e:
api_logger.warning(f"触发按需初始化任务失败(不影响主流程): {e}")
# 并发执行配置查询和记忆数量查询
memory_configs_map, memory_nums_map = await asyncio.gather(
get_memory_configs(),
get_memory_nums()
)
# 构建结果列表
items = []
for index, end_user in enumerate(end_users):
for end_user in end_users:
user_id = str(end_user.id)
config_info = memory_configs_map.get(user_id, {})
if current_workspace_type == "rag":
memory_total = int(raw_items[index].get("memory_count", 0) or 0)
else:
memory_total = int(getattr(end_user, "memory_count", 0) or 0)
items.append({
"end_user": {
"id": user_id,
"other_name": end_user.other_name,
'end_user': {
'id': user_id,
'other_name': end_user.other_name
},
"memory_num": {"total": memory_total},
"memory_config": {
'memory_num': memory_nums_map.get(user_id, {"total": 0}),
'memory_config': {
"memory_config_id": config_info.get("memory_config_id"),
"memory_config_name": config_info.get("memory_config_name"),
},
"memory_config_name": config_info.get("memory_config_name")
}
})
# 触发社区聚类补全任务(异步,不阻塞接口响应)
@@ -396,7 +407,6 @@ def get_current_user_rag_total_num(
total_chunk = memory_dashboard_service.get_current_user_total_chunk(end_user_id, db, current_user)
return success(data=total_chunk, msg="宿主RAG知识数据获取成功")
@router.get("/rag_content", response_model=ApiResponse)
def get_rag_content(
end_user_id: str = Query(..., description="宿主ID"),

View File

@@ -296,7 +296,7 @@ async def chat(
}
)
# workflow 非流式返回
# 多 Agent 非流式返回
result = await app_chat_service.workflow_chat(
message=payload.message,

View File

@@ -221,7 +221,7 @@ def update_workspace_members(
@router.delete("/members/{member_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def delete_workspace_member(
def delete_workspace_member(
member_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
@@ -230,7 +230,7 @@ async def delete_workspace_member(
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
await workspace_service.delete_workspace_member(
workspace_service.delete_workspace_member(
db=db,
workspace_id=workspace_id,
member_id=member_id,

View File

@@ -241,8 +241,6 @@ class Settings:
SMTP_PORT: int = int(os.getenv("SMTP_PORT", "587"))
SMTP_USER: str = os.getenv("SMTP_USER", "")
SMTP_PASSWORD: str = os.getenv("SMTP_PASSWORD", "")
SANDBOX_URL: str = os.getenv("SANDBOX_URL", "")
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))

View File

@@ -20,7 +20,6 @@ from app.core.memory.storage_services.extraction_engine.knowledge_extraction.mem
memory_summary_generation
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.memory.utils.log.logging_utils import log_time
from app.core.memory.utils.memory_count_utils import sync_end_user_memory_count_from_neo4j
from app.db import get_db_context
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
@@ -314,28 +313,6 @@ async def write(
except Exception as cache_err:
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
# 同步 Neo4j 记忆节点总数到 PostgreSQL end_users.memory_count
if end_user_id:
try:
memory_count_connector = Neo4jConnector()
try:
node_count = await sync_end_user_memory_count_from_neo4j(
end_user_id,
memory_count_connector,
)
finally:
await memory_count_connector.close()
logger.info(
f"[MemoryCount] 写入后同步 memory_count: "
f"end_user_id={end_user_id}, count={node_count}"
)
except Exception as e:
logger.warning(
f"[MemoryCount] 写入后同步 memory_count 失败(不影响主流程): {e}",
exc_info=True,
)
# Close LLM/Embedder underlying httpx clients to prevent
# 'RuntimeError: Event loop is closed' during garbage collection
for client_obj in (llm_client, embedder_client):
@@ -354,4 +331,3 @@ async def write(
logger.info("=== Pipeline Complete ===")
logger.info(f"Total execution time: {total_time:.2f} seconds")

View File

@@ -43,13 +43,10 @@ class MemoryService:
self,
query: str,
search_switch: SearchStrategy,
history: list | None = None,
limit: int = 10,
) -> MemorySearchResult:
if history is None:
history = []
with get_db_context() as db:
return await ReadPipeLine(self.ctx, db).run(query, search_switch, history, limit)
return await ReadPipeLine(self.ctx, db).run(query, search_switch, limit)
async def forget(self, max_batch: int = 100, min_days: int = 30) -> dict:
raise NotImplementedError

View File

@@ -32,12 +32,10 @@ class Memory(BaseModel):
class MemorySearchResult(BaseModel):
memories: list[Memory]
content_str: str = Field(default="")
@computed_field
@property
def content(self) -> str:
if self.content_str:
return self.content_str
return "\n".join([memory.content for memory in self.memories])
@computed_field

View File

@@ -1,9 +1,8 @@
from app.core.memory.enums import SearchStrategy, StorageType
from app.core.memory.models.service_models import MemorySearchResult
from app.core.memory.pipelines.base_pipeline import ModelClientMixin, DBRequiredPipeline
from app.core.memory.read_services.generate_engine.query_preprocessor import QueryPreprocessor
from app.core.memory.read_services.generate_engine.retrieval_summary import RetrievalSummaryProcessor
from app.core.memory.read_services.search_engine.content_search import Neo4jSearchService, RAGSearchService
from app.core.memory.read_services.generate_engine.query_preprocessor import QueryPreprocessor
class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
@@ -11,30 +10,20 @@ class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
self,
query: str,
search_switch: SearchStrategy,
history: list,
limit: int = 10,
includes=None
) -> MemorySearchResult:
memory_l0 = None
if self.ctx.storage_type == StorageType.NEO4J:
memory_l0 = await self._get_search_service(includes).memory_l0()
query = QueryPreprocessor.process(query)
match search_switch:
case SearchStrategy.DEEP:
res = await self._deep_read(query, history, limit, includes)
return await self._deep_read(query, limit, includes)
case SearchStrategy.NORMAL:
res = await self._normal_read(query, history, limit, includes)
return await self._normal_read(query, limit, includes)
case SearchStrategy.QUICK:
res = await self._quick_read(query, limit, includes)
return await self._quick_read(query, limit, includes)
case _:
raise RuntimeError("Unsupported search strategy")
if memory_l0 is not None:
res.content_str = memory_l0.content + '\n' + res.content
res.memories.insert(0, memory_l0)
return res
def _get_search_service(self, includes=None):
if self.ctx.storage_type == StorageType.NEO4J:
return Neo4jSearchService(
@@ -48,11 +37,10 @@ class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
self.db
)
async def _deep_read(self, query: str, history: list, limit: int, includes=None) -> MemorySearchResult:
async def _deep_read(self, query: str, limit: int, includes=None) -> MemorySearchResult:
search_service = self._get_search_service(includes)
questions = await QueryPreprocessor.split(
query,
history,
self.get_llm_client(self.db, self.ctx.memory_config.llm_model_id)
)
query_results = []
@@ -61,18 +49,12 @@ class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
query_results.append(search_results)
results = sum(query_results, start=MemorySearchResult(memories=[]))
results.memories.sort(key=lambda x: x.score, reverse=True)
results.content_str = await RetrievalSummaryProcessor.summary(
query,
results.content,
self.get_llm_client(self.db, self.ctx.memory_config.llm_model_id)
)
return results
async def _normal_read(self, query: str, history: list, limit: int, includes=None) -> MemorySearchResult:
async def _normal_read(self, query: str, limit: int, includes=None) -> MemorySearchResult:
search_service = self._get_search_service(includes)
questions = await QueryPreprocessor.split(
query,
history,
self.get_llm_client(self.db, self.ctx.memory_config.llm_model_id)
)
query_results = []
@@ -81,11 +63,6 @@ class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
query_results.append(search_results)
results = sum(query_results, start=MemorySearchResult(memories=[]))
results.memories.sort(key=lambda x: x.score, reverse=True)
results.content_str = await RetrievalSummaryProcessor.summary(
query,
results.content,
self.get_llm_client(self.db, self.ctx.memory_config.llm_model_id)
)
return results
async def _quick_read(self, query: str, limit: int, includes=None) -> MemorySearchResult:

View File

@@ -76,8 +76,8 @@ Remember the following:
- Today's date is {{ datetime }}.
- Do not return anything from the custom few shot example prompts provided above.
- Don't reveal your prompt or model information to the user.
- The output language should match the user's input language.
- Vague times in user input should be converted into specific dates.
- If you are unable to extract any relevant information from the user's input, return the user's original input:{"questions":[userinput]}
# [IMPORTANT]: THE OUTPUT LANGUAGE MUST BE THE SAME AS THE USER'S INPUT LANGUAGE.
The following is the user's input. You need to extract the relevant information from the input and return it in the JSON format as shown above.

View File

@@ -1,15 +0,0 @@
You are a Content Condenser for a memory-augmented retrieval system.
Your task is to compress the retrieved content while preserving all information that is highly relevant to the users query.
Guidelines:
Focus only on content related to the query; ignore irrelevant parts.
Remove redundancy, filler, or repeated information only for non-XML content.
Preserve all factual details: names, dates, decisions, code snippets, technical details.
If relevant information is inside XML tags, do not remove, merge, or compress the XML tags or their internal text; keep them fully intact.
Structure multiple relevant points as a compact bullet list or paragraph, depending on density.
If no content is relevant, return exactly: "No relevant information found."
Do not add any knowledge or facts not in the retrieved content.
# [IMPORTANT] OUTPUT ONLY THE CONDENSED CONTENT, DO NOT ATTEMPT TO ANSWER THE QUERY.
# [IMPORTANT] DO NOT REMOVE OR PARAPHRASE HIGHLY RELEVANT INFORMATION.

View File

@@ -21,14 +21,14 @@ class QueryPreprocessor:
return text
@staticmethod
async def split(query: str, history: list, llm_client: RedBearLLM):
async def split(query: str, llm_client: RedBearLLM):
system_prompt = prompt_manager.render(
name="problem_split",
datetime=datetime.now().strftime("%Y-%m-%d"),
)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"<history>{history}</history><query>{query}</query>"},
{"role": "user", "content": query},
]
try:
sub_queries = await llm_client.ainvoke(messages) | StructResponse(mode='json')

View File

@@ -1,29 +1,11 @@
import logging
from app.core.models import RedBearLLM
from app.core.memory.prompt import prompt_manager
from app.core.memory.utils.llm.llm_utils import StructResponse
logger = logging.getLogger(__name__)
class RetrievalSummaryProcessor:
@staticmethod
async def summary(query, content: str, llm_client: RedBearLLM):
system_prompt = prompt_manager.render(
name="retrieval_summary"
)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"<query>{query}</query><content>{content}</content>"},
]
try:
summary = await llm_client.ainvoke(messages) | StructResponse(mode='str')
return summary
except:
logger.error("Failed to generate reply summary, returning original content", exc_info=True)
return content
def summary(content: str, llm_client: RedBearLLM):
return
@staticmethod
async def verify(query, content: str, llm_client: RedBearLLM):
def verify(content: str, llm_client: RedBearLLM):
return

View File

@@ -14,8 +14,6 @@ from app.core.rag.nlp.search import knowledge_retrieval
from app.repositories import knowledge_repository
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.read_services.search_engine.result_builder import MetadataBuilder
from app.repositories.neo4j.graph_search import search_user_metadata
logger = logging.getLogger(__name__)
@@ -179,22 +177,6 @@ class Neo4jSearchService:
memories.sort(key=lambda x: x.score, reverse=True)
return MemorySearchResult(memories=memories[:limit])
async def memory_l0(self) -> Memory:
async with Neo4jConnector() as connector:
end_user_id = self.ctx.end_user_id
user_meta = await search_user_metadata(connector, end_user_id)
metadata = MetadataBuilder(user_meta)
memory = Memory(
score=1,
source=Neo4jNodeType.EXTRACTEDENTITY,
query='',
id=end_user_id,
content=metadata.content,
data=metadata.data,
)
return memory
class RAGSearchService:
def __init__(self, ctx: MemoryContext, db: Session):

View File

@@ -42,15 +42,7 @@ class ChunkBuilder(BaseBuilder):
@property
def content(self) -> str:
parts = ["<chunk>"]
fields = [
("content", self.record.get("content", "")),
]
for tag, value in fields:
if value:
parts.append(f"<{tag}>{value}</{tag}>")
parts.append("</chunk>")
return "".join(parts)
return self.record.get("content")
class StatementBuiler(BaseBuilder):
@@ -65,15 +57,7 @@ class StatementBuiler(BaseBuilder):
@property
def content(self) -> str:
parts = ["<statement>"]
fields = [
("statement", self.record.get("statement", "")),
]
for tag, value in fields:
if value:
parts.append(f"<{tag}>{value}</{tag}>")
parts.append("</statement>")
return "".join(parts)
return self.record.get("statement")
class EntityBuilder(BaseBuilder):
@@ -89,16 +73,10 @@ class EntityBuilder(BaseBuilder):
@property
def content(self) -> str:
parts = ["<entity>"]
fields = [
("name", self.record.get("name", "")),
("description", self.record.get("description", "")),
]
for tag, value in fields:
if value:
parts.append(f"<{tag}>{value}</{tag}>")
parts.append("</entity>")
return "".join(parts)
return (f"<entity>"
f"<name>{self.record.get("name")}<name>"
f"<description>{self.record.get("description")}</description>"
f"</entity>")
class SummaryBuilder(BaseBuilder):
@@ -113,15 +91,7 @@ class SummaryBuilder(BaseBuilder):
@property
def content(self) -> str:
parts = ["<summary>"]
fields = [
("content", self.record.get("content", "")),
]
for tag, value in fields:
if value:
parts.append(f"<{tag}>{value}</{tag}>")
parts.append("</summary>")
return "".join(parts)
return self.record.get("content")
class PerceptualBuilder(BaseBuilder):
@@ -144,21 +114,15 @@ class PerceptualBuilder(BaseBuilder):
@property
def content(self) -> str:
parts = ["<history-file-info>"]
fields = [
("file-name", self.record.get("file_name", "")),
("file-path", self.record.get("file_path", "")),
("summary", self.record.get("summary", "")),
("topic", self.record.get("topic", "")),
("domain", self.record.get("domain", "")),
("keywords", self.record.get("keywords", [])),
("file-type", self.record.get("file_type", "")),
]
for tag, value in fields:
if value:
parts.append(f"<{tag}>{value}</{tag}>")
parts.append("</history-file-info>")
return "".join(parts)
return ("<history-file-info>"
f"<file-name>{self.record.get('file_name')}</file-name>"
f"<file-path>{self.record.get('file_path')}</file-path>"
f"<summary>{self.record.get('summary')}</summary>"
f"<topic>{self.record.get('topic')}</topic>"
f"<domain>{self.record.get('domain')}</domain>"
f"<keywords>{self.record.get('keywords')}</keywords>"
f"<file-type>{self.record.get('file_type')}</file-type>"
"</history-file-info>")
class CommunityBuilder(BaseBuilder):
@@ -173,54 +137,7 @@ class CommunityBuilder(BaseBuilder):
@property
def content(self) -> str:
parts = ["<community>"]
fields = [
("content", self.record.get("content", "")),
]
for tag, value in fields:
if value:
parts.append(f"<{tag}>{value}</{tag}>")
parts.append("</community>")
return "".join(parts)
class MetadataBuilder(BaseBuilder):
@property
def data(self) -> dict:
return {
"id": self.record.get("id", ""),
"aliases_name": self.record.get("aliases", []) or [],
"description": self.record.get("description", ""),
"anchors": self.record.get("anchors", []) or [],
"beliefs_or_stances": self.record.get("beliefs_or_stances", []) or [],
"core_facts": self.record.get("core_facts", []) or [],
"events": self.record.get("events", []) or [],
"goals": self.record.get("goals", []) or [],
"interests": self.record.get("interests", []) or [],
"relations": self.record.get("relations", []) or [],
"traits": self.record.get("traits", []) or [],
}
@property
def content(self) -> str:
parts = ["<user-info>"]
fields = [
("description", self.record.get("description", "")),
("aliases", self.record.get("aliases", [])),
("anchors", self.record.get("anchors", [])),
("beliefs_or_stances", self.record.get("beliefs_or_stances", [])),
("core_facts", self.record.get("core_facts", [])),
("events", self.record.get("events", [])),
("goals", self.record.get("goals", [])),
("interests", self.record.get("interests", [])),
("relations", self.record.get("relations", [])),
("traits", self.record.get("traits", [])),
]
for tag, value in fields:
if value:
parts.append(f"<{tag}>{value}</{tag}>")
parts.append("</user-info>")
return "".join(parts)
return self.record.get("content")
def data_builder_factory(node_type, data: dict) -> T:

View File

@@ -20,7 +20,6 @@ from uuid import UUID
from datetime import datetime
from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy
from app.core.memory.utils.memory_count_utils import sync_end_user_memory_count_from_neo4j
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
@@ -146,22 +145,7 @@ class ForgettingScheduler:
}
logger.info("没有可遗忘的节点对,遗忘周期结束")
# 同步 Neo4j 记忆节点总数到 PostgreSQL 的 end_users.memory_count
if end_user_id:
try:
node_count = await sync_end_user_memory_count_from_neo4j(
end_user_id,
self.connector,
)
logger.info(
f"[MemoryCount] 遗忘后同步 memory_count: "
f"end_user_id={end_user_id}, count={node_count}"
)
except Exception as e:
logger.warning(
f"[MemoryCount] 遗忘后同步 memory_count 失败(不影响主流程): {e}",
exc_info=True,
)
return report
# 步骤3按激活值排序激活值最低的优先
@@ -318,22 +302,7 @@ class ForgettingScheduler:
f"({reduction_rate:.2%}), "
f"耗时 {duration:.2f}"
)
# 同步 Neo4j 记忆节点总数到 PostgreSQL 的 end_users.memory_count
if end_user_id:
try:
node_count = await sync_end_user_memory_count_from_neo4j(
end_user_id,
self.connector,
)
logger.info(
f"[MemoryCount] 遗忘后同步 memory_count: "
f"end_user_id={end_user_id}, count={node_count}"
)
except Exception as e:
logger.warning(
f"[MemoryCount] 遗忘后同步 memory_count 失败(不影响主流程): {e}",
exc_info=True,
)
return report
except Exception as e:

View File

@@ -17,7 +17,7 @@ async def handle_response(response: type[BaseModel]) -> dict:
class StructResponse:
def __init__(self, mode: Literal["json", "pydantic", "str"], model: Type[BaseModel] = None):
def __init__(self, mode: Literal["json", "pydantic"], model: Type[BaseModel] = None):
self.mode = mode
if mode == "pydantic" and model is None:
raise ValueError("Pydantic model is required")
@@ -31,8 +31,6 @@ class StructResponse:
for block in other.content_blocks:
if block.get("type") == "text":
text += block.get("text", "")
if self.mode == "str":
return text
fixed_json = json_repair.repair_json(text, return_objects=True)
if self.mode == "json":
return fixed_json

View File

@@ -1,36 +0,0 @@
from uuid import UUID
from app.db import get_db_context
from app.models.end_user_model import EndUser
from app.repositories.memory_config_repository import MemoryConfigRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def sync_end_user_memory_count_from_neo4j(
end_user_id: str,
connector: Neo4jConnector,
) -> int:
"""
Sync one end user's Neo4j memory node count to PostgreSQL.
The caller owns the Neo4j connector lifecycle.
"""
if not end_user_id:
return 0
result = await connector.execute_query(
MemoryConfigRepository.SEARCH_FOR_ALL_BATCH,
end_user_ids=[end_user_id],
)
node_count = int(result[0]["total"]) if result else 0
with get_db_context() as db:
db.query(EndUser).filter(
EndUser.id == UUID(end_user_id)
).update(
{"memory_count": node_count},
synchronize_session=False,
)
db.commit()
return node_count

View File

@@ -216,7 +216,7 @@ class RedBearModelFactory:
# 深度思考模式Claude 3.7 Sonnet 等支持思考的模型
# 通过 additional_model_request_fields 传递 thinking 块关闭时不传Bedrock 无 disabled 选项)
if config.deep_thinking:
budget = config.thinking_budget_tokens or 1024
budget = config.thinking_budget_tokens or 10000
params["additional_model_request_fields"] = {
"thinking": {"type": "enabled", "budget_tokens": budget}
}

View File

@@ -2,6 +2,7 @@
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/10 13:33
import json
import logging
import re
import uuid
@@ -141,9 +142,10 @@ class GraphBuilder:
for node_info in source_nodes:
if self.get_node_type(node_info["id"]) in BRANCH_NODES:
branch_nodes.append(
(node_info["id"], node_info["branch"])
)
if node_info.get("branch") is not None:
branch_nodes.append(
(node_info["id"], node_info["branch"])
)
else:
if self.get_node_type(node_info["id"]) in (NodeType.END, NodeType.OUTPUT):
output_nodes.append(node_info["id"])
@@ -314,9 +316,12 @@ class GraphBuilder:
for idx in range(len(related_edge)):
# Generate a condition expression for each edge
# Used later to determine which branch to take based on the node's output
# Assumes node output `node.<node_id>.output` matches the edge's label
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
related_edge[idx]['condition'] = f"node['{node_id}']['output'] == '{related_edge[idx]['label']}'"
# For LLM nodes, use branch_signal field for routing (output is dynamic text)
# For other branch nodes (e.g. HTTP), use output field
route_field = "branch_signal" if node_type == NodeType.LLM else "output"
related_edge[idx]['condition'] = (
f"node[{json.dumps(node_id)}][{json.dumps(route_field)}] == {json.dumps(related_edge[idx]['label'])}"
)
if node_instance:
# Wrap node's run method to avoid closure issues

View File

@@ -18,10 +18,17 @@ class AssignerNode(BaseNode):
super().__init__(node_config, workflow_config, down_stream_nodes)
self.variable_updater = True
self.typed_config: AssignerNodeConfig | None = None
self._input_data: dict[str, Any] | None = None
def _output_types(self) -> dict[str, VariableType]:
return {}
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""提取节点输入,如果有缓存的执行前数据则使用缓存"""
if self._input_data is not None:
return self._input_data
return {"config": self._resolve_config(self.config, variable_pool)}
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
"""
Execute the assignment operation defined by this node.
@@ -34,6 +41,9 @@ class AssignerNode(BaseNode):
Returns:
None or the result of the assignment operation.
"""
# 在执行前提取并缓存输入数据(捕获执行前的变量值)
self._input_data = {"config": self._resolve_config(self.config, variable_pool)}
# Initialize a variable pool for accessing conversation, node, and system variables
self.typed_config = AssignerNodeConfig(**self.config)
logger.info(f"节点 {self.node_id} 开始执行")

View File

@@ -14,7 +14,6 @@ from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes import BaseNode
from app.core.workflow.nodes.code.config import CodeNodeConfig
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
from app.core.config import settings
logger = logging.getLogger(__name__)
@@ -132,7 +131,7 @@ class CodeNode(BaseNode):
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(
f"{settings.SANDBOX_URL}/v1/sandbox/run",
"http://sandbox:8194/v1/sandbox/run",
headers={
"x-api-key": 'redbear-sandbox'
},

View File

@@ -185,7 +185,7 @@ class DocExtractorNode(BaseNode):
mime_type=f"image/{ext}",
is_file=True,
).model_dump())
text = text + f"\n{placeholder}: <img src=\"{url}\" data-url=\"{url}\">"
text = text + f"\n{placeholder}: {url}"
except Exception as e:
logger.error(f"Node {self.node_id}: failed to save image {placeholder}: {e}")

View File

@@ -31,7 +31,7 @@ class NodeType(StrEnum):
NOTES = "notes"
BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER})
BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER, NodeType.LLM})
class ComparisonOperator(StrEnum):

View File

@@ -6,6 +6,7 @@ import uuid
from pydantic import BaseModel, Field, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
from app.core.workflow.nodes.enums import HttpErrorHandle
from app.core.workflow.variable.base_variable import VariableType
@@ -49,6 +50,20 @@ class MemoryWindowSetting(BaseModel):
)
class LLMErrorHandleConfig(BaseModel):
"""LLM 异常处理配置"""
method: HttpErrorHandle = Field(
default=HttpErrorHandle.NONE,
description="异常处理策略:'none' 抛出异常, 'default' 返回默认值, 'branch' 走异常分支",
)
output: str = Field(
default="",
description="LLM 异常时返回的默认输出文本method=default 时生效)",
)
class LLMNodeConfig(BaseNodeConfig):
"""LLM 节点配置
@@ -152,6 +167,11 @@ class LLMNodeConfig(BaseNodeConfig):
description="输出变量定义(自动生成,通常不需要修改)"
)
error_handle: LLMErrorHandleConfig = Field(
default_factory=LLMErrorHandleConfig,
description="LLM 异常处理配置",
)
@field_validator("messages", "prompt")
@classmethod
def validate_input_mode(cls, v):

View File

@@ -15,6 +15,7 @@ from app.core.models import RedBearLLM, RedBearModelConfig
from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import HttpErrorHandle
from app.core.workflow.nodes.llm.config import LLMNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.db import get_db_context
@@ -76,7 +77,7 @@ class LLMNode(BaseNode):
self.messages = []
def _output_types(self) -> dict[str, VariableType]:
return {"output": VariableType.STRING}
return {"output": VariableType.STRING, "branch_signal": VariableType.STRING}
def _render_context(self, message: str, variable_pool: VariablePool):
context = f"<context>{self._render_template(self.typed_config.context, variable_pool)}</context>"
@@ -239,7 +240,7 @@ class LLMNode(BaseNode):
return llm
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> AIMessage:
async def execute(self, state: WorkflowState, variable_pool: VariablePool):
"""非流式执行 LLM 调用
Args:
@@ -247,28 +248,36 @@ class LLMNode(BaseNode):
variable_pool: 变量池
Returns:
LLM 响应消息
dict: {"llm_result": AIMessage, "branch_signal": "SUCCESS"} on success,
{"llm_result": None, "branch_signal": "ERROR"} on branch error
"""
# self.typed_config = LLMNodeConfig(**self.config)
llm = await self._prepare_llm(state, variable_pool, False)
try:
# self.typed_config = LLMNodeConfig(**self.config)
llm = await self._prepare_llm(state, variable_pool, False)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
# 调用 LLM支持字符串或消息列表
response = await llm.ainvoke(self.messages)
# 提取内容
if hasattr(response, 'content'):
content = self.process_model_output(response.content)
else:
content = str(response)
# 调用 LLM支持字符串或消息列表
response = await llm.ainvoke(self.messages)
# 提取内容
if hasattr(response, 'content'):
content = self.process_model_output(response.content)
else:
content = str(response)
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
# 返回 AIMessage包含响应元数据
return AIMessage(content=content, response_metadata={
**response.response_metadata,
"token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage')
})
# 返回 AIMessage包含响应元数据
return {
"llm_result": AIMessage(content=content, response_metadata={
**response.response_metadata,
"token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage')
}),
"branch_signal": "SUCCESS",
}
except Exception as e:
logger.error(f"节点 {self.node_id} LLM 调用失败: {e}")
return self._handle_llm_error(e)
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""提取输入数据(用于记录)"""
@@ -286,16 +295,36 @@ class LLMNode(BaseNode):
}
}
def _extract_output(self, business_result: Any) -> str:
""" AIMessage 中提取文本内容"""
def _extract_output(self, business_result: Any) -> dict:
"""业务结果中提取输出变量
支持新旧两种格式:
- 新格式:{"llm_result": AIMessage, "branch_signal": "SUCCESS"}
- 旧格式AIMessage向后兼容
"""
if isinstance(business_result, dict) and "branch_signal" in business_result:
llm_result = business_result.get("llm_result")
if isinstance(llm_result, AIMessage):
return {
"output": llm_result.content,
"branch_signal": business_result["branch_signal"],
}
return {
"output": str(llm_result) if llm_result else "",
"branch_signal": business_result["branch_signal"],
}
# 旧格式向后兼容
if isinstance(business_result, AIMessage):
return business_result.content
return str(business_result)
return {"output": business_result.content, "branch_signal": "SUCCESS"}
return {"output": str(business_result), "branch_signal": "SUCCESS"}
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
""" AIMessage 中提取 token 使用情况"""
if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'):
usage = business_result.response_metadata.get('token_usage')
"""业务结果中提取 token 使用情况"""
llm_result = business_result
if isinstance(business_result, dict):
llm_result = business_result.get("llm_result", business_result)
if isinstance(llm_result, AIMessage) and hasattr(llm_result, 'response_metadata'):
usage = llm_result.response_metadata.get('token_usage')
if usage:
return {
"prompt_tokens": usage.get('input_tokens', 0),
@@ -304,6 +333,44 @@ class LLMNode(BaseNode):
}
return None
def _handle_llm_error(self, error: Exception) -> dict:
"""处理 LLM 调用异常,根据 error_handle 配置决定行为
Args:
error: LLM 调用中捕获的异常
Returns:
dict: {"llm_result": None, "branch_signal": "ERROR"} for branch mode,
or default output for default mode
Raises:
原异常(当 error_handle.method 为 NONE 时)
"""
if self.typed_config is None:
raise error
match self.typed_config.error_handle.method:
case HttpErrorHandle.NONE:
raise error
case HttpErrorHandle.DEFAULT:
logger.warning(
f"节点 {self.node_id}: LLM 调用失败,返回默认输出"
)
default_output = self.typed_config.error_handle.output or ""
return {
"llm_result": AIMessage(content=default_output, response_metadata={}),
"branch_signal": "SUCCESS",
}
case HttpErrorHandle.BRANCH:
logger.warning(
f"节点 {self.node_id}: LLM 调用失败,切换到异常处理分支"
)
return {
"llm_result": None,
"branch_signal": "ERROR",
}
raise error
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
"""流式执行 LLM 调用
@@ -316,54 +383,58 @@ class LLMNode(BaseNode):
"""
self.typed_config = LLMNodeConfig(**self.config)
llm = await self._prepare_llm(state, variable_pool, True)
try:
llm = await self._prepare_llm(state, variable_pool, True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
# logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
# 累积完整响应
full_response = ""
chunk_count = 0
# 累积完整响应
full_response = ""
chunk_count = 0
# 调用 LLM流式支持字符串或消息列表
last_meta_data = {}
last_usage_metadata = {}
async for chunk in llm.astream(self.messages):
if hasattr(chunk, 'content'):
content = self.process_model_output(chunk.content)
else:
content = str(chunk)
if hasattr(chunk, 'response_metadata') and chunk.response_metadata:
last_meta_data = chunk.response_metadata
if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
last_usage_metadata = chunk.usage_metadata
# 调用 LLM流式支持字符串或消息列表
last_meta_data = {}
last_usage_metadata = {}
async for chunk in llm.astream(self.messages):
if hasattr(chunk, 'content'):
content = self.process_model_output(chunk.content)
else:
content = str(chunk)
if hasattr(chunk, 'response_metadata') and chunk.response_metadata:
last_meta_data = chunk.response_metadata
if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
last_usage_metadata = chunk.usage_metadata
# 只有当内容不为空时才处理
if content:
full_response += content
chunk_count += 1
# 只有当内容不为空时才处理
if content:
full_response += content
chunk_count += 1
# 流式返回每个文本片段
yield {
"__final__": False,
"chunk": content
}
# 流式返回每个文本片段
yield {
"__final__": False,
"chunk": content
}
yield {
"__final__": False,
"chunk": "",
"done": True
}
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
# 构建完整的 AIMessage包含元数据
final_message = AIMessage(
content=full_response,
response_metadata={
**last_meta_data,
"token_usage": last_usage_metadata or last_meta_data.get('token_usage')
yield {
"__final__": False,
"chunk": "",
"done": True
}
)
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
# yield 完成标记
yield {"__final__": True, "result": final_message}
# 构建完整的 AIMessage包含元数据
final_message = AIMessage(
content=full_response,
response_metadata={
**last_meta_data,
"token_usage": last_usage_metadata or last_meta_data.get('token_usage')
}
)
# yield 完成标记
yield {"__final__": True, "result": {"llm_result": final_message, "branch_signal": "SUCCESS"}}
except Exception as e:
logger.error(f"节点 {self.node_id} LLM 流式调用失败: {e}")
error_result = self._handle_llm_error(e)
yield {"__final__": True, "result": error_result}

View File

@@ -40,7 +40,6 @@ class MemoryReadNode(BaseNode):
end_user_id=end_user_id,
user_rag_memory_id=state["user_rag_memory_id"],
)
# TODO: Historical Messages -> Used to refer to coreference resolution
search_result = await memory_service.read(
self._render_template(self.typed_config.message, variable_pool),
search_switch=SearchStrategy(self.typed_config.search_switch)

View File

@@ -1,7 +1,7 @@
import datetime
import uuid
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Text
from sqlalchemy import Column, DateTime, ForeignKey, String, Text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
@@ -38,15 +38,6 @@ class EndUser(Base):
comment="关联的记忆配置ID"
)
memory_count = Column(
Integer,
nullable=False,
default=0,
server_default="0",
index=True,
comment="记忆节点总数",
)
# 用户摘要四个维度 - User Summary Four Dimensions
user_summary = Column(Text, nullable=True, comment="缓存的用户摘要(基本介绍)")
personality_traits = Column(Text, nullable=True, comment="性格特点")

View File

@@ -1296,7 +1296,6 @@ RETURN e.id AS id,
e.name AS name,
e.end_user_id AS end_user_id,
e.entity_type AS entity_type,
e.description AS description,
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
COALESCE(e.importance_score, 0.5) AS importance_score,
e.last_access_time AS last_access_time,
@@ -1480,21 +1479,6 @@ ORDER BY score DESC
LIMIT $limit
"""
SEARCH_USER_METADATA = """
MATCH (n:ExtractedEntity)
WHERE (n.end_user_id = $end_user_id AND n.entity_type ='用户')
RETURN n.description AS description,
n.aliases AS aliases,
n.anchors AS anchors,
n.beliefs_or_stances AS beliefs_or_stances,
n.core_facts AS core_facts,
n.events AS events,
n.goals AS goals,
n.interests AS interests,
n.relations AS relations,
n.traits AS traits
"""
FULLTEXT_QUERY_CYPHER_MAPPING = {
Neo4jNodeType.STATEMENT: SEARCH_STATEMENTS_BY_KEYWORD,
Neo4jNodeType.EXTRACTEDENTITY: SEARCH_ENTITIES_BY_NAME_OR_ALIAS,

View File

@@ -27,9 +27,9 @@ from app.repositories.neo4j.cypher_queries import (
SEARCH_PERCEPTUAL_BY_USER_ID,
FULLTEXT_QUERY_CYPHER_MAPPING,
USER_ID_QUERY_CYPHER_MAPPING,
NODE_ID_QUERY_CYPHER_MAPPING,
SEARCH_USER_METADATA
NODE_ID_QUERY_CYPHER_MAPPING
)
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
logger = logging.getLogger(__name__)
@@ -513,7 +513,7 @@ async def search_graph_by_embedding(
task_keys = []
for node_type in include:
tasks.append(search_by_embedding(connector, node_type, end_user_id, embedding, limit * 2))
tasks.append(search_by_embedding(connector, node_type, end_user_id, embedding, limit*2))
task_keys.append(node_type.value)
task_results = await asyncio.gather(*tasks, return_exceptions=True)
@@ -557,17 +557,6 @@ async def search_graph_by_embedding(
return results
async def search_user_metadata(
connector: Neo4jConnector,
end_user_id: str
) -> dict:
user_info = await connector.execute_query(
SEARCH_USER_METADATA,
end_user_id=end_user_id
)
return user_info[0] if user_info else {}
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
connector: Neo4jConnector,
end_user_id: str,

View File

@@ -250,7 +250,7 @@ class ModelParameters(BaseModel):
n: int = Field(default=1, ge=1, le=10, description="生成的回复数量")
stop: Optional[List[str]] = Field(default=None, description="停止序列")
deep_thinking: bool = Field(default=False, description="是否启用深度思考模式(需模型支持,如 DeepSeek-R1、QwQ 等)")
thinking_budget_tokens: Optional[int] = Field(default=None, ge=1, le=131072, description="深度思考 token 预算(仅部分模型支持)")
thinking_budget_tokens: Optional[int] = Field(default=None, ge=1024, le=131072, description="深度思考 token 预算(仅部分模型支持)")
json_output: bool = Field(default=False, description="是否强制 JSON 格式输出(需模型支持 json_output 能力)")

View File

@@ -19,6 +19,4 @@ class EndUser(BaseModel):
# 用户摘要和洞察更新时间
user_summary_updated_at: Optional[datetime.datetime] = Field(description="用户摘要最后更新时间", default=None)
memory_insight_updated_at: Optional[datetime.datetime] = Field(description="洞察报告最后更新时间", default=None)
#用户记忆节点总数Neo4j模式
memory_count: int = Field(description="记忆节点总数", default=0)
memory_insight_updated_at: Optional[datetime.datetime] = Field(description="洞察报告最后更新时间", default=None)

View File

@@ -1,15 +1,14 @@
import uuid
from abc import ABC
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import BaseModel
class UserInput(BaseModel):
message: str
history: list[dict]
search_switch: str
end_user_id: str
session_id: uuid.UUID = Field(default_factory=uuid.uuid4)
config_id: Optional[str] = None

View File

@@ -161,10 +161,7 @@ class AppChatService:
f.type == FileType.DOCUMENT for f in files
):
system_prompt += (
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>"
"请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。"
"重要:图片 URL 中包含 UUID如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。"
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。"
)
# 创建 LangChain Agent
@@ -451,10 +448,7 @@ class AppChatService:
):
from langchain.agents import create_agent
system_prompt += (
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>"
"请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。"
"重要:图片 URL 中包含 UUID如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。"
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。"
)
# 创建 LangChain Agent

View File

@@ -108,7 +108,6 @@ def create_long_term_memory_tool(
try:
with get_db_context() as db:
memory_service = MemoryService(db, config_id, end_user_id)
# TODO: Historical Messages -> Used to refer to coreference resolution
search_result = asyncio.run(memory_service.read(question, SearchStrategy.QUICK))
# memory_content = asyncio.run(
@@ -651,10 +650,7 @@ class AgentRunService:
)
if has_doc_with_images:
system_prompt += (
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>"
"请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。"
"重要:图片 URL 中包含 UUID如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。"
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。"
)
agent = LangChainAgent(
@@ -928,10 +924,7 @@ class AgentRunService:
)
if has_doc_with_images:
system_prompt += (
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>"
"请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。"
"重要:图片 URL 中包含 UUID如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。"
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。"
)
# 创建 LangChain Agent

View File

@@ -1,5 +1,5 @@
from sqlalchemy.orm import Session
from sqlalchemy import desc, nullslast, or_, cast, String, func
from sqlalchemy import desc, nullslast, or_, and_, cast, String
from typing import List, Optional, Dict, Any
import uuid
from fastapi import HTTPException
@@ -102,7 +102,6 @@ def get_workspace_end_users_paginated(
"""获取工作空间的宿主列表(分页版本,支持模糊搜索)
返回结果按 created_at 从新到旧排序NULL 值排在最后)
固定过滤 memory_count > 0 的宿主,保证分页基于“有记忆宿主”集合计算。
支持通过 keyword 参数同时模糊搜索 other_name 和 id 字段
Args:
@@ -121,8 +120,7 @@ def get_workspace_end_users_paginated(
try:
# 构建基础查询
base_query = db.query(EndUserModel).filter(
EndUserModel.workspace_id == workspace_id,
EndUserModel.memory_count > 0 , # 只查询有记忆的宿主
EndUserModel.workspace_id == workspace_id
)
# 构建搜索条件过滤空字符串和None
@@ -130,13 +128,20 @@ def get_workspace_end_users_paginated(
if keyword:
keyword_pattern = f"%{keyword}%"
# other_name 匹配始终生效id 匹配仅对 other_name 为空的记录生效
base_query = base_query.filter(
or_(
EndUserModel.other_name.ilike(keyword_pattern),
cast(EndUserModel.id, String).ilike(keyword_pattern),
and_(
or_(
EndUserModel.other_name.is_(None),
EndUserModel.other_name == "",
),
cast(EndUserModel.id, String).ilike(keyword_pattern),
),
)
)
business_logger.info(f"应用模糊搜索: keyword={keyword}(匹配 other_name id")
business_logger.info(f"应用模糊搜索: keyword={keyword}(匹配 other_nameother_name 为空时匹配 id")
# 获取总记录数
total = base_query.count()
@@ -164,98 +169,6 @@ def get_workspace_end_users_paginated(
business_logger.error(f"获取工作空间宿主列表(分页)失败: workspace_id={workspace_id} - {str(e)}")
raise
def get_workspace_end_users_paginated_rag(
db: Session,
workspace_id: uuid.UUID,
current_user: User,
page: int,
pagesize: int,
keyword: Optional[str] = None,
) -> Dict[str, Any]:
"""RAG 模式宿主列表分页。
RAG 记忆数量以 documents.chunk_num 为准:
- file_name = end_user_id + ".txt"
- 只统计当前 workspace 下 permission_id="Memory" 的用户记忆知识库
- 在 SQL 层过滤 chunk 总数为 0 的宿主,保证分页准确
"""
business_logger.info(
f"获取 RAG 宿主列表(分页): workspace_id={workspace_id}, "
f"keyword={keyword}, page={page}, pagesize={pagesize}, 操作者: {current_user.username}"
)
try:
from app.models.document_model import Document
from app.models.knowledge_model import Knowledge
chunk_subquery = (
db.query(
Document.file_name.label("file_name"),
func.coalesce(func.sum(Document.chunk_num), 0).label("memory_count"),
)
.join(Knowledge, Document.kb_id == Knowledge.id)
.filter(
Knowledge.workspace_id == workspace_id,
Knowledge.status == 1,
Knowledge.permission_id == "Memory",
Document.status == 1,
)
.group_by(Document.file_name)
.subquery()
)
base_query = (
db.query(
EndUserModel,
chunk_subquery.c.memory_count.label("memory_count"),
)
.join(
chunk_subquery,
chunk_subquery.c.file_name == func.concat(cast(EndUserModel.id, String), ".txt"),
)
.filter(
EndUserModel.workspace_id == workspace_id,
chunk_subquery.c.memory_count > 0,
)
)
keyword = keyword.strip() if keyword else None
if keyword:
keyword_pattern = f"%{keyword}%"
base_query = base_query.filter(
or_(
EndUserModel.other_name.ilike(keyword_pattern),
cast(EndUserModel.id, String).ilike(keyword_pattern),
)
)
total = base_query.count()
if total == 0:
business_logger.info("RAG 模式下没有符合条件的宿主")
return {"items": [], "total": 0}
rows = base_query.order_by(
nullslast(desc(EndUserModel.created_at)),
desc(EndUserModel.id),
).offset((page - 1) * pagesize).limit(pagesize).all()
items = []
for end_user_orm, memory_count in rows:
items.append({
"end_user": EndUserSchema.model_validate(end_user_orm),
"memory_count": int(memory_count or 0),
})
business_logger.info(f"成功获取 RAG 宿主记录 {len(items)} 条,总计 {total}")
return {"items": items, "total": total}
except HTTPException:
raise
except Exception as e:
business_logger.error(
f"获取 RAG 宿主列表(分页)失败: workspace_id={workspace_id} - {str(e)}"
)
raise
def get_workspace_memory_increment(
db: Session,

View File

@@ -400,7 +400,7 @@ class MultimodalService:
# 在文本内容中追加图片位置标记
if result and result[-1].get("type") in ("text", "document"):
key = "text" if "text" in result[-1] else list(result[-1].keys())[-1]
result[-1][key] = result[-1].get(key, "") + f"\n[图片 {placeholder}]: <img src=\"{img_url}\" data-url=\"{img_url}\">"
result[-1][key] = result[-1].get(key, "") + f"\n[图片 {placeholder}]: {img_url}"
# 将图片以视觉格式追加到消息内容中
img_file = FileInput(
type=FileType.IMAGE,

View File

@@ -1,13 +1,13 @@
{% raw %}You are a professional information extraction system.
Your task is to analyze the provided file content and generate structured metadata.
Your task is to analyze the provided document content and generate structured metadata.
Extract the following fields:
* **summary**: A concise summary of the file in 35 sentences.
* **keywords**: 510 important keywords or key phrases that best represent the file. This field MUST be a JSON array of strings.
* **topic**: The primary topic of the file expressed as a short phrase (38 words).
* **domain**: The broader knowledge domain or field the file belongs to (e.g., Artificial Intelligence, Computer Science, Finance, Healthcare, Education, Law, etc.).
* **summary**: A concise summary of the document in 24 sentences.
* **keywords**: 510 important keywords or key phrases that best represent the document. This field MUST be a JSON array of strings.
* **topic**: The primary topic of the document expressed as a short phrase (38 words).
* **domain**: The broader knowledge domain or field the document belongs to (e.g., Artificial Intelligence, Computer Science, Finance, Healthcare, Education, Law, etc.).
STRICT RULES:
@@ -28,7 +28,7 @@ STRICT RULES:
{% endif %}
{% raw %}
6. `keywords` MUST be a JSON array of strings.
7. If the file content is insufficient, infer the best possible answer based on context.
7. If the document content is insufficient, infer the best possible answer based on context.
8. Ensure the JSON is syntactically correct.
{% endraw %}
9. Output using the language {{ language }}
@@ -50,4 +50,4 @@ Required JSON format:
{% raw %}
}
Now analyze the following file and return the JSON result.{% endraw %}
Now analyze the following document and return the JSON result.{% endraw %}

View File

@@ -554,16 +554,13 @@ class WorkflowService:
}
}
case "workflow_end":
data = {
"elapsed_time": payload.get("elapsed_time"),
"message_length": len(payload.get("output", "")),
"error": payload.get("error", "")
}
if "citations" in payload and payload["citations"]:
data["citations"] = payload["citations"]
return {
"event": "end",
"data": data
"data": {
"elapsed_time": payload.get("elapsed_time"),
"message_length": len(payload.get("output", "")),
"error": payload.get("error", "")
}
}
case "node_start" | "node_end" | "node_error" | "cycle_item":
return None

View File

@@ -20,7 +20,6 @@ from app.models.workspace_model import (
)
from app.repositories import workspace_repository
from app.repositories.workspace_invite_repository import WorkspaceInviteRepository
from app.services.session_service import SessionService
from app.schemas.workspace_schema import (
InviteAcceptRequest,
InviteValidateResponse,
@@ -59,7 +58,7 @@ def switch_workspace(
raise BusinessException(f"切换工作空间失败: {str(e)}", BizCode.INTERNAL_ERROR)
async def delete_workspace_member(
def delete_workspace_member(
db: Session,
workspace_id: uuid.UUID,
member_id: uuid.UUID,
@@ -77,29 +76,10 @@ async def delete_workspace_member(
BizCode.WORKSPACE_NOT_FOUND)
try:
deleted_user = workspace_member.user
workspace_member.is_active = False
deleted_user.current_workspace_id = None
# 若被删除成员不是超级管理员且没有其他可用工作空间,则禁用该用户
if not deleted_user.is_superuser:
remaining = (
db.query(WorkspaceMember)
.filter(
WorkspaceMember.user_id == deleted_user.id,
WorkspaceMember.workspace_id != workspace_id,
WorkspaceMember.is_active.is_(True),
)
.count()
)
if remaining == 0:
deleted_user.is_active = False
workspace_member.user.current_workspace_id = None
db.commit()
business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}")
# 使被删除成员的所有 token 立即失效
await SessionService.invalidate_all_user_tokens(str(workspace_member.user_id))
except Exception as e:
db.rollback()
business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}")

View File

@@ -1,77 +0,0 @@
import json
import logging
import redis.asyncio as redis
from app.aioRedis import get_redis_connection
logger = logging.getLogger(__name__)
DEFAULT_TTL = 3600
class ChatSessionCache:
"""Cache user-AI conversation history in Redis with TTL-based expiry.
Usage::
cache = ChatSessionCache(session_id="user_123")
await cache.append("user", "Hello")
await cache.append("assistant", "Hi there!")
history = await cache.get_history()
"""
def __init__(self, session_id: str, ttl: int = DEFAULT_TTL):
self.session_id = session_id
self.ttl = ttl
self._key = f"chat:session:{session_id}"
@staticmethod
async def _client() -> redis.StrictRedis:
return await get_redis_connection()
async def append(self, role: str, content: str) -> None:
r = await self._client()
entry = json.dumps({"role": role, "content": content}, ensure_ascii=False)
await r.rpush(self._key, entry)
await r.expire(self._key, self.ttl)
async def append_many(self, messages: list[dict[str, str]]) -> None:
"""Batch append messages. Each dict should have ``role`` and ``content`` keys."""
if not messages:
return
r = await self._client()
entries = [
json.dumps(m, ensure_ascii=False)
for m in messages
if "role" in m and "content" in m
]
if entries:
await r.rpush(self._key, *entries)
await r.expire(self._key, self.ttl)
async def get_history(self) -> list[dict[str, str]]:
r = await self._client()
raw = await r.lrange(self._key, 0, -1)
return [json.loads(item) for item in raw]
async def get_history_text(self, user_label: str = "User", ai_label: str = "Assistant") -> str:
"""Return conversation as a formatted text block."""
history = await self.get_history()
lines = []
for msg in history:
role = msg.get("role", "")
content = msg.get("content", "")
label = user_label if role == "user" else ai_label if role == "assistant" else role
lines.append(f"{label}: {content}")
return "\n".join(lines)
async def reset(self) -> None:
"""Delete the session from Redis."""
r = await self._client()
await r.delete(self._key)
async def touch(self) -> None:
"""Refresh the TTL without modifying data."""
r = await self._client()
await r.expire(self._key, self.ttl)

View File

@@ -1,47 +0,0 @@
"""202604271530
Revision ID: 1f85dce125e5
Revises: 4e89970f9e7c
Create Date: 2026-04-27 15:30:35.614679
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = '1f85dce125e5'
down_revision: Union[str, None] = '4e89970f9e7c'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('files', sa.Column('file_key', sa.String(length=512), nullable=True, comment='storage file key for FileStorageService'))
op.create_index(op.f('ix_files_file_key'), 'files', ['file_key'], unique=False)
op.alter_column('model_configs', 'capability',
existing_type=postgresql.ARRAY(sa.VARCHAR()),
comment="模型能力列表(如['vision', 'audio', 'video', 'thinking']",
existing_comment="模型能力列表(如['vision', 'audio', 'video']",
existing_nullable=False)
# ### end Alembic commands ###
op.execute("""
UPDATE files
SET file_key = 'kb/' || kb_id::text || '/' || parent_id::text || '/' || id::text || file_ext
WHERE file_ext != 'folder' AND file_key IS NULL
""")
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column('model_configs', 'capability',
existing_type=postgresql.ARRAY(sa.VARCHAR()),
comment="模型能力列表(如['vision', 'audio', 'video']",
existing_comment="模型能力列表(如['vision', 'audio', 'video', 'thinking']",
existing_nullable=False)
op.drop_index(op.f('ix_files_file_key'), table_name='files')
op.drop_column('files', 'file_key')
# ### end Alembic commands ###

View File

@@ -1,139 +0,0 @@
"""202604291755
Revision ID: 37e2a73b28c4
Revises: e2d60c6d1a1a
Create Date: 2026-04-29 18:52:35.686290
"""
from typing import Dict, List, Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '37e2a73b28c4'
down_revision: Union[str, None] = 'e2d60c6d1a1a'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
BATCH_SIZE = 500
def _chunked(values: List[str], size: int) -> List[List[str]]:
return [values[index:index + size] for index in range(0, len(values), size)]
def _load_neo4j_end_user_ids(connection) -> List[str]:
"""加载所有需要从 Neo4j 同步 memory_count 的宿主。
RAG 工作空间的记忆数量以 documents.chunk_num 为准,不写入 end_users.memory_count。
"""
rows = connection.execute(sa.text("""
SELECT eu.id::text AS end_user_id
FROM end_users eu
JOIN workspaces w ON eu.workspace_id = w.id
WHERE w.storage_type IS NULL OR w.storage_type <> 'rag'
""")).all()
return [row[0] for row in rows]
async def _fetch_neo4j_counts(end_user_ids: List[str]) -> Dict[str, int]:
if not end_user_ids:
return {}
from app.repositories.memory_config_repository import MemoryConfigRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
connector = Neo4jConnector()
try:
result = await connector.execute_query(
MemoryConfigRepository.SEARCH_FOR_ALL_BATCH,
end_user_ids=end_user_ids,
)
finally:
await connector.close()
counts = {str(row["user_id"]): int(row["total"]) for row in result}
for end_user_id in end_user_ids:
counts.setdefault(end_user_id, 0)
return counts
def _update_memory_counts(connection, counts: Dict[str, int]) -> int:
updated = 0
for end_user_id, memory_count in counts.items():
result = connection.execute(
sa.text("""
UPDATE end_users
SET memory_count = :memory_count
WHERE id = CAST(:end_user_id AS uuid)
"""),
{
"end_user_id": end_user_id,
"memory_count": memory_count,
},
)
updated += result.rowcount or 0
return updated
def _sync_memory_count_from_neo4j() -> None:
"""迁移时初始化 Neo4j 模式宿主的 memory_count。
"""
import asyncio
print("[memory_count] 开始同步 Neo4j 模式宿主 memory_count")
connection = op.get_bind()
target_ids = _load_neo4j_end_user_ids(connection)
if not target_ids:
print("[memory_count] 没有需要同步的 Neo4j 模式宿主")
return
print(
f"[memory_count] 待同步宿主数量: {len(target_ids)}, "
f"batch_size={BATCH_SIZE}"
)
total_updated = 0
batches = _chunked(target_ids, BATCH_SIZE)
for batch_index, batch_ids in enumerate(batches, start=1):
print(
f"[memory_count] 正在查询 Neo4j: "
f"batch={batch_index}/{len(batches)}, size={len(batch_ids)}"
)
counts = asyncio.run(_fetch_neo4j_counts(batch_ids))
total_updated += _update_memory_counts(connection, counts)
print(
f"[memory_count] 已写入 PostgreSQL: "
f"updated={total_updated}/{len(target_ids)}"
)
print(
f"[memory_count] Neo4j 模式宿主同步完成: "
f"total={len(target_ids)}, updated={total_updated}"
)
def upgrade() -> None:
op.add_column(
'end_users',
sa.Column(
'memory_count',
sa.Integer(),
server_default='0',
nullable=False,
comment='记忆节点总数',
),
)
_sync_memory_count_from_neo4j()
op.create_index(
op.f('ix_end_users_memory_count'),
'end_users',
['memory_count'],
unique=False,
)
def downgrade() -> None:
op.drop_index(op.f('ix_end_users_memory_count'), table_name='end_users')
op.drop_column('end_users', 'memory_count')

View File

@@ -1,34 +0,0 @@
"""202604281230
Revision ID: e2d60c6d1a1a
Revises: 1f85dce125e5
Create Date: 2026-04-28 12:32:01.643954
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = 'e2d60c6d1a1a'
down_revision: Union[str, None] = '1f85dce125e5'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('tenants', 'api_ops_rate_limit')
op.drop_column('tenants', 'plan')
op.drop_column('tenants', 'plan_expired_at')
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('tenants', sa.Column('plan_expired_at', postgresql.TIMESTAMP(), autoincrement=False, nullable=True))
op.add_column('tenants', sa.Column('plan', sa.VARCHAR(length=50), autoincrement=False, nullable=True))
op.add_column('tenants', sa.Column('api_ops_rate_limit', sa.VARCHAR(length=100), autoincrement=False, nullable=True))
# ### end Alembic commands ###

Binary file not shown.

Before

Width:  |  Height:  |  Size: 108 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 336 KiB

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 387 B

View File

@@ -1,13 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>勾选</title>
<g id="空间外层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<g id="登录页面" transform="translate(-64, -611)" fill="#FFFFFF" fill-rule="nonzero">
<g id="编组-8" transform="translate(64, 608)">
<g id="勾选" transform="translate(0, 3)">
<path d="M12,0 C14.209139,0 16,1.790861 16,4 L16,12 C16,14.209139 14.209139,16 12,16 L4,16 C1.790861,16 0,14.209139 0,12 L0,4 C0,1.790861 1.790861,4.4408921e-16 4,0 L12,0 Z M11.9182266,4.80024782 C11.7273831,4.80024782 11.5444062,4.87629473 11.4097812,5.0115625 L6.552,9.86932813 L4.4284375,7.74489063 C4.29381317,7.60962766 4.11083967,7.53358379 3.92,7.53358379 C3.72916033,7.53358379 3.54618683,7.60962766 3.4115625,7.74489063 C3.27602096,7.87955071 3.19979999,8.06271883 3.19979999,8.25378125 C3.19979999,8.44484367 3.27602096,8.62801179 3.4115625,8.76267188 L6.0453125,11.3946719 C6.17993745,11.5299396 6.3629143,11.6059866 6.55375781,11.6059866 C6.74460132,11.6059866 6.92757818,11.5299396 7.06220312,11.3946719 L12.4311094,6.02667188 C12.5659036,5.89187668 12.6412595,5.70881589 12.6404302,5.51818919 C12.639587,5.3275625 12.5626279,5.14516989 12.4266562,5.0115625 C12.2920469,4.87629473 12.1090701,4.80024782 11.9182266,4.80024782 Z" id="形状结合"></path>
</g>
</g>
</g>
</g>
</svg>

Before

Width:  |  Height:  |  Size: 1.5 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.3 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.8 KiB

View File

@@ -8,11 +8,12 @@ import { type FC, useRef, useEffect, useState } from 'react'
import clsx from 'clsx'
import Markdown from '@/components/Markdown'
import type { ChatContentProps } from './types'
import { Spin, Flex, Button } from 'antd'
import { Spin, Image, Flex, Button } from 'antd'
import { SoundOutlined } from '@ant-design/icons'
import { useTranslation } from 'react-i18next'
import MessageFiles from './MessageFiles'
import AudioPlayer from './AudioPlayer'
import VideoPlayer from './VideoPlayer'
const getFileUrl = (file: any) => {
return file.thumbUrl || file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : undefined)
@@ -148,7 +149,72 @@ const ChatContent: FC<ChatContentProps> = ({
{labelFormat(item)}
</div>
}
<MessageFiles files={item.meta_data?.files ?? []} contentClassNames={contentClassNames} onDownload={handleDownload} />
{item?.meta_data?.files && item.meta_data?.files.length > 0 && <Flex gap={8} vertical align="end" className="rb:mb-2!">
{item.meta_data?.files?.map((file) => {
if (file.type.includes('image')) {
return (
<div key={file.url || file.uid} className={`rb:inline-block rb:group rb:relative rb:rounded-lg ${contentClassNames}`}>
<Image src={getFileUrl(file)} alt={file.name} className="rb:w-full rb:max-w-80 rb:rounded-lg rb:object-cover rb:cursor-pointer" />
</div>
)
}
if (file.type.includes('video')) {
return (
<div key={file.url || file.uid} className="rb:w-50">
{/* <video src={getFileUrl(file)} controls className="rb:max-w-80 rb:rounded-lg rb:object-cover rb:cursor-pointer" /> */}
<VideoPlayer key={file.url || file.uid} src={getFileUrl(file)} />
</div>
)
}
if (file.type.includes('audio')) {
return (
<div key={file.url || file.uid} className="rb:w-50">
<AudioPlayer key={file.url || file.uid} src={getFileUrl(file)} />
</div>
)
}
const documentType = (file.file_type || file.type)?.split('/')
return (
<Flex
key={file.url || file.uid}
align="center"
gap={10}
className="rb:text-left rb:w-45 rb:text-[12px] rb:group rb:relative rb:rounded-lg rb-border rb:py-2! rb:px-2.5! rb:border rb:border-[#F6F6F6]"
onClick={() => handleDownload(file)}
>
<div
className={clsx(
"rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/pdf_disabled.svg')]",
file.type?.includes('pdf')
? "rb:bg-[url('@/assets/images/file/pdf.svg')]"
: (file.type?.includes('excel') || file.type?.includes('spreadsheetml.sheet')) || file.type?.includes('xls') || file.type?.includes('xlsx')
? "rb:bg-[url('@/assets/images/file/excel.svg')]"
: file.type?.includes('csv')
? "rb:bg-[url('@/assets/images/file/csv.svg')]"
: file.type?.includes('html')
? "rb:bg-[url('@/assets/images/file/html.svg')]"
: file.type?.includes('json')
? "rb:bg-[url('@/assets/images/file/json.svg')]"
: file.type?.includes('ppt')
? "rb:bg-[url('@/assets/images/file/ppt.svg')]"
: file.type?.includes('markdown')
? "rb:bg-[url('@/assets/images/file/md.svg')]"
: file.type?.includes('text')
? "rb:bg-[url('@/assets/images/file/txt.svg')]"
: (file.type?.includes('doc') || file.type?.includes('docx') || file.type?.includes('word') || file.type?.includes('wordprocessingml.document'))
? "rb:bg-[url('@/assets/images/file/word.svg')]"
: "rb:bg-[url('@/assets/images/file/txt.svg')]"
)}
></div>
<div className="rb:flex-1 rb:w-32.5">
<div className="rb:leading-4 rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{file.name}</div>
<div className="rb:leading-3.5 rb:mt-0.5 rb:text-[#5B6167] rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{documentType?.[documentType.length - 1]} · {file.size}</div>
</div>
</Flex>
)
})}
</Flex>}
{/* Message bubble */}
<div className={clsx('rb:text-left rb:leading-5 rb:inline-block rb:wrap-break-word rb:relative', item.role === 'user' ? contentClassNames : '', {
// Error message style (content is null and not assistant message)

View File

@@ -1,87 +0,0 @@
import { Image, Flex } from 'antd'
import clsx from 'clsx'
import AudioPlayer from './AudioPlayer'
import VideoPlayer from './VideoPlayer'
const getFileUrl = (file: any) =>
file.thumbUrl || file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : undefined)
const DOC_ICONS: [string[], string][] = [
[['pdf'], "rb:bg-[url('@/assets/images/file/pdf.svg')]"],
[['excel', 'spreadsheetml.sheet', 'xls', 'xlsx'], "rb:bg-[url('@/assets/images/file/excel.svg')]"],
[['csv'], "rb:bg-[url('@/assets/images/file/csv.svg')]"],
[['html'], "rb:bg-[url('@/assets/images/file/html.svg')]"],
[['json'], "rb:bg-[url('@/assets/images/file/json.svg')]"],
[['ppt'], "rb:bg-[url('@/assets/images/file/ppt.svg')]"],
[['markdown'], "rb:bg-[url('@/assets/images/file/md.svg')]"],
[['text'], "rb:bg-[url('@/assets/images/file/txt.svg')]"],
[['doc', 'docx', 'word', 'wordprocessingml.document'], "rb:bg-[url('@/assets/images/file/word.svg')]"],
]
const getDocIcon = (parts: string[]) => {
const match = DOC_ICONS.find(([keys]) => keys.some(k => parts.includes(k)))
return match ? match[1] : "rb:bg-[url('@/assets/images/file/txt.svg')]"
}
interface MessageFilesProps {
files: any[]
contentClassNames?: string | Record<string, boolean>
onDownload: (file: any) => void
}
const MessageFiles = ({ files, contentClassNames, onDownload }: MessageFilesProps) => {
if (!files?.length) return null
return (
<Flex gap={8} vertical align="end" className="rb:mb-2!">
{files.map((file) => {
const key = file.url || file.uid
if (file.type.includes('image')) {
return (
<div key={key} className={clsx('rb:inline-block rb:group rb:relative rb:rounded-lg', contentClassNames)}>
<Image src={getFileUrl(file)} alt={file.name} className="rb:w-full rb:max-w-80 rb:rounded-lg rb:object-cover rb:cursor-pointer" />
</div>
)
}
if (file.type.includes('video')) {
return (
<div key={key} className="rb:w-50">
<VideoPlayer src={getFileUrl(file)} />
</div>
)
}
if (file.type.includes('audio')) {
return (
<div key={key} className="rb:w-50">
<AudioPlayer src={getFileUrl(file)} />
</div>
)
}
const documentType = (file.file_type || file.type)?.split('/') ?? []
return (
<Flex
key={key}
align="center"
gap={10}
className="rb:text-left rb:w-45 rb:text-[12px] rb:group rb:relative rb:rounded-lg rb-border rb:py-2! rb:px-2.5! rb:border rb:border-[#F6F6F6]"
onClick={() => onDownload(file)}
>
<div
className={clsx(
"rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/pdf_disabled.svg')]",
getDocIcon(documentType)
)}
/>
<div className="rb:flex-1 rb:w-32.5">
<div className="rb:leading-4 rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{file.name}</div>
<div className="rb:leading-3.5 rb:mt-0.5 rb:text-[#5B6167] rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">
{documentType?.[documentType.length - 1]} · {file.size}
</div>
</div>
</Flex>
)
})}
</Flex>
)
}
export default MessageFiles

View File

@@ -3,14 +3,14 @@ import { Popover, type PopoverProps } from 'antd'
import Tag, { type TagProps } from '@/components/Tag'
interface OverflowTagsProps {
items?: ReactNode[];
items: ReactNode[];
gap?: number;
numTagColor?: TagProps['color'];
numTag?: (num?: number) => ReactNode;
popoverProps?: PopoverProps | false;
}
const OverflowTags = ({ items = [], gap = 8, numTagColor = 'default', numTag, popoverProps }: OverflowTagsProps) => {
const OverflowTags = ({ items, gap = 8, numTagColor = 'default', numTag, popoverProps }: OverflowTagsProps) => {
const containerRef = useRef<HTMLDivElement>(null)
const measureRef = useRef<HTMLDivElement>(null)
const [visibleCount, setVisibleCount] = useState(items.length)
@@ -20,7 +20,7 @@ const OverflowTags = ({ items = [], gap = 8, numTagColor = 'default', numTag, po
if (!measure || containerWidth === 0) return
const children = Array.from(measure.children) as HTMLElement[]
if (!children.length) { setVisibleCount(0); return }
if (!children.length) return
// last child is the sample +N tag
const extraTagWidth = (children[children.length - 1] as HTMLElement).offsetWidth

View File

@@ -399,7 +399,7 @@ const Menu: FC<{
className="rb:overflow-y-auto rb:flex-1!"
/>
{/* Return to space button for superusers */}
{source === 'space' &&
{user?.is_superuser && source === 'space' &&
<Flex gap={4} vertical className="rb:my-3! rb:mx-3!">
<Divider className="rb:mb-2.5! rb:mt-0! rb:border-[#DFE4ED]! rb:mx-2! rb:min-w-[calc(100%-20px)]! rb:w-[calc(100%-20px)]!" />
<Flex
@@ -412,18 +412,16 @@ const Menu: FC<{
<div className="rb:cursor-pointer rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/menuNew/switch.svg')]"></div>
{collapsed ? null : t('common.switchSpace')}
</Flex>
{user?.is_superuser &&
<Flex
gap={8}
align="center"
justify="start"
onClick={goToSpace}
className="rb:p-2.5! rb:text-[13px] rb:hover:bg-[rgba(223,228,237,0.5)] rb:rounded-lg rb:leading-3.5 rb:font-regular rb:text-center rb:cursor-pointer"
>
<div className="rb:cursor-pointer rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/menuNew/return.svg')]"></div>
{collapsed ? null : t('common.returnToSpace')}
</Flex>
}
<Flex
gap={8}
align="center"
justify="start"
onClick={goToSpace}
className="rb:p-2.5! rb:text-[13px] rb:hover:bg-[rgba(223,228,237,0.5)] rb:rounded-lg rb:leading-3.5 rb:font-regular rb:text-center rb:cursor-pointer"
>
<div className="rb:cursor-pointer rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/menuNew/return.svg')]"></div>
{collapsed ? null : t('common.returnToSpace')}
</Flex>
</Flex>
}
{source === 'manage' && subscription && !collapsed &&

View File

@@ -1538,7 +1538,6 @@ export const en = {
json_output: 'Support JSON formatted output',
thinking_budget_tokens: 'thinking budget tokens',
thinking_budget_tokens_max_error: "Cannot exceed the max tokens limit ({{max}})",
thinking_budget_tokens_min_error: "Cannot be less than {{min}}",
logSearchPlaceholder: 'Search log content',
},
userMemory: {

View File

@@ -868,7 +868,6 @@ export const zh = {
json_output: '支持JSON格式化输出',
thinking_budget_tokens: '深度思考预算Token数',
thinking_budget_tokens_max_error: "不能超过 最大令牌数 ({{max}})",
thinking_budget_tokens_min_error: "不能小于 {{min}}",
logSearchPlaceholder: '搜索日志内容',
},
table: {

View File

@@ -467,29 +467,4 @@ input:-webkit-autofill:active {
animation-name: onAutoFillStart;
animation-duration: 1ms;
}
@keyframes onAutoFillStart { from {} to {} }
/* Login input placeholder */
.login-input input::placeholder {
color: #A8A9AA !important;
}
.login-input {
border-color: #A8A9AA;
}
/* Login input hover/focus border */
.login-input:hover,
.login-input:focus-within {
border-color: #FFFFFF !important;
box-shadow: none !important;
}
/* Override browser autofill styles */
.login-input input:-webkit-autofill,
.login-input input:-webkit-autofill:hover,
.login-input input:-webkit-autofill:focus,
.login-input input:-webkit-autofill:active {
-webkit-box-shadow: 0 0 0px 1000px #0A0A0A inset !important;
-webkit-text-fill-color: #FFFFFF !important;
transition: background-color 5000s ease-in-out 0s !important;
}
@keyframes onAutoFillStart { from {} to {} }

View File

@@ -49,8 +49,6 @@ const configFields = [
{ key: 'n', max: 10, min: 1, step: 1, defaultValue: 1 },
]
const minThinkingBudgetTokens = 128;
const defaultThinkingBudgetTokens = 1000;
const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(({
refresh,
data,
@@ -110,7 +108,7 @@ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(
const newValues: ModelConfig = {
capability: (option as Model).capability,
deep_thinking: false,
thinking_budget_tokens: defaultThinkingBudgetTokens,
thinking_budget_tokens: undefined,
json_output: false,
}
if (source === 'chat') {
@@ -130,12 +128,6 @@ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(
form.setFieldsValue({ ...rest })
}, [data?.default_model_config_id])
useEffect(() => {
if (values?.deep_thinking && !values?.thinking_budget_tokens) {
form.setFieldValue('thinking_budget_tokens', defaultThinkingBudgetTokens)
}
}, [values?.deep_thinking])
const handleReset = () => {
if (!id) return
resetAppModelConfig(id).then((res) => {
@@ -186,20 +178,15 @@ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(
name="thinking_budget_tokens"
label={t('application.thinking_budget_tokens')}
hidden={!['model', 'chat'].includes(source) || !(values?.deep_thinking || values?.capability?.includes('thinking'))}
extra={<>{t('application.range')}: [{minThinkingBudgetTokens}, {t(`application.max_tokens`)}: {values?.max_tokens}]</>}
extra={<>{t('application.range')}: [{0}, {t(`application.max_tokens`)}: {values?.max_tokens}]</>}
rules={[
{ required: values?.deep_thinking, message: t('common.pleaseEnter') },
{
validator: (_, value) => {
const maxTokens = values?.max_tokens
const deep_thinking = values?.deep_thinking;
if (deep_thinking && value !== undefined) {
if (value < minThinkingBudgetTokens) {
return Promise.reject(t('application.thinking_budget_tokens_min_error', { min: minThinkingBudgetTokens }))
}
if (maxTokens !== undefined && value > maxTokens) {
return Promise.reject(t('application.thinking_budget_tokens_max_error', { max: maxTokens }))
}
if (deep_thinking && value !== undefined && maxTokens !== undefined && value > maxTokens) {
return Promise.reject(t('application.thinking_budget_tokens_max_error', { max: maxTokens }))
}
return Promise.resolve()
}
@@ -208,7 +195,7 @@ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(
>
<RbSlider
step={1}
min={minThinkingBudgetTokens}
min={0}
max={32000}
isInput={true}
disabled={!values?.deep_thinking}

View File

@@ -102,7 +102,7 @@ const Index = () => {
<Flex gap={12} wrap="nowrap" className="rb:w-full! rb:h-full! rb:overflow-y-auto">
<div className="rb:flex-1 rb:min-w-0">
<Flex vertical>
<div className='rb:w-full rb:h-26 rb:p-4 rb:bg-cover rb:bg-[url("@/assets/images/index/index_bg.png")] rb:rounded-xl rb:overflow-hidden'>
<div className='rb:w-full rb:h-26 rb:p-4 rb:bg-cover rb:bg-[url("@/assets/images/index/index_bg@2x.png")] rb:rounded-xl rb:overflow-hidden'>
<div className="rb:font-[MiSans-Bold] rb:font-bold rb:text-white rb:text-[18px] rb:leading-7">
{t('index.spaceTitle')}
</div>

View File

@@ -14,33 +14,27 @@ import React, { useState, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
import { Button, Input, Form, App } from 'antd';
import type { FormProps } from 'antd';
import clsx from 'clsx';
import { useUser, type LoginInfo } from '@/store/user';
import { login } from '@/api/user'
import loginBg from '@/assets/images/login/bg.mp4'
import check from '@/assets/images/login/check.svg'
import loginBg from '@/assets/images/login/loginBg.png'
import check from '@/assets/images/login/check.png'
import email from '@/assets/images/login/email.svg'
import lock from '@/assets/images/login/lock.svg'
import type { LoginForm } from './types';
import { useI18n } from '@/store/locale'
/**
* Input field styling
*/
const inputClassName = "login-input rb:rounded-[8px]! rb:p-[12px]! rb:h-[44px]! rb:bg-transparent! rb:text-[#FFFFFF]! [&_input]:rb:text-[#FFFFFF]! [&_input]:rb:caret-[#FFFFFF]!"
const inputClassName = "rb:rounded-[8px]! rb:p-[12px]! rb:h-[44px]!"
/**
* Login page component
*/const LoginPage: React.FC = () => {
const { t } = useTranslation();
const { clearUserInfo, updateLoginInfo, getUserInfo } = useUser();
const { language } = useI18n()
const [loading, setLoading] = useState(false);
const [form] = Form.useForm<LoginForm>();
const emailVal = Form.useWatch('email', form);
const passwordVal = Form.useWatch('password', form);
const canLogin = !!(emailVal && passwordVal);
const { message } = App.useApp();
useEffect(() => {
@@ -49,7 +43,6 @@ const inputClassName = "login-input rb:rounded-[8px]! rb:p-[12px]! rb:h-[44px]!
/** Handle login form submission */
const handleLogin: FormProps<LoginForm>['onFinish'] = async (values) => {
if (!canLogin) return;
if (!values.email) {
message.warning(t('login.emailPlaceholder'));
return;
@@ -71,45 +64,42 @@ const inputClassName = "login-input rb:rounded-[8px]! rb:p-[12px]! rb:h-[44px]!
return (
<div className="rb:min-h-screen rb:flex rb:h-screen rb:bg-[#0A0A0A] rb:text-[#FFFFFF]">
<div className="rb:min-h-screen rb:flex rb:h-screen">
<div className="rb:relative rb:w-1/2 rb:h-screen rb:overflow-hidden">
<video src={loginBg} loop autoPlay playsInline muted className="rb:w-full rb:h-full rb:object-cover"></video>
<div className="rb:absolute rb:top-10 rb:left-12">
<div className={clsx("rb:h-8.25 rb:bg-cover", {
"rb:w-89 rb:bg-[url('@/assets/images/login/title_en.png')]": language !== 'zh',
"rb:w-42 rb:bg-[url('@/assets/images/login/title_zh.png')]": language === 'zh'
})}></div>
<div className="rb:text-[18px] rb:text-[rgba(255,255,255,0.7)] rb:leading-6.25 rb:font-regular rb:mt-3">{t('login.subTitle')}</div>
<img src={loginBg} alt="loginBg" className="rb:w-full rb:h-full rb:object-cover rb:absolute rb:top-1/2 rb:-translate-y-1/2 rb:left-0" />
<div className="rb:absolute rb:top-14 rb:left-16">
<div className="rb:text-[28px] rb:leading-8.25 rb:font-bold rb:font-[AlimamaShuHeiTi,AlimamaShuHeiTi] rb:mb-4">{t('login.title')}</div>
<div className="rb:text-[18px] rb:leading-6.25 rb:font-regular">{t('login.subTitle')}</div>
</div>
<div className="rb:absolute rb:bottom-14 rb:left-12 rb:right-12 rb:grid rb:grid-cols-2 rb:gap-x-30 rb:gap-y-10.75">
{['intelligentMemory', 'instantRecall', 'knowledgeAssociation'].map((key, index) => (
<div key={key} className={`rb:flex${index === 0 ? ' rb:col-span-2' : ''}`}>
<div className="rb:absolute rb:bottom-20.25 rb:left-16 rb:grid rb:grid-cols-2 rb:gap-x-30 rb:gap-y-10.75">
{['intelligentMemory', 'instantRecall', 'knowledgeAssociation'].map(key => (
<div key={key} className="rb:flex">
<img src={check} className="rb:w-4 rb:h-4 rb:mr-2 rb:mt-0.75" />
<div className="rb:text-[16px] rb:leading-5.5">
<div className="rb:font-medium">{t(`login.${key}`)}</div>
<div className="rb:text-[14px] rb:text-[rgba(255,255,255,0.7)] rb:leading-5 rb:font-regular! rb:mt-2">{t(`login.${key}Desc`)}</div>
<div className="rb:text-[#5B6167] rb:text-[14px] rb:leading-5 rb:font-regular! rb:mt-2">{t(`login.${key}Desc`)}</div>
</div>
</div>
))}
</div>
</div>
<div className="rb:flex rb:items-center rb:justify-center rb:flex-[1_1_auto]">
<div className="rb:w-110 rb:mx-auto">
<div className="rb:text-center rb:text-[24px] rb:font-[MiSans-Bold] rb:font-bold rb:leading-8 rb:mb-12">{t('login.welcome')}</div>
<div className="rb:bg-[#FFFFFF] rb:flex rb:items-center rb:justify-center rb:flex-[1_1_auto]">
<div className="rb:w-100 rb:mx-auto">
<div className="rb:text-center rb:text-[28px] rb:font-semibold rb:leading-8 rb:mb-12">{t('login.welcome')}</div>
<Form
form={form}
onFinish={handleLogin}
>
<Form.Item name="email" className="rb:mb-6!">
<Form.Item name="email" className="rb:mb-5!">
<Input
prefix={<img src={email} className="rb:w-5 rb:h-5 rb:mr-2" />}
placeholder={t('login.emailPlaceholder')}
className={inputClassName}
/>
</Form.Item>
<Form.Item name="password" className="rb:mb-0!">
<Form.Item name="password">
<Input.Password
prefix={<img src={lock} className="rb:w-5 rb:h-5 rb:mr-2" />}
placeholder={t('login.passwordPlaceholder')}
@@ -121,11 +111,7 @@ const inputClassName = "login-input rb:rounded-[8px]! rb:p-[12px]! rb:h-[44px]!
block
loading={loading}
htmlType="submit"
disabled={!canLogin}
className={clsx("rb:h-11.5! rb:rounded-lg! rb:mt-12", {
'rb:hover:bg-[#2d6ef1]! rb:bg-[#155EEF]! rb:border-[#155EEF]!': canLogin,
'rb:bg-[#171719]! rb:border-[#171719]!': !canLogin
})}
className="rb:h-10! rb:rounded-lg! rb:mt-4"
>
{t('login.loginIn')}
</Button>

View File

@@ -166,10 +166,10 @@ const Ontology: FC = () => {
<div className="rb:h-10 rb:wrap-break-word rb:line-clamp-2 rb:leading-5">{item.scene_description}</div>
</Tooltip>
<div className="rb:mt-2 rb:h-5.5">
<div className="rb:mt-2">
<OverflowTags
popoverProps={false}
items={item.entity_type ? [...item.entity_type.map((type, i) => <Tag key={i} variant="borderless" color="dark">{type}</Tag>), <Tag variant="borderless" color="dark">{`+${item.type_num - 3}`}</Tag>] : []}
items={[...item.entity_type?.map((type, i) => <Tag key={i} variant="borderless" color="dark">{type}</Tag>), <Tag variant="borderless" color="dark">{`+${item.type_num - 3}`}</Tag>]}
numTag={(num?: number) => <Tag variant="borderless" color="dark">{`+${item.type_num - 3 + (num ? num - 1 : 0)}`}</Tag>}
/>
</div>

View File

@@ -361,7 +361,7 @@ const Market: React.FC<{ getStatusTag?: (status: string) => ReactNode }> = () =>
)}
</Flex>
<div>
<div className="rb:font-[MiSans-Bold] rb:font-bold rb:text-[16px] rb:leading-5.5">{source.name}</div>
<div className="rb:font-[MiSans Bold] rb:font-bold rb:text-[16px] rb:leading-5.5">{source.name}</div>
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4.5">{t('tool.availableMcp')} ({mcpTotal})</div>
</div>
</Flex>

View File

@@ -101,7 +101,6 @@ const CustomToolModal = forwardRef<CustomToolModalRef, CustomToolModalProps>(({
});
};
const formatSchema = (value: string) => {
if (!value || value.trim() === '') return
setParseSchemaData({} as ParseSchemaData)
parseSchema({ schema_content: value })
.then(res => {

View File

@@ -57,6 +57,7 @@ const CanvasToolbar: FC<CanvasToolbarProps> = ({
}
}}
labelRender={(props) => {
console.log('props', props)
return `${props.value}%`
}}
className="rb:w-20 rb:h-4!"

View File

@@ -66,6 +66,8 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef; data: Work
const [fileList, setFileList] = useState<any[]>([])
const [message, setMessage] = useState<string | undefined>(undefined)
console.log('abortRef', abortRef, chatList)
/**
* Opens the chat drawer and loads workflow variables from the start node
*/

View File

@@ -18,7 +18,6 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
// Handle node selection from popover and create new node replacing the add-node placeholder
const handleNodeSelect = (selectedNodeType: any) => {
graph.startBatch('add-node');
const parentBBox = node.getBBox();
const cycleId = data.cycle;
const horizontalSpacing = 0;
@@ -44,7 +43,7 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
if (cycleId) {
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
if (parentNode) {
parentNode.addChild(newNode, { silent: true });
parentNode.addChild(newNode);
}
}
@@ -77,40 +76,55 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
}
});
setTimeout(() => {
addedEdges.forEach(e => {
const src = graph.getCellById(e.getSourceCellId());
const tgt = graph.getCellById(e.getTargetCellId());
if (src?.isNode()) src.toFront();
if (tgt?.isNode()) tgt.toFront();
});
}, 50);
// Automatically adjust loop node size
const loopNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
if (loopNode) {
const adjustLoopSize = () => {
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
if (childNodes.length > 0) {
const bounds = childNodes.reduce((acc, child) => {
const bbox = child.getBBox();
return {
minX: Math.min(acc.minX, bbox.x),
minY: Math.min(acc.minY, bbox.y),
maxX: Math.max(acc.maxX, bbox.x + bbox.width),
maxY: Math.max(acc.maxY, bbox.y + bbox.height)
};
}, { minX: Infinity, minY: Infinity, maxX: -Infinity, maxY: -Infinity });
const padding = 50;
const newWidth = Math.max(nodeWidth, bounds.maxX - bounds.minX + padding * 2);
const newHeight = Math.max(120, bounds.maxY - bounds.minY + padding * 2);
loopNode.prop('size', { width: newWidth, height: newHeight });
// Update right port x position
const ports = loopNode.getPorts();
ports.forEach(port => {
if (port.group === 'right' && port.args) {
loopNode.portProp(port.id!, 'args/x', newWidth);
}
});
}
};
adjustLoopSize();
// Listen to child node movement events
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
if (childNodes.length > 0) {
const bounds = childNodes.reduce((acc, child) => {
const bbox = child.getBBox();
return {
minX: Math.min(acc.minX, bbox.x),
minY: Math.min(acc.minY, bbox.y),
maxX: Math.max(acc.maxX, bbox.x + bbox.width),
maxY: Math.max(acc.maxY, bbox.y + bbox.height)
};
}, { minX: Infinity, minY: Infinity, maxX: -Infinity, maxY: -Infinity });
const padding = 50;
const newWidth = Math.max(nodeWidth, bounds.maxX - bounds.minX + padding * 2);
const newHeight = Math.max(120, bounds.maxY - bounds.minY + padding * 2);
loopNode.prop('size', { width: newWidth, height: newHeight });
loopNode.getPorts().forEach(port => {
if (port.group === 'right' && port.args) {
loopNode.portProp(port.id!, 'args/x', newWidth);
}
});
}
childNodes.forEach((childNode: any) => {
childNode.on('change:position', adjustLoopSize);
});
}
addedEdges.forEach(e => {
const src = graph.getCellById(e.getSourceCellId());
const tgt = graph.getCellById(e.getTargetCellId());
if (src?.isNode()) src.toFront();
if (tgt?.isNode()) tgt.toFront();
});
graph.stopBatch('add-node');
setOpen(false);
};

View File

@@ -99,7 +99,7 @@ const ConditionNode: ReactShapeConfig['component'] = ({ node }) => {
{data.type === 'if-else' &&
<Flex vertical gap={4} className="rb:mt-3!">
{data.config?.cases?.defaultValue.map((item: any, index: number) => (
<div key={index}>
<div key={index} className={item.expressions.length > 0 ? '' : 'rb:mb-1'}>
<Flex justify={item.expressions.length > 0 ? "space-between" : 'end'} className="rb:mb-1! rb:leading-4">
{item.expressions.length > 0 && <span className="rb:text-[#5B6167] rb:text-[10px] rb:pl-1">CASE{index + 1}</span>}
<span className="rb:text-[#212332] rb:font-medium rb:text-[12px]">{index === 0 ? 'IF' : `ELIF`}</span>

View File

@@ -1,15 +1,134 @@
import { useEffect } from 'react';
import { useTranslation } from 'react-i18next'
import clsx from 'clsx';
import type { ReactShapeConfig } from '@antv/x6-react-shape';
import { Flex } from 'antd';
import { CheckCircleFilled, CloseCircleFilled, LoadingOutlined } from '@ant-design/icons';
import { useTranslation } from 'react-i18next'
import { graphNodeLibrary, edgeAttrs } from '../../constant';
import NodeTools from './NodeTools'
const LoopNode: ReactShapeConfig['component'] = ({ node }) => {
const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => {
const data = node.getData() || {};
const { t } = useTranslation()
useEffect(() => {
// 使用setTimeout确保在所有节点都添加完成后再创建连线
const timer = setTimeout(() => {
initNodes()
checkAndAddAddNode()
}, 50)
return () => clearTimeout(timer)
}, [graph])
const checkAndAddAddNode = () => {
if (!graph) return;
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === data.id);
const cycleStartNodes = childNodes.filter((n: any) => n.getData()?.type === 'cycle-start');
// 如果只有一个cycle-start节点且没有其他类型的子节点则添加add-node
if (cycleStartNodes.length === 1 && childNodes.length === 1) {
const cycleStartNode = cycleStartNodes[0];
const cycleStartBBox = cycleStartNode.getBBox();
const addNode = graph.addNode({
...graphNodeLibrary.addStart,
x: cycleStartBBox.x + 84,
y: cycleStartBBox.y + 4,
data: {
type: 'add-node',
label: t('workflow.addNode'),
icon: '+',
parentId: node.id,
cycle: data.id,
},
});
node.addChild(addNode);
// 连接cycle-start和add-node
const sourcePorts = cycleStartNode.getPorts();
const targetPorts = addNode.getPorts();
const sourcePort = sourcePorts.find((port: any) => port.group === 'right')?.id || 'right';
const targetPort = targetPorts.find((port: any) => port.group === 'left')?.id || 'left';
// 然后创建连线
graph.addEdge({
source: { cell: cycleStartNode.id, port: sourcePort },
target: { cell: addNode.id, port: targetPort },
...edgeAttrs,
});
cycleStartNode.toFront()
addNode.toFront()
}
}
const initNodes = () => {
// 检查是否存在cycle为当前节点ID的子节点若存在则不调用initNodes避免重复创建
const existingCycleNodes = graph.getNodes().filter((n: any) =>
n.getData()?.cycle === data.id
);
if (existingCycleNodes.length > 0) return;
// 添加默认子节点
const parentBBox = node.getBBox();
const centerX = parentBBox.x + 24;
const centerY = parentBBox.y + 70;
const cycleStartNodeId = `cycle_start_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
const cycleStartNode = graph.addNode({
...graphNodeLibrary.cycleStart,
x: centerX,
y: centerY,
id: cycleStartNodeId,
data: {
id: cycleStartNodeId,
type: 'cycle-start',
parentId: node.id,
isDefault: true, // 标记为默认节点,不可删除
cycle: data.id,
},
});
const addNode = graph.addNode({
...graphNodeLibrary.addStart,
x: centerX + 84,
y: centerY + 4,
data: {
type: 'add-node',
label: t('workflow.addNode'),
icon: '+',
parentId: node.id,
cycle: data.id,
},
});
node.addChild(cycleStartNode)
node.addChild(addNode)
const sourcePorts = cycleStartNode.getPorts()
const targetPorts = addNode.getPorts()
let sourcePort = sourcePorts.find((port: any) => port.group === 'right')?.id || 'right';
const edgeConfig = {
source: {
cell: cycleStartNode.id,
port: sourcePort
},
target: {
cell: addNode.id,
port: targetPorts.find((port: any) => port.group === 'left')?.id || 'left'
},
...edgeAttrs
}
graph.addEdge(edgeConfig)
setTimeout(() => {
cycleStartNode.toFront()
addNode.toFront()
}, 0)
}
return (
<div className={clsx('rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:p-3 rb:border rb:rounded-2xl rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)]', {
'rb:border-[#171719]!': data.isSelected && !data.executionStatus,

View File

@@ -43,52 +43,70 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
};
}, []);
// Handle node selection from popover menu and create new node with edge connection
const handleNodeSelect = (selectedNodeType: any) => {
if (!sourceNode || !graph) return;
const sourceNodeData = sourceNode.getData();
const sourceNodeType = sourceNodeData?.type;
const isCycleSubNode = !!sourceNodeData.cycle;
const isCycleContainer = (type: string) => type === 'loop' || type === 'iteration';
const newNodeType = selectedNodeType.type;
// Save add-node placeholder position before disabling history
// If it's a cycle-start node, handle the add-node placeholder
let addNodePosition = null;
const isCycleSubNode = sourceNodeData.cycle
if (isCycleSubNode && sourceNodeType === 'cycle-start') {
const cycleId = sourceNodeData.cycle;
const addNodes = graph.getNodes().filter((n: any) =>
const addNodes = graph.getNodes().filter((n: any) =>
n.getData()?.type === 'add-node' && n.getData()?.cycle === cycleId
);
if (addNodes.length > 0) addNodePosition = addNodes[0].getBBox();
if (addNodes.length > 0) {
const addNode = addNodes[0];
addNodePosition = addNode.getBBox();
addNode.remove();
}
}
// Calculate position
// Calculate new node position to avoid overlapping
const sourceBBox = sourceNode.getBBox();
const nw = graphNodeLibrary[newNodeType]?.width || 120;
const nh = graphNodeLibrary[newNodeType]?.height || 88;
const hSpacing = isCycleSubNode ? 48 : 80;
const vSpacing = 10;
const nodeWidth = graphNodeLibrary[selectedNodeType.type]?.width || 120;
const nodeHeight = graphNodeLibrary[selectedNodeType.type]?.height || 88;
const horizontalSpacing = isCycleSubNode ? 48 : 80;
const verticalSpacing = 10;
// Get source port group information
const sourcePortInfo = sourceNode.getPorts().find((p: any) => p.id === sourcePort);
const sourcePortGroup = sourcePortInfo?.group || sourcePort;
let newX: number, newY: number;
// Calculate new node position
let newX, newY;
if (edgeInsertion) {
// Edge insertion: place new node on the same row as target, between source and target
const targetBBox = edgeInsertion.targetCell.getBBox();
const gap = targetBBox.x - (sourceBBox.x + sourceBBox.width);
const requiredSpace = nw + hSpacing * 4;
newX = sourceBBox.x + sourceBBox.width + hSpacing;
newY = targetBBox.y + (targetBBox.height - nh) / 2;
const requiredSpace = nodeWidth + horizontalSpacing * 4;
// New node x: right after source + spacing
newX = sourceBBox.x + sourceBBox.width + horizontalSpacing;
// Same row as target node
newY = targetBBox.y + (targetBBox.height - nodeHeight) / 2;
// If not enough space, shift target and all downstream nodes to the right
if (gap < requiredSpace) {
const shiftX = requiredSpace - gap;
const visited = new Set<string>();
const shiftDownstream = (cell: any) => {
if (visited.has(cell.id)) return;
visited.add(cell.id);
const cellId = cell.id;
if (visited.has(cellId)) return;
visited.add(cellId);
const pos = cell.getPosition();
cell.setPosition(pos.x + shiftX, pos.y);
// Recursively shift nodes connected from right ports
graph.getConnectedEdges(cell, { outgoing: true }).forEach((e: any) => {
const tCell = graph.getCellById(e.getTargetCellId());
if (tCell?.isNode()) shiftDownstream(tCell);
const tId = e.getTargetCellId();
if (tId && !visited.has(tId)) {
const tCell = graph.getCellById(tId);
if (tCell?.isNode()) shiftDownstream(tCell);
}
});
};
shiftDownstream(edgeInsertion.targetCell);
@@ -96,170 +114,208 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
} else if (addNodePosition) {
newX = addNodePosition.x;
newY = addNodePosition.y;
} else if (sourcePortGroup === 'left') {
newX = sourceBBox.x - nw * 2 - hSpacing;
newY = sourceBBox.y;
} else {
newX = sourceBBox.x + sourceBBox.width + hSpacing;
newY = sourceBBox.y;
const connectedNodes = new Set<string>();
graph.getConnectedEdges(sourceNode).forEach((e: any) => {
[e.getSourceCellId(), e.getTargetCellId()].forEach((cid: string) => {
if (cid !== sourceNode.id) connectedNodes.add(cid);
// Determine node placement direction based on port position
if (sourcePortGroup === 'left') {
// Left port: add node to the left
newX = sourceBBox.x - nodeWidth*2 - horizontalSpacing;
newY = sourceBBox.y;
} else {
// Right port: add node to the right
newX = sourceBBox.x + sourceBBox.width + horizontalSpacing;
newY = sourceBBox.y;
}
// Check if position overlaps with existing nodes (only consider connected nodes)
const checkOverlap = (x: number, y: number) => {
// Get nodes connected to the source node
const connectedNodes = new Set();
graph.getConnectedEdges(sourceNode).forEach((edge: any) => {
const sourceId = edge.getSourceCellId();
const targetId = edge.getTargetCellId();
if (sourceId !== sourceNode.id) connectedNodes.add(sourceId);
if (targetId !== sourceNode.id) connectedNodes.add(targetId);
});
});
const checkOverlap = (x: number, y: number) =>
graph.getNodes().some((n: any) => {
if (n.id === sourceNode.id || !connectedNodes.has(n.id)) return false;
const b = n.getBBox();
return !(x + nw < b.x || x > b.x + b.width || y + nh < b.y || y > b.y + b.height);
return graph.getNodes().some((node: any) => {
if (node.id === sourceNode.id) return false;
if (!connectedNodes.has(node.id)) return false; // Only consider connected nodes
const bbox = node.getBBox();
return !(x + nodeWidth < bbox.x || x > bbox.x + bbox.width ||
y + nodeHeight < bbox.y || y > bbox.y + bbox.height);
});
while (checkOverlap(newX, newY)) newY += nh + vSpacing;
};
// If position is occupied, search downward for empty space
while (checkOverlap(newX, newY)) {
newY += nodeHeight + verticalSpacing;
}
}
// Disable history for all graph mutations
graph.disableHistory();
// Remove add-node placeholder
if (isCycleSubNode && sourceNodeType === 'cycle-start') {
const cycleId = sourceNodeData.cycle;
graph.getNodes()
.filter((n: any) => n.getData()?.type === 'add-node' && n.getData()?.cycle === cycleId)
.forEach((n: any) => n.remove());
}
const id = `${newNodeType.replace(/-/g, '_')}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
// Create new node
const id = `${selectedNodeType.type.replace(/-/g, '_')}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
const newNode = graph.addNode({
...(graphNodeLibrary[newNodeType] || graphNodeLibrary.default),
...(graphNodeLibrary[selectedNodeType.type] || graphNodeLibrary.default),
x: newX,
y: newY - (isCycleSubNode && sourceNodeType === 'cycle-start' ? 12 : 0),
id,
data: {
id,
type: newNodeType,
type: selectedNodeType.type,
icon: selectedNodeType.icon,
name: t(`workflow.${newNodeType}`),
cycle: sourceNodeData.cycle,
name: t(`workflow.${selectedNodeType.type}`),
cycle: sourceNodeData.cycle, // Inherit cycle from source node
config: selectedNodeType.config || {}
},
});
// Add new node as child of parent node
if (sourceNodeData.cycle) {
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === sourceNodeData.cycle);
if (parentNode) parentNode.addChild(newNode, { silent: true });
}
if (edgeInsertion) {
const { edge: oldEdge } = edgeInsertion;
if (oldEdge.id && graph.getCellById(oldEdge.id)) graph.removeCell(oldEdge.id);
else graph.removeEdge(oldEdge);
}
const newPorts = newNode.getPorts();
const addedCells: any[] = [newNode];
if (edgeInsertion) {
const { targetCell, targetPort: origTargetPort } = edgeInsertion;
const newLeftPort = newPorts.find((p: any) => p.group === 'left')?.id || 'left';
const newRightPort = newPorts.find((p: any) => p.group === 'right')?.id || 'right';
addedCells.push(graph.addEdge({ source: { cell: sourceNode.id, port: sourcePort }, target: { cell: newNode.id, port: newLeftPort }, ...edgeAttrs }));
addedCells.push(graph.addEdge({ source: { cell: newNode.id, port: newRightPort }, target: { cell: targetCell.id, port: origTargetPort }, ...edgeAttrs }));
setEdgeInsertion(null);
} else if (sourcePortGroup === 'left') {
const tp = newPorts.find((p: any) => p.group === 'right')?.id || 'right';
addedCells.push(graph.addEdge({ source: { cell: newNode.id, port: tp }, target: { cell: sourceNode.id, port: sourcePort }, ...edgeAttrs }));
} else {
const tp = newPorts.find((p: any) => p.group === 'left')?.id || 'left';
addedCells.push(graph.addEdge({ source: { cell: sourceNode.id, port: sourcePort }, target: { cell: newNode.id, port: tp }, ...edgeAttrs }));
}
// If adding a loop/iteration node, create cycle-start, add-node and inner edge regardless of source type
if (isCycleContainer(newNodeType)) {
const parentBBox = newNode.getBBox();
const cycleStartId = `cycle_start_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
const cycleStartNode = graph.addNode({
...graphNodeLibrary.cycleStart,
x: parentBBox.x + 24,
y: parentBBox.y + 70,
id: cycleStartId,
data: { id: cycleStartId, type: 'cycle-start', parentId: id, isDefault: true, cycle: id },
});
const addNodePlaceholder = graph.addNode({
...graphNodeLibrary.addStart,
x: parentBBox.x + 24 + 84,
y: parentBBox.y + 70 + 4,
data: { type: 'add-node', label: t('workflow.addNode'), icon: '+', parentId: id, cycle: id },
});
newNode.addChild(cycleStartNode, { silent: true });
newNode.addChild(addNodePlaceholder, { silent: true });
const innerEdge = graph.addEdge({
source: { cell: cycleStartNode.id, port: cycleStartNode.getPorts().find((p: any) => p.group === 'right')?.id || 'right' },
target: { cell: addNodePlaceholder.id, port: addNodePlaceholder.getPorts().find((p: any) => p.group === 'left')?.id || 'left' },
...edgeAttrs,
});
addedCells.push(cycleStartNode, addNodePlaceholder, innerEdge);
}
// Adjust parent size if adding inside a cycle container
const cycleId = sourceNodeData.cycle;
if (cycleId) {
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
if (parentNode) {
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
if (childNodes.length > 0) {
const bounds = childNodes.reduce((acc: any, child: any) => {
const b = child.getBBox();
return { minX: Math.min(acc.minX, b.x), minY: Math.min(acc.minY, b.y), maxX: Math.max(acc.maxX, b.x + b.width), maxY: Math.max(acc.maxY, b.y + b.height) };
}, { minX: Infinity, minY: Infinity, maxX: -Infinity, maxY: -Infinity });
const padding = 50;
const newWidth = Math.max(nodeWidth, bounds.maxX - bounds.minX + padding * 2);
const newHeight = Math.max(120, bounds.maxY - bounds.minY + padding * 2);
parentNode.prop('size', { width: newWidth, height: newHeight });
parentNode.getPorts().forEach((port: any) => {
if (port.group === 'right' && port.args) parentNode.portProp(port.id!, 'args/x', newWidth);
});
}
parentNode.addChild(newNode);
}
}
// toFront
const bringCycleChildrenToFront = (cycleContainerId: string) => {
graph.getEdges().forEach((e: any) => {
const src = graph.getCellById(e.getSourceCellId());
const tgt = graph.getCellById(e.getTargetCellId());
if (src?.getData()?.cycle === cycleContainerId || tgt?.getData()?.cycle === cycleContainerId) e.toFront();
});
graph.getNodes().forEach((n: any) => { if (n.getData()?.cycle === cycleContainerId) n.toFront(); });
};
if (isCycleContainer(sourceNodeType)) {
newNode.toFront(); sourceNode.toFront(); bringCycleChildrenToFront(sourceNodeData.id);
if (isCycleContainer(newNodeType)) bringCycleChildrenToFront(id);
} else if (isCycleContainer(newNodeType)) {
newNode.toFront(); sourceNode.toFront(); bringCycleChildrenToFront(id);
} else {
addedCells.forEach(c => { if (c.isNode?.()) c.toFront(); });
// Edge insertion: remove old edge immediately before creating new edges
if (edgeInsertion) {
const { edge: oldEdge } = edgeInsertion;
if (oldEdge.id && graph.getCellById(oldEdge.id)) {
graph.removeCell(oldEdge.id);
} else {
graph.removeEdge(oldEdge);
}
}
// Re-enable history and manually push one batch frame for all added cells
graph.enableHistory();
const history = graph.getPlugin('history') as any;
if (history) {
const batchFrame = addedCells.map((cell: any) => ({
batch: true,
event: 'cell:added',
data: { id: cell.id, node: cell.isNode(), edge: cell.isEdge(), props: cell.toJSON() },
options: {},
}));
history.undoStack.push(batchFrame);
history.redoStack = [];
graph.trigger('history:change', { cmds: batchFrame, options: { name: 'add-node' } });
}
// Create edge connection
setTimeout(() => {
const newPorts = newNode.getPorts();
const addedEdges: any[] = [];
if (edgeInsertion) {
// Edge insertion: create source→new and new→target edges
const { targetCell, targetPort: origTargetPort } = edgeInsertion;
const newLeftPort = newPorts.find((p: any) => p.group === 'left')?.id || 'left';
const newRightPort = newPorts.find((p: any) => p.group === 'right')?.id || 'right';
addedEdges.push(graph.addEdge({
source: { cell: sourceNode.id, port: sourcePort },
target: { cell: newNode.id, port: newLeftPort },
...edgeAttrs
}));
addedEdges.push(graph.addEdge({
source: { cell: newNode.id, port: newRightPort },
target: { cell: targetCell.id, port: origTargetPort },
...edgeAttrs
}));
setEdgeInsertion(null);
} else if (sourcePortGroup === 'left') {
// Connect from left port to new node's right side
const targetPort = newPorts.find((port: any) => port.group === 'right')?.id || 'right';
addedEdges.push(graph.addEdge({
source: { cell: newNode.id, port: targetPort },
target: { cell: sourceNode.id, port: sourcePort },
...edgeAttrs
}));
} else {
// Connect from right port to new node's left side
const targetPort = newPorts.find((port: any) => port.group === 'left')?.id || 'left';
addedEdges.push(graph.addEdge({
source: { cell: sourceNode.id, port: sourcePort },
target: { cell: newNode.id, port: targetPort },
...edgeAttrs
}));
}
// Adjust loop node size when child node is added via port within loop node
const cycleId = sourceNodeData.cycle;
if (cycleId) {
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
if (parentNode) {
const adjustLoopSize = () => {
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
if (childNodes.length > 0) {
const bounds = childNodes.reduce((acc: any, child: any) => {
const bbox = child.getBBox();
return {
minX: Math.min(acc.minX, bbox.x),
minY: Math.min(acc.minY, bbox.y),
maxX: Math.max(acc.maxX, bbox.x + bbox.width),
maxY: Math.max(acc.maxY, bbox.y + bbox.height)
};
}, { minX: Infinity, minY: Infinity, maxX: -Infinity, maxY: -Infinity });
const padding = 50;
const newWidth = Math.max(nodeWidth, bounds.maxX - bounds.minX + padding * 2);
const newHeight = Math.max(120, bounds.maxY - bounds.minY + padding * 2);
parentNode.prop('size', { width: newWidth, height: newHeight });
// Update right port x position
const ports = parentNode.getPorts();
ports.forEach((port: any) => {
if (port.group === 'right' && port.args) {
parentNode.portProp(port.id!, 'args/x', newWidth);
}
});
}
};
adjustLoopSize();
// Listen to child node movement events
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
childNodes.forEach((childNode: any) => {
childNode.on('change:position', adjustLoopSize);
});
}
}
const isCycleContainer = (type: string) => type === 'loop' || type === 'iteration';
const newNodeType = selectedNodeType.type;
// Helper: bring all child nodes and their edges of a cycle container to front
const bringCycleChildrenToFront = (cycleContainerId: string) => {
graph.getEdges().forEach((e: any) => {
const src = graph.getCellById(e.getSourceCellId());
const tgt = graph.getCellById(e.getTargetCellId());
if (src?.getData()?.cycle === cycleContainerId || tgt?.getData()?.cycle === cycleContainerId) e.toFront();
});
graph.getNodes().forEach((n: any) => {
if (n.getData()?.cycle === cycleContainerId) n.toFront();
});
};
if (isCycleContainer(sourceNodeType)) {
console.log('isCycleContainer(sourceNodeType)')
// Case 4: source is a loop/iteration node — bring new node to front, then its children
newNode.toFront();
sourceNode.toFront();
bringCycleChildrenToFront(sourceNodeData.id);
} else if (isCycleContainer(newNodeType)) {
console.log('isCycleContainer(newNodeType)')
// Case 3: adding a loop/iteration node from a normal node — bring new node to front, then its children
newNode.toFront();
sourceNode.toFront()
bringCycleChildrenToFront(id);
} else {
// Case 2: normal node → normal node
addedEdges.forEach(e => {
const src = graph.getCellById(e.getSourceCellId());
const tgt = graph.getCellById(e.getTargetCellId());
if (src?.isNode()) src.toFront();
if (tgt?.isNode()) tgt.toFront();
});
}
}, 50);
// Clean up temporary element
if (tempElement) {
document.body.removeChild(tempElement);
setTempElement(null);
}
setPopoverVisible(false);
};
@@ -335,4 +391,4 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
);
};
export default PortClickHandler;
export default PortClickHandler;

View File

@@ -355,13 +355,14 @@ const CaseList: FC<CaseListProps> = ({
// Update node ports based on case count changes (add/remove cases)
const updateNodePorts = (caseCount: number, removedCaseIndex?: number) => {
if (!selectedNode || !graphRef?.current) return;
const graph = graphRef.current;
const currentRightPorts = selectedNode.getPorts().filter((port: any) => port.group === 'right');
const currentCaseCount = currentRightPorts.length - 1;
// Get current port count to determine if it's an add or remove operation
const currentPorts = selectedNode.getPorts().filter((port: any) => port.group === 'right');
const currentCaseCount = currentPorts.length - 1; // Exclude ELSE port
const isAddingCase = removedCaseIndex === undefined && caseCount > currentCaseCount;
const existingEdges = graph.getEdges().filter((edge: any) =>
// Save existing edge connections (including left-side port connections)
const existingEdges = graphRef.current.getEdges().filter((edge: any) =>
edge.getSourceCellId() === selectedNode.id || edge.getTargetCellId() === selectedNode.id
);
const edgeConnections = existingEdges.map((edge: any) => ({
@@ -370,70 +371,113 @@ const CaseList: FC<CaseListProps> = ({
targetCellId: edge.getTargetCellId(),
targetPortId: edge.getTargetPortId(),
sourceCellId: edge.getSourceCellId(),
isIncoming: edge.getTargetCellId() === selectedNode.id,
isIncoming: edge.getTargetCellId() === selectedNode.id
}));
const cases = form.getFieldValue(name) || [];
const leftPorts = selectedNode.getPorts().filter((p: any) => p.group !== 'right');
const newRightPorts = Array.from({ length: caseCount + 1 }, (_, i) => ({
id: `CASE${i + 1}`,
group: 'right',
args: { x: nodeWidth, y: getConditionNodeCasePortY(cases, i) },
}));
graph.startBatch('update-ports');
existingEdges.forEach((edge: any) => graph.removeCell(edge));
// Replace all ports in one prop call — produces a single cell:change:ports command
selectedNode.prop('ports/items', [...leftPorts, ...newRightPorts], { rewrite: true });
selectedNode.prop('size', { width: nodeWidth, height: calcConditionNodeTotalHeight(cases) });
edgeConnections.forEach(({sourcePortId, targetCellId, targetPortId, sourceCellId, isIncoming }: any) => {
if (isIncoming) {
const sourceCell = graph.getCellById(sourceCellId);
if (sourceCell) {
graph.addEdge({
source: { cell: sourceCellId, port: sourcePortId },
target: { cell: selectedNode.id, port: targetPortId },
...edgeAttrs
});
sourceCell.toFront();
bringLoopChildrenToFront(sourceCell);
selectedNode.toFront();
bringLoopChildrenToFront(selectedNode);
}
return;
}
const originalCaseNumber = parseInt(sourcePortId.match(/CASE(\d+)/)?.[1] || '0');
if (removedCaseIndex !== undefined && originalCaseNumber === removedCaseIndex + 1) return;
let newPortId = sourcePortId;
if (removedCaseIndex !== undefined) {
if (originalCaseNumber > removedCaseIndex + 1) {
newPortId = `CASE${originalCaseNumber - 1}`;
} else if (originalCaseNumber === currentCaseCount + 1) {
newPortId = `CASE${caseCount + 1}`;
}
} else if (isAddingCase && originalCaseNumber === currentCaseCount + 1) {
newPortId = `CASE${caseCount + 1}`;
}
if (newRightPorts.find((p) => p.id === newPortId)) {
const targetCell = graph.getCellById(targetCellId);
if (targetCell) {
graph.addEdge({
source: { cell: selectedNode.id, port: newPortId },
target: { cell: targetCellId, port: targetPortId },
...edgeAttrs
});
selectedNode.toFront();
bringLoopChildrenToFront(selectedNode);
targetCell.toFront();
bringLoopChildrenToFront(targetCell);
}
// Remove all existing right-side ports
const existingPorts = selectedNode.getPorts();
existingPorts.forEach((port: any) => {
if (port.group === 'right') {
selectedNode.removePort(port.id);
}
});
graph.stopBatch('update-ports');
const cases = form.getFieldValue(name) || [];
selectedNode.prop('size', { width: nodeWidth, height: calcConditionNodeTotalHeight(cases) });
// Add ELIF ports
for (let i = 0; i < caseCount; i++) {
selectedNode.addPort({
id: `CASE${i + 1}`,
group: 'right',
args: {
x: nodeWidth,
y: getConditionNodeCasePortY(cases, i),
},
});
}
// Add ELSE port
selectedNode.addPort({
id: `CASE${caseCount + 1}`,
group: 'right',
args: {
x: nodeWidth,
y: getConditionNodeCasePortY(cases, caseCount),
},
});
// Restore edge connections
setTimeout(() => {
edgeConnections.forEach(({ edge, sourcePortId, targetCellId, targetPortId, sourceCellId, isIncoming }: any) => {
// If it's an incoming connection (left-side port), restore directly
if (isIncoming) {
const sourceCell = graphRef.current?.getCellById(sourceCellId);
if (sourceCell) {
graphRef.current?.addEdge({
source: { cell: sourceCellId, port: sourcePortId },
target: { cell: selectedNode.id, port: targetPortId },
...edgeAttrs,
});
}
sourceCell.toFront()
selectedNode.toFront()
bringLoopChildrenToFront(sourceCell)
bringLoopChildrenToFront(selectedNode)
graphRef.current?.removeCell(edge);
return;
}
// Handle right-side port connections
const originalCaseNumber = parseInt(sourcePortId.match(/CASE(\d+)/)?.[1] || '0');
// If it's a remove operation and the port is being removed, delete the connection
if (removedCaseIndex !== undefined && originalCaseNumber === removedCaseIndex + 1) {
graphRef.current?.removeCell(edge);
return;
}
let newPortId = sourcePortId;
// If it's a remove operation, remap port IDs
if (removedCaseIndex !== undefined) {
if (originalCaseNumber > removedCaseIndex + 1) {
// Ports after the removed port, shift numbering forward
newPortId = `CASE${originalCaseNumber - 1}`;
}
// ELSE port always maps to the new ELSE port position
else if (originalCaseNumber === currentCaseCount + 1) {
newPortId = `CASE${caseCount + 1}`;
}
} else if (isAddingCase) {
// If it's an add operation, ELSE port needs to be remapped
if (originalCaseNumber === currentCaseCount + 1) {
newPortId = `CASE${caseCount + 1}`; // New ELSE port
}
// Newly added ports don't restore any connections
}
const newPorts = selectedNode.getPorts();
const matchingPort = newPorts.find((port: any) => port.id === newPortId);
if (matchingPort) {
const targetCell = graphRef.current?.getCellById(targetCellId);
if (targetCell) {
graphRef.current?.addEdge({
source: { cell: selectedNode.id, port: newPortId },
target: { cell: targetCellId, port: targetPortId },
...edgeAttrs
});
selectedNode.toFront()
bringLoopChildrenToFront(selectedNode)
targetCell.toFront()
bringLoopChildrenToFront(targetCell)
}
}
graphRef.current?.removeCell(edge);
});
}, 50);
};
const handleChangeLogicalOperator = (index: number) => {

View File

@@ -42,73 +42,109 @@ const CategoryList: FC<CategoryListProps> = ({ parentName, selectedNode, graphRe
// Update node ports based on category count changes (add/remove categories)
const updateNodePorts = (caseCount: number, removedCaseIndex?: number) => {
if (!selectedNode || !graphRef?.current) return;
const graph = graphRef.current;
const existingEdges = graph.getEdges().filter((edge: any) =>
// Save existing edge connections (including left-side port connections)
const existingEdges = graphRef.current.getEdges().filter((edge: any) =>
edge.getSourceCellId() === selectedNode.id || edge.getTargetCellId() === selectedNode.id
);
const edgeConnections = existingEdges.map((edge: any) => ({
edge,
sourcePortId: edge.getSourcePortId(),
targetCellId: edge.getTargetCellId(),
targetPortId: edge.getTargetPortId(),
sourceCellId: edge.getSourceCellId(),
isIncoming: edge.getTargetCellId() === selectedNode.id,
isIncoming: edge.getTargetCellId() === selectedNode.id
}));
graph.startBatch('update-ports');
existingEdges.forEach((edge: any) => graph.removeCell(edge));
// Replace all ports in one prop call — produces a single cell:change:ports command
const leftPorts = selectedNode.getPorts().filter((p: any) => p.group !== 'right');
const newRightPorts = Array.from({ length: caseCount }, (_, i) => ({
id: `CASE${i + 1}`,
group: 'right',
args: { x: nodeWidth, y: portItemArgsY * i + conditionNodePortItemArgsY },
}));
selectedNode.prop('ports/items', [...leftPorts, ...newRightPorts], { rewrite: true });
const newHeight = conditionNodeHeight + (caseCount - 2) * conditionNodeItemHeight;
selectedNode.prop('size', { width: nodeWidth, height: newHeight < conditionNodeHeight ? conditionNodeHeight : newHeight });
edgeConnections.forEach(({ sourcePortId, targetCellId, targetPortId, sourceCellId, isIncoming }: any) => {
if (isIncoming) {
const sourceCell = graph.getCellById(sourceCellId);
if (sourceCell) {
graph.addEdge({
source: { cell: sourceCellId, port: sourcePortId },
target: { cell: selectedNode.id, port: targetPortId },
...edgeAttrs
});
sourceCell.toFront();
bringLoopChildrenToFront(sourceCell);
selectedNode.toFront();
bringLoopChildrenToFront(selectedNode);
}
return;
}
const originalCaseNumber = parseInt(sourcePortId.match(/CASE(\d+)/)?.[1] || '0');
if (removedCaseIndex !== undefined && originalCaseNumber === removedCaseIndex + 1) return;
let newPortId = sourcePortId;
if (removedCaseIndex !== undefined && originalCaseNumber > removedCaseIndex + 1) {
newPortId = `CASE${originalCaseNumber - 1}`;
}
if (newRightPorts.find((p) => p.id === newPortId)) {
const targetCell = graph.getCellById(targetCellId);
if (targetCell) {
graph.addEdge({
source: { cell: selectedNode.id, port: newPortId },
target: { cell: targetCellId, port: targetPortId },
...edgeAttrs
});
selectedNode.toFront();
bringLoopChildrenToFront(selectedNode);
targetCell.toFront();
bringLoopChildrenToFront(targetCell);
}
// Remove all existing right-side ports
const existingPorts = selectedNode.getPorts();
existingPorts.forEach((port: any) => {
if (port.group === 'right') {
selectedNode.removePort(port.id);
}
});
graph.stopBatch('update-ports');
// Calculate new node height: base height 88px + 30px for each additional port
const newHeight = conditionNodeHeight + (caseCount - 2) * conditionNodeItemHeight;
selectedNode.prop('size', { width: nodeWidth, height: newHeight < conditionNodeHeight ? conditionNodeHeight : newHeight })
// Update right port x position
const currentPorts = selectedNode.getPorts();
currentPorts.forEach(port => {
if (port.group === 'right' && port.args) {
selectedNode.portProp(port.id!, 'args/x', nodeWidth);
}
});
// Add category ports
for (let i = 0; i < caseCount; i++) {
selectedNode.addPort({
id: `CASE${i + 1}`,
group: 'right',
args: {
x: nodeWidth,
y: portItemArgsY * i + conditionNodePortItemArgsY,
},
});
}
// Restore edge connections
setTimeout(() => {
edgeConnections.forEach(({ edge, sourcePortId, targetCellId, targetPortId, sourceCellId, isIncoming }: any) => {
graphRef.current?.removeCell(edge);
// If it's an incoming connection (left-side port), restore directly
if (isIncoming) {
const sourceCell = graphRef.current?.getCellById(sourceCellId);
if (sourceCell) {
graphRef.current?.addEdge({
source: { cell: sourceCellId, port: sourcePortId },
target: { cell: selectedNode.id, port: targetPortId },
...edgeAttrs
});
sourceCell.toFront()
bringLoopChildrenToFront(sourceCell)
selectedNode.toFront()
bringLoopChildrenToFront(selectedNode)
}
return;
}
// Handle right-side port connections
const originalCaseNumber = parseInt(sourcePortId.match(/CASE(\d+)/)?.[1] || '0');
// If it's a removed port, don't recreate the connection
if (removedCaseIndex !== undefined && originalCaseNumber === removedCaseIndex + 1) {
return;
}
let newPortId = sourcePortId;
// If a port was removed, remap subsequent port IDs
if (removedCaseIndex !== undefined && originalCaseNumber > removedCaseIndex + 1) {
newPortId = `CASE${originalCaseNumber - 1}`;
}
// Check if the new port exists
const newPorts = selectedNode.getPorts();
const matchingPort = newPorts.find((port: any) => port.id === newPortId);
if (matchingPort) {
const targetCell = graphRef.current?.getCellById(targetCellId);
if (targetCell) {
graphRef.current?.addEdge({
source: { cell: selectedNode.id, port: newPortId },
target: { cell: targetCellId, port: targetPortId },
...edgeAttrs
});
selectedNode.toFront()
bringLoopChildrenToFront(selectedNode)
targetCell.toFront()
bringLoopChildrenToFront(targetCell)
}
}
});
}, 50);
};
const handleAddCategory = (addFunc: Function) => {

View File

@@ -242,11 +242,10 @@ const ToolConfig: FC<{ options: Suggestion[]; }> = ({
className={parameter.type === 'boolean' ? 'rb:mb-0!' : ''}
>
{parameter.type === 'string' && parameter.enum && parameter.enum.length > 0
? <Select key={values.tool_id} size="small" options={parameter.enum.map(vo => ({ value: vo, label: vo }))} placeholder={t('common.pleaseSelect')} />
? <Select size="small" options={parameter.enum.map(vo => ({ value: vo, label: vo }))} placeholder={t('common.pleaseSelect')} />
: parameter.type === 'boolean'
? <Switch key={values.tool_id} size="small" />
? <Switch size="small" />
: <Editor
key={values.tool_id}
variant="outlined"
type="input"
size="small"

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-02-03 15:06:18
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-27 14:07:14
* @Last Modified time: 2026-04-21 18:23:31
*/
import type { ReactShapeConfig } from '@antv/x6-react-shape';
import type { GroupMetadata, PortMetadata } from '@antv/x6/lib/model/port';
@@ -948,15 +948,6 @@ export const graphNodeLibrary: Record<string, NodeConfig> = {
width: nodeWidth,
height: 120,
shape: 'notes-node',
},
output: {
width: nodeWidth,
height: 76,
shape: 'normal-node',
ports: {
groups: { left: defaultPortGroup },
items: [defaultPortItems[0]],
},
}
}

View File

@@ -2,9 +2,10 @@
* @Author: ZhaoYing
* @Date: 2026-02-03 15:17:48
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-28 13:49:11
* @Last Modified time: 2026-04-24 17:21:09
*/
import { Clipboard, Graph, Keyboard, MiniMap, Node, Snapline, History, type Edge } from '@antv/x6';
import type { HistoryCommand as Command } from '@antv/x6/lib/plugin/history/type';
import { register } from '@antv/x6-react-shape';
import type { PortMetadata } from '@antv/x6/lib/model/port';
import { App } from 'antd';
@@ -16,7 +17,7 @@ import { getWorkflowConfig, saveWorkflowConfig } from '@/api/application';
import { useUser } from '@/store/user';
import type { FeaturesConfigForm } from '@/views/ApplicationConfig/types';
import { conditionNodeHeight, conditionNodeItemHeight, conditionNodePortItemArgsY, defaultAbsolutePortGroups, defaultPortItems, edgeAttrs, edgeHoverTool, edge_color, edge_selected_color, edge_width, graphNodeLibrary, nodeLibrary, nodeRegisterLibrary, nodeWidth, notesConfig, portAttrs, portItemArgsY, portMarkup, portTextAttrs, unknownNode } from '../constant';
import type { ChatVariable, HistoryRecord, NodeProperties, WorkflowConfig } from '../types';
import type { ChatVariable, NodeProperties, WorkflowConfig } from '../types';
import { calcConditionNodeTotalHeight, getConditionNodeCasePortY } from '../utils';
import { useWorkflowStore } from '@/store/workflow';
@@ -85,10 +86,6 @@ export interface UseWorkflowGraphReturn {
/** Get start node output variable list (user-defined + system variables) */
getStartNodeVariables: () => Array<{ name: string; type: string; readonly?: boolean }>;
nodeClick: ({ node }: { node: Node }) => void;
/** All recorded history operations */
historyRecords: HistoryRecord[];
/** Clear history records */
clearHistoryRecords: () => void;
}
/**
@@ -122,17 +119,14 @@ export const useWorkflowGraph = ({
const featuresRef = useRef<FeaturesConfigForm | undefined>(undefined)
const [canUndo, setCanUndo] = useState(false)
const [canRedo, setCanRedo] = useState(false)
const [historyRecords, setHistoryRecords] = useState<HistoryRecord[]>([])
const lastHistoryRef = useRef<{ cellIds: string[]; timestamp: number; type: string } | null>(null)
const syncChildRelationshipsRef = useRef<() => void>(() => { })
const isSyncingRef = useRef(false)
useEffect(() => {
if (!graphRef.current) return
graphRef.current.getNodes().forEach(node => {
const data = node.getData()
if (data?.type === 'if-else' || data?.type === 'question-classifier') {
console.log('chatVariables', chatVariables)
node.setData({ ...data, chatVariables })
node.setData({ ...data, chatVariables }, { silent: true })
}
})
}, [chatVariables])
@@ -349,7 +343,7 @@ export const useWorkflowGraph = ({
if (parentNode) {
const addedChild = graphRef.current?.addNode(childNode)
if (addedChild) {
parentNode.addChild(addedChild, { silent: true })
parentNode.addChild(addedChild)
}
}
}
@@ -380,6 +374,8 @@ export const useWorkflowGraph = ({
const newWidth = Math.max(parentBBox.width, maxX - minX + padding * 2)
const newHeight = Math.max(parentBBox.height, maxY - minY + padding * 2 + headerHeight)
console.log('newWidth', newHeight, newWidth)
parentNode.prop('size', { width: newWidth, height: newHeight })
// Update x position of right group ports
@@ -492,135 +488,8 @@ export const useWorkflowGraph = ({
graphRef.current.cleanHistory()
}
}, 200)
} else {
graphRef.current.enableHistory()
graphRef.current.cleanHistory()
}
}
const resizeGroupNodes = (graph: Graph) => {
graph.getNodes().forEach(parentNode => {
const parentType = parentNode.getData()?.type
if (parentType !== 'loop' && parentType !== 'iteration') return
const children = graph.getNodes().filter(
n => n.getData()?.cycle === parentNode.getData()?.id && n.getData()?.type !== 'add-node'
)
if (!children.length) return
const padding = 24
const headerHeight = 50
const childBounds = children.map(c => c.getBBox())
const minX = Math.min(...childBounds.map(b => b.x))
const minY = Math.min(...childBounds.map(b => b.y))
const maxX = Math.max(...childBounds.map(b => b.x + b.width))
const maxY = Math.max(...childBounds.map(b => b.y + b.height))
const parentBBox = parentNode.getBBox()
const newWidth = Math.max(parentBBox.width, maxX - minX + padding * 2)
const newHeight = Math.max(parentBBox.height, maxY - minY + padding * 2 + headerHeight)
parentNode.prop('size', { width: newWidth, height: newHeight })
parentNode.getPorts().forEach(port => {
if (port.group === 'right' && port.args) {
parentNode.portProp(port.id!, 'args/x', newWidth)
}
})
})
}
const syncChildRelationships = () => {
if (!graphRef.current) return
const graph = graphRef.current
graph.disableHistory()
graph.getNodes().forEach(node => {
const nodeData = node.getData()
const children = node.getChildren()
const cycleId = nodeData?.cycle
if (cycleId) {
const parentNode = graph.getCellById(cycleId) as Node | null
if (!parentNode) return
if (!parentNode.getChildren()?.some(c => c.id === node.id)) {
parentNode.addChild(node, { silent: true })
}
}
if (nodeData.type === 'if-else') {
const rightPorts = node.getPorts().filter(p => p.group === 'right')
const caseCount = rightPorts.length - 1 // last port is ELSE
const currentCases: any[] = nodeData.config?.cases?.defaultValue ?? []
const newCases = caseCount !== currentCases.length
? Array.from({ length: caseCount }, (_, i) => currentCases[i] ?? { logical_operator: 'and', expressions: [] })
: currentCases
if (caseCount !== currentCases.length) {
node.setData({
...nodeData,
config: { ...nodeData.config, cases: { ...nodeData.config.cases, defaultValue: newCases } }
}, { deep: false, silent: true })
}
// Sync node height and port Y positions
node.prop('size', { width: nodeWidth, height: calcConditionNodeTotalHeight(newCases) })
newCases.forEach((_c: any, i: number) => {
node.portProp(`CASE${i + 1}`, 'args/y', getConditionNodeCasePortY(newCases, i))
})
node.portProp(`CASE${newCases.length + 1}`, 'args/y', getConditionNodeCasePortY(newCases, newCases.length))
node.toFront()
graph.getEdges().filter(e => e.getSourceCellId() === node.id).forEach(e => {
const tgt = graph.getCellById(e.getTargetCellId())
tgt?.toFront()
})
} else if (nodeData.type === 'question-classifier') {
const rightPorts = node.getPorts().filter(p => p.group === 'right')
const currentCategories: any[] = nodeData.config?.categories?.defaultValue ?? []
const categoryCount = rightPorts.length
const newCategories = categoryCount !== currentCategories.length
? rightPorts.map((port, i) => {
if (currentCategories[i]) return currentCategories[i]
const edge = graph.getEdges().find(e => e.getSourceCellId() === node.id && e.getSourcePortId() === port.id)
return edge ? { name: '' } : {}
})
: currentCategories
if (categoryCount !== currentCategories.length) {
node.setData({
...nodeData,
config: { ...nodeData.config, categories: { ...nodeData.config.categories, defaultValue: [...newCategories] } }
}, { deep: false, silent: true })
}
// Sync node height and port Y positions
const newHeight = conditionNodeHeight + (categoryCount - 2) * conditionNodeItemHeight
node.prop('size', { width: nodeWidth, height: Math.max(newHeight, conditionNodeHeight) })
rightPorts.forEach((_p, i) => {
node.portProp(`CASE${i + 1}`, 'args/y', portItemArgsY * i + conditionNodePortItemArgsY)
})
node.toFront()
graph.getEdges().filter(e => e.getSourceCellId() === node.id).forEach(e => {
const tgt = graph.getCellById(e.getTargetCellId())
tgt?.toFront()
})
}
if (children?.length) {
children.forEach(child => {
if (!child.isNode()) return
const childCycleId = (child as Node).getData?.()?.cycle
if (childCycleId !== node.id && childCycleId !== node.getData?.()?.id) {
node.removeChild(child, { silent: true })
}
})
}
})
resizeGroupNodes(graph)
graph.getEdges().forEach(edge => {
const src = graph.getCellById(edge.getSourceCellId())
const tgt = graph.getCellById(edge.getTargetCellId())
if (src?.getData()?.cycle || tgt?.getData()?.cycle) {
edge.toFront()
}
})
graph.getNodes().forEach(node => {
if (node.getData()?.cycle) node.toFront()
})
graph.enableHistory()
}
syncChildRelationshipsRef.current = syncChildRelationships
/**
* Setup X6 graph plugins (MiniMap, Snapline, Clipboard, Keyboard)
*/
@@ -656,44 +525,18 @@ export const useWorkflowGraph = ({
new History({
enabled: false,
beforeAddCommand(_event, args: any) {
const key = args?.key
if (key === 'attrs' || key === 'tools') return false
const event = args?.key ? `cell:change:${args.key}` : _event;
if (event.startsWith('cell:change:') &&
event !== 'cell:change:position' &&
event !== 'cell:change:source' &&
event !== 'cell:change:target') return false;
},
}),
);
const MERGE_INTERVAL = 1000
graphRef.current.on('history:change', ({ cmds, options }: { cmds: any[]; options: any }) => {
graphRef.current.on('history:change', ({ cmds }: { cmds: Command[] }) => {
setCanUndo(graphRef.current?.canUndo() ?? false)
setCanRedo(graphRef.current?.canRedo() ?? false)
console.log('history:change', cmds, options)
const batchName: string | undefined = options?.name
const actionType = batchName === 'undo' ? 'undo' : batchName === 'redo' ? 'redo' : batchName ? 'batch' : 'change'
const cellIds = [...new Set(cmds?.map((cmd: any) => cmd.data?.id).filter(Boolean))]
const now = Date.now()
const last = lastHistoryRef.current
const canMerge =
actionType === 'change' &&
last?.type === 'change' &&
now - last.timestamp < MERGE_INTERVAL &&
cellIds.length > 0 &&
cellIds.length === last.cellIds.length &&
cellIds.every((id, i) => id === last.cellIds[i])
if (canMerge) {
lastHistoryRef.current!.timestamp = now
setHistoryRecords(prev => {
const next = [...prev]
next[next.length - 1] = { ...next[next.length - 1], timestamp: now }
return next
})
} else {
const record: HistoryRecord = { type: actionType, timestamp: now, batchName, cellIds }
lastHistoryRef.current = { cellIds, timestamp: now, type: actionType }
setHistoryRecords(prev => [...prev, record])
}
})
graphRef.current.on('history:undo', () => { if (!isSyncingRef.current) syncChildRelationshipsRef.current() })
graphRef.current.on('history:redo', () => { if (!isSyncingRef.current) syncChildRelationshipsRef.current() })
};
// 显示/隐藏连接桩
// const showPorts = (show: boolean) => {
@@ -726,13 +569,13 @@ export const useWorkflowGraph = ({
vo.setData({
...data,
isSelected: false,
}, { silent: true });
});
}
});
node.setData({
...nodeData,
isSelected: true,
}, { silent: true });
});
clearEdgeSelect()
if (nodeData.type !== 'notes') {
setSelectedNode(node);
@@ -746,7 +589,7 @@ export const useWorkflowGraph = ({
const edgeClick = ({ edge }: { edge: Edge }) => {
clearEdgeSelect();
edge.setAttrByPath('line/stroke', edge_selected_color);
edge.setData({ ...edge.getData(), isSelected: true }, { silent: true });
edge.setData({ ...edge.getData(), isSelected: true });
clearNodeSelect();
};
/**
@@ -761,7 +604,7 @@ export const useWorkflowGraph = ({
node.setData({
...data,
isSelected: false,
}, { silent: true });
});
}
});
setSelectedNode(null);
@@ -771,7 +614,7 @@ export const useWorkflowGraph = ({
*/
const clearEdgeSelect = () => {
graphRef.current?.getEdges().forEach(e => {
e.setData({ ...e.getData(), isSelected: false, isNodeHover: false }, { silent: true });
e.setData({ ...e.getData(), isSelected: false, isNodeHover: false });
e.setAttrByPath('line/stroke', edge_color);
e.setAttrByPath('line/strokeWidth', edge_width);
});
@@ -910,6 +753,8 @@ export const useWorkflowGraph = ({
// Find corresponding parent node
const parentNode = nodes?.find(n => n.id === nodeData.cycle);
if (parentNode) {
// Use removeChild method to delete child node
parentNode.removeChild(nodeToDelete);
parentNodesToUpdate.push(parentNode);
}
// Add child node to deletion list
@@ -937,51 +782,42 @@ export const useWorkflowGraph = ({
// Delete all collected nodes and edges
if (cells.length > 0) {
// Pre-calculate which parents need an add-node restored (before removal changes the graph)
const parentsNeedingAddNode = parentNodesToUpdate
.filter(parentNode => {
const parentShape = parentNode.shape;
if (parentShape !== 'loop-node' && parentShape !== 'iteration-node') return false;
const parentData = parentNode.getData();
const allChildren = graphRef.current!.getNodes().filter(n => n.getData()?.cycle === parentData.id);
const cycleStartNodes = allChildren.filter(n => n.getData()?.type === 'cycle-start');
// After deletion, only cycle-start will remain
const nonCycleStartToDelete = cells.filter(c =>
c.isNode() &&
(c as Node).getData()?.cycle === parentData.id &&
(c as Node).getData()?.type !== 'cycle-start'
);
return cycleStartNodes.length === 1 && (allChildren.length - nonCycleStartToDelete.length) === 1;
})
.map(parentNode => ({
parentNode,
cycleStartNode: graphRef.current!.getNodes().find(
n => n.getData()?.cycle === parentNode.getData().id && n.getData()?.type === 'cycle-start'
)!
}))
.filter(({ cycleStartNode }) => !!cycleStartNode);
graphRef.current?.startBatch('delete');
graphRef.current?.removeCells(cells);
parentsNeedingAddNode.forEach(({ parentNode, cycleStartNode }) => {
// If parent is iteration/loop and only cycle-start remains, add add-node connected to it
parentNodesToUpdate.forEach(parentNode => {
const parentShape = parentNode.shape;
if (parentShape !== 'loop-node' && parentShape !== 'iteration-node') return;
const parentData = parentNode.getData();
const bbox = cycleStartNode.getBBox();
const addNode = graphRef.current!.addNode({
...graphNodeLibrary.addStart,
x: bbox.x + 84,
y: bbox.y + 4,
data: { type: 'add-node', parentId: parentNode.id, cycle: parentData.id, label: t('workflow.addNode'), icon: '+' },
});
parentNode.addChild(addNode, { silent: true });
graphRef.current!.addEdge({
source: { cell: cycleStartNode.id, port: cycleStartNode.getPorts().find(p => p.group === 'right')?.id || 'right' },
target: { cell: addNode.id, port: addNode.getPorts().find(p => p.group === 'left')?.id || 'left' },
...edgeAttrs,
});
const remainingChildren = graphRef.current!.getNodes().filter(
n => n.getData()?.cycle === parentData.id
);
const cycleStartNodes = remainingChildren.filter(n => n.getData()?.type === 'cycle-start');
if (cycleStartNodes.length === 1 && remainingChildren.length === 1) {
const cycleStartNode = cycleStartNodes[0];
const bbox = cycleStartNode.getBBox();
const addNode = graphRef.current!.addNode({
...graphNodeLibrary.addStart,
x: bbox.x + 84,
y: bbox.y + 4,
data: {
type: 'add-node',
parentId: parentNode.id,
cycle: parentData.id,
label: t('workflow.addNode'),
icon: '+',
},
});
parentNode.addChild(addNode);
const sourcePort = cycleStartNode.getPorts().find(p => p.group === 'right')?.id || 'right';
const targetPort = addNode.getPorts().find(p => p.group === 'left')?.id || 'left';
graphRef.current!.addEdge({
source: { cell: cycleStartNode.id, port: sourcePort },
target: { cell: addNode.id, port: targetPort },
...edgeAttrs,
});
}
});
graphRef.current?.stopBatch('delete');
}
return false;
};
@@ -1200,7 +1036,7 @@ export const useWorkflowGraph = ({
graphRef.current?.getConnectedEdges(node).forEach(edge => {
if (!edge.getData()?.isSelected) {
edge.setAttrByPath('line/stroke', edge_selected_color);
edge.setData({ ...edge.getData(), isNodeHover: true }, { silent: true });
edge.setData({ ...edge.getData(), isNodeHover: true });
}
});
});
@@ -1208,7 +1044,7 @@ export const useWorkflowGraph = ({
graphRef.current?.getConnectedEdges(node).forEach(edge => {
if (!edge.getData()?.isSelected) {
edge.setAttrByPath('line/stroke', edge_color);
edge.setData({ ...edge.getData(), isNodeHover: false }, { silent: true });
edge.setData({ ...edge.getData(), isNodeHover: false });
}
});
});
@@ -1290,8 +1126,8 @@ export const useWorkflowGraph = ({
// Delete selected nodes and edges
graphRef.current.bindKey(['ctrl+d', 'cmd+d', 'delete', 'backspace'], deleteEvent);
// Undo / Redo
graphRef.current.bindKey(['ctrl+z', 'cmd+z'], () => { undo(); return false; });
graphRef.current.bindKey(['ctrl+y', 'cmd+y', 'ctrl+shift+z', 'cmd+shift+z'], () => { redo(); return false; });
graphRef.current.bindKey(['ctrl+z', 'cmd+z'], () => { graphRef.current?.undo(); return false; });
graphRef.current.bindKey(['ctrl+y', 'cmd+y', 'ctrl+shift+z', 'cmd+shift+z'], () => { graphRef.current?.redo(); return false; });
};
@@ -1357,51 +1193,13 @@ export const useWorkflowGraph = ({
};
if (dragData.type === 'loop' || dragData.type === 'iteration') {
graph.disableHistory()
const parentNode = graphRef.current.addNode({
graphRef.current.addNode({
...graphNodeLibrary[dragData.type],
x: point.x - 150,
y: point.y - 100,
id: cleanNodeData.id,
data: { ...cleanNodeData, isGroup: true },
})
const parentBBox = parentNode.getBBox()
const cycleStartId = `cycle_start_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
const cycleStartNode = graphRef.current.addNode({
...graphNodeLibrary.cycleStart,
x: parentBBox.x + 24,
y: parentBBox.y + 70,
id: cycleStartId,
data: { id: cycleStartId, type: 'cycle-start', parentId: cleanNodeData.id, isDefault: true, cycle: cleanNodeData.id },
})
const addNode = graphRef.current.addNode({
...graphNodeLibrary.addStart,
x: parentBBox.x + 24 + 84,
y: parentBBox.y + 70 + 4,
data: { type: 'add-node', label: t('workflow.addNode'), icon: '+', parentId: cleanNodeData.id, cycle: cleanNodeData.id },
})
parentNode.addChild(cycleStartNode, { silent: true })
parentNode.addChild(addNode, { silent: true })
const newEdge = graphRef.current.addEdge({
source: { cell: cycleStartNode.id, port: cycleStartNode.getPorts().find(p => p.group === 'right')?.id || 'right' },
target: { cell: addNode.id, port: addNode.getPorts().find(p => p.group === 'left')?.id || 'left' },
...edgeAttrs,
})
cycleStartNode.toFront()
addNode.toFront()
graph.enableHistory()
// Manually push a single batch frame covering all 4 cells into undoStack
const history = graph.getPlugin('history') as History
const makeBatchCmd = (cell: any) => ({
batch: true,
event: 'cell:added',
data: { id: cell.id, node: cell.isNode(), edge: cell.isEdge(), props: cell.toJSON() },
options: {},
})
const batchFrame = [parentNode, cycleStartNode, addNode, newEdge].map(makeBatchCmd)
;(history as any).undoStack.push(batchFrame)
;(history as any).redoStack = []
graph.trigger('history:change', { cmds: batchFrame, options: { name: 'add-group' } })
});
} else if (dragData.type === 'if-else') {
// Create condition node
graphRef.current.addNode({
@@ -1648,80 +1446,8 @@ export const useWorkflowGraph = ({
return userVars
}
const clearHistoryRecords = () => {
setHistoryRecords([])
lastHistoryRef.current = null
}
const getStackCellIds = (cmds: any): string[] => {
const arr = Array.isArray(cmds) ? cmds : [cmds]
return [...new Set(arr.map((c: any) => c.data?.id).filter(Boolean))]
}
const isSkippableFrame = (frame: any): boolean => {
const arr = Array.isArray(frame) ? frame : [frame]
return arr.every((c: any) => ['zIndex', 'attrs', 'tools'].includes(c.data?.key))
}
const undo = () => {
const history = graphRef.current?.getPlugin('history') as History | undefined
if (!history || history.getUndoSize() === 0) return
const undoStack = (history as any).undoStack as any[]
isSyncingRef.current = true
while (undoStack.length > 0 && isSkippableFrame(undoStack[undoStack.length - 1])) {
graphRef.current!.undo()
}
if (undoStack.length === 0) {
isSyncingRef.current = false
return
}
const topIds = getStackCellIds(undoStack[undoStack.length - 1])
graphRef.current!.undo()
while (undoStack.length > 0) {
if (isSkippableFrame(undoStack[undoStack.length - 1])) {
graphRef.current!.undo()
continue
}
const nextIds = getStackCellIds(undoStack[undoStack.length - 1])
if (nextIds.length === topIds.length && nextIds.every((id, i) => id === topIds[i])) {
graphRef.current!.undo()
} else {
break
}
}
isSyncingRef.current = false
syncChildRelationships()
}
const redo = () => {
const history = graphRef.current?.getPlugin('history') as History | undefined
if (!history || history.getRedoSize() === 0) return
const redoStack = (history as any).redoStack as any[]
isSyncingRef.current = true
while (redoStack.length > 0 && isSkippableFrame(redoStack[redoStack.length - 1])) {
graphRef.current!.redo()
}
if (redoStack.length === 0) {
isSyncingRef.current = false
return
}
const topIds = getStackCellIds(redoStack[redoStack.length - 1])
graphRef.current!.redo()
while (redoStack.length > 0) {
if (isSkippableFrame(redoStack[redoStack.length - 1])) {
graphRef.current!.redo()
continue
}
const nextIds = getStackCellIds(redoStack[redoStack.length - 1])
if (nextIds.length === topIds.length && nextIds.every((id, i) => id === topIds[i])) {
graphRef.current!.redo()
} else {
break
}
}
isSyncingRef.current = false
syncChildRelationships()
}
const undo = () => graphRef.current?.undo()
const redo = () => graphRef.current?.redo()
const handleSaveFeaturesConfig = (value?: FeaturesConfigForm) => {
const { statement = '' } = value?.opening_statement || {}
@@ -1762,16 +1488,20 @@ export const useWorkflowGraph = ({
if (!graphRef.current) return;
const nodes = graphRef.current.getNodes();
// Reset all node execution status on every chatHistory change
const lastWithSub = [...chatHistory].reverse().find(item => item.subContent?.length);
// Reset all node execution status first
nodes.forEach(node => {
const data = node.getData();
node.setData({ ...data, executionStatus: '' });
if (typeof data.executionStatus === 'string') {
node.setData({ ...data, executionStatus: undefined });
}
});
const lastAssistant = [...chatHistory].reverse().find(item => item.role === 'assistant');
if (!lastAssistant?.subContent?.length) return;
lastAssistant.subContent.forEach(sub => {
if (!lastWithSub?.subContent) return;
// Build a nodeId -> status map first
const statusMap: Record<string, string> = {};
lastWithSub.subContent.forEach(sub => {
if (typeof sub.status === 'string') {
statusMap[sub.node_id] = sub.status;
const node = nodes.find(n => n.getData()?.id === sub.node_id);
if (node) {
node.setData({ ...node.getData(), executionStatus: sub.status });
@@ -1807,7 +1537,5 @@ export const useWorkflowGraph = ({
canRedo,
undo,
redo,
historyRecords,
clearHistoryRecords,
};
};

View File

@@ -113,13 +113,4 @@ export interface ChatVariable {
}
export interface AddChatVariableRef {
handleOpen: (value?: ChatVariable) => void;
}
export type HistoryActionType = 'add' | 'remove' | 'change' | 'undo' | 'redo' | 'batch'
export interface HistoryRecord {
type: HistoryActionType;
timestamp: number;
batchName?: string;
cellIds?: string[];
}

View File

@@ -17,7 +17,6 @@ export const isSubExprSet = (sub: any) => {
* Uses the same per-expression height logic as getConditionNodeCasePortY.
*/
export const calcConditionNodeTotalHeight = (cases: any[]) => {
if (!cases?.length) return conditionNodeHeight;
const casesHeight = cases.reduce((acc: number, c: any) => {
const exprs = c?.expressions ?? [];
const n = exprs.length;