Merge branch 'develop' into feature/ui_upgrade_zy

This commit is contained in:
yingzhao
2026-04-03 20:49:53 +08:00
committed by GitHub
63 changed files with 2643 additions and 684 deletions

View File

@@ -292,10 +292,19 @@ def get_opening(
): ):
"""返回开场白文本和预设问题,供前端对话界面初始化时展示""" """返回开场白文本和预设问题,供前端对话界面初始化时展示"""
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
features = cfg.features or {} # 根据应用类型获取 features
if hasattr(features, "model_dump"): from app.models.app_model import App as AppModel
features = features.model_dump() app = db.get(AppModel, app_id)
if app and app.type == "workflow":
cfg = app_service.get_workflow_config(db=db, app_id=app_id, workspace_id=workspace_id)
features = cfg.features or {}
else:
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
features = cfg.features or {}
if hasattr(features, "model_dump"):
features = features.model_dump()
opening = features.get("opening_statement", {}) opening = features.get("opening_statement", {})
return success(data=app_schema.OpeningResponse( return success(data=app_schema.OpeningResponse(
enabled=opening.get("enabled", False), enabled=opening.get("enabled", False),
@@ -1070,6 +1079,14 @@ async def update_workflow_config(
current_user: Annotated[User, Depends(get_current_user)] current_user: Annotated[User, Depends(get_current_user)]
): ):
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
if payload.variables:
from app.services.workflow_service import WorkflowService
resolved = await WorkflowService(db)._resolve_variables_file_defaults(
[v.model_dump() for v in payload.variables]
)
# Patch default values back into VariableDefinition objects
for var_def, resolved_def in zip(payload.variables, resolved):
var_def.default = resolved_def.get("default", var_def.default)
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id) cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
return success(data=WorkflowConfigSchema.model_validate(cfg)) return success(data=WorkflowConfigSchema.model_validate(cfg))

View File

@@ -53,22 +53,24 @@ async def login_for_access_token(
user = auth_service.authenticate_user_or_raise(db, form_data.email, form_data.password) user = auth_service.authenticate_user_or_raise(db, form_data.email, form_data.password)
auth_logger.info(f"用户认证成功: {user.email} (ID: {user.id})") auth_logger.info(f"用户认证成功: {user.email} (ID: {user.id})")
if form_data.invite: if form_data.invite:
auth_service.bind_workspace_with_invite(db=db, auth_service.bind_workspace_with_invite(
user=user, db=db,
invite_token=form_data.invite, user=user,
workspace_id=invite_info.workspace_id) invite_token=form_data.invite,
workspace_id=invite_info.workspace_id
)
except BusinessException as e: except BusinessException as e:
# 用户不存在且有邀请码,尝试注册 # 用户不存在且有邀请码,尝试注册
if e.code == BizCode.USER_NOT_FOUND: if e.code == BizCode.USER_NOT_FOUND:
auth_logger.info(f"用户不存在,使用邀请码注册: {form_data.email}") auth_logger.info(f"用户不存在,使用邀请码注册: {form_data.email}")
user = auth_service.register_user_with_invite( user = auth_service.register_user_with_invite(
db=db, db=db,
email=form_data.email, email=form_data.email,
username=form_data.username, username=form_data.username,
password=form_data.password, password=form_data.password,
invite_token=form_data.invite, invite_token=form_data.invite,
workspace_id=invite_info.workspace_id workspace_id=invite_info.workspace_id
) )
elif e.code == BizCode.PASSWORD_ERROR: elif e.code == BizCode.PASSWORD_ERROR:
# 用户存在但密码错误 # 用户存在但密码错误
auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}") auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}")

View File

@@ -314,8 +314,10 @@ async def parse_documents(
) )
# 4. Check if the file exists # 4. Check if the file exists
api_logger.debug(f"Constructed file path: {file_path}")
api_logger.debug(f"File metadata - kb_id: {db_file.kb_id}, parent_id: {db_file.parent_id}, file_id: {db_file.id}, extension: {db_file.file_ext}")
if not os.path.exists(file_path): if not os.path.exists(file_path):
api_logger.warning(f"File not found (possibly deleted): file_path={file_path}") api_logger.error(f"File not found (possibly deleted): file_path={file_path}, file_id={db_file.id}, document_id={document_id}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail="File not found (possibly deleted)" detail="File not found (possibly deleted)"

View File

@@ -475,7 +475,7 @@ class LangChainAgent:
history: Optional[List[Dict[str, str]]] = None, history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None, context: Optional[str] = None,
files: Optional[List[Dict[str, Any]]] = None files: Optional[List[Dict[str, Any]]] = None
) -> AsyncGenerator[str | int, None]: ) -> AsyncGenerator[str | int | dict[str, str], None]:
"""执行流式对话 """执行流式对话
Args: Args:

View File

@@ -0,0 +1,408 @@
"""
Perceptual Memory Retrieval Node & Service
Provides PerceptualSearchService for searching perceptual memories (vision, audio,
text, conversation) from Neo4j using keyword fulltext + embedding semantic search
with BM25+embedding fusion reranking.
Also provides the perceptual_retrieve_node for use as a LangGraph node.
"""
import asyncio
import math
from typing import List, Dict, Any, Optional
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import ReadState
from app.core.memory.utils.data.text_utils import escape_lucene_query
from app.repositories.neo4j.graph_search import (
search_perceptual,
search_perceptual_by_embedding,
)
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
logger = get_agent_logger(__name__)
class PerceptualSearchService:
"""
感知记忆检索服务。
封装关键词全文检索 + 向量语义检索 + BM25/embedding 融合排序的完整流程。
调用方只需提供 query / keywords、end_user_id、memory_config即可获得
格式化并排序后的感知记忆列表和拼接文本。
Usage:
service = PerceptualSearchService(end_user_id=..., memory_config=...)
results = await service.search(query="...", keywords=[...], limit=10)
# results = {"memories": [...], "content": "...", "keyword_raw": N, "embedding_raw": M}
"""
DEFAULT_ALPHA = 0.6
DEFAULT_CONTENT_SCORE_THRESHOLD = 0.5
def __init__(
self,
end_user_id: str,
memory_config: Any,
alpha: float = DEFAULT_ALPHA,
content_score_threshold: float = DEFAULT_CONTENT_SCORE_THRESHOLD,
):
self.end_user_id = end_user_id
self.memory_config = memory_config
self.alpha = alpha
self.content_score_threshold = content_score_threshold
async def search(
self,
query: str,
keywords: Optional[List[str]] = None,
limit: int = 10,
) -> Dict[str, Any]:
"""
执行感知记忆检索(关键词 + 向量并行),融合排序后返回结果。
对 embedding 命中但 keyword 未命中的结果,补查全文索引获取 BM25 分数,
确保所有结果都同时具备 BM25 和 embedding 两个维度的评分。
Args:
query: 原始用户查询(用于向量检索和 BM25 补查)
keywords: 关键词列表(用于全文检索),为 None 时使用 [query]
limit: 最大返回数量
Returns:
{
"memories": [格式化后的记忆 dict, ...],
"content": "拼接的纯文本摘要",
"keyword_raw": int,
"embedding_raw": int,
}
"""
if keywords is None:
keywords = [query] if query else []
connector = Neo4jConnector()
try:
kw_task = self._keyword_search(connector, keywords, limit)
emb_task = self._embedding_search(connector, query, limit)
kw_results, emb_results = await asyncio.gather(
kw_task, emb_task, return_exceptions=True
)
if isinstance(kw_results, Exception):
logger.warning(f"[PerceptualSearch] keyword search error: {kw_results}")
kw_results = []
if isinstance(emb_results, Exception):
logger.warning(f"[PerceptualSearch] embedding search error: {emb_results}")
emb_results = []
# 补查 BM25找出 embedding 命中但 keyword 未命中的 id
# 用原始 query 对这些节点补查全文索引拿 BM25 score
kw_ids = {r.get("id") for r in kw_results if r.get("id")}
emb_only_ids = {r.get("id") for r in emb_results if r.get("id") and r.get("id") not in kw_ids}
if emb_only_ids and query:
backfill = await self._bm25_backfill(connector, query, emb_only_ids, limit)
# 把补查到的 BM25 score 注入到 embedding 结果中
backfill_map = {r["id"]: r.get("score", 0) for r in backfill}
for r in emb_results:
rid = r.get("id", "")
if rid in backfill_map:
r["bm25_backfill_score"] = backfill_map[rid]
logger.info(
f"[PerceptualSearch] BM25 backfill: {len(emb_only_ids)} embedding-only ids, "
f"{len(backfill_map)} got BM25 scores"
)
reranked = self._rerank(kw_results, emb_results, limit)
memories = []
content_parts = []
for record in reranked:
fmt = self._format_result(record)
fmt["score"] = round(record.get("content_score", 0), 4)
memories.append(fmt)
content_parts.append(self._build_content_text(fmt))
logger.info(
f"[PerceptualSearch] {len(memories)} results after rerank "
f"(keyword_raw={len(kw_results)}, embedding_raw={len(emb_results)})"
)
return {
"memories": memories,
"content": "\n\n".join(content_parts),
"keyword_raw": len(kw_results),
"embedding_raw": len(emb_results),
}
finally:
await connector.close()
async def _bm25_backfill(
self,
connector: Neo4jConnector,
query: str,
target_ids: set,
limit: int,
) -> List[dict]:
"""
对指定 id 集合补查全文索引 BM25 score。
用原始 query 查全文索引,只保留 id 在 target_ids 中的结果。
"""
escaped = escape_lucene_query(query)
if not escaped.strip():
return []
try:
r = await search_perceptual(
connector=connector, q=escaped,
end_user_id=self.end_user_id,
limit=limit * 5, # 多查一些以提高命中率
)
all_hits = r.get("perceptuals", [])
return [h for h in all_hits if h.get("id") in target_ids]
except Exception as e:
logger.warning(f"[PerceptualSearch] BM25 backfill failed: {e}")
return []
async def _keyword_search(
self,
connector: Neo4jConnector,
keywords: List[str],
limit: int,
) -> List[dict]:
"""并发对每个关键词做全文检索,去重后按 score 降序返回 top N 原始结果。"""
seen_ids: set = set()
all_results: List[dict] = []
async def _one(kw: str):
escaped = escape_lucene_query(kw)
if not escaped.strip():
return []
r = await search_perceptual(
connector=connector, q=escaped,
end_user_id=self.end_user_id, limit=limit,
)
return r.get("perceptuals", [])
tasks = [_one(kw) for kw in keywords[:10]]
batch = await asyncio.gather(*tasks, return_exceptions=True)
for result in batch:
if isinstance(result, Exception):
logger.warning(f"[PerceptualSearch] keyword sub-query error: {result}")
continue
for rec in result:
rid = rec.get("id", "")
if rid and rid not in seen_ids:
seen_ids.add(rid)
all_results.append(rec)
all_results.sort(key=lambda x: float(x.get("score", 0)), reverse=True)
return all_results[:limit]
async def _embedding_search(
self,
connector: Neo4jConnector,
query_text: str,
limit: int,
) -> List[dict]:
"""向量语义检索,返回原始结果(不做阈值过滤)。"""
try:
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.services.memory_config_service import MemoryConfigService
with get_db_context() as db:
cfg = MemoryConfigService(db).get_embedder_config(
str(self.memory_config.embedding_model_id)
)
client = OpenAIEmbedderClient(RedBearModelConfig(**cfg))
r = await search_perceptual_by_embedding(
connector=connector, embedder_client=client,
query_text=query_text, end_user_id=self.end_user_id,
limit=limit,
)
return r.get("perceptuals", [])
except Exception as e:
logger.warning(f"[PerceptualSearch] embedding search failed: {e}")
return []
def _rerank(
self,
keyword_results: List[dict],
embedding_results: List[dict],
limit: int,
) -> List[dict]:
"""BM25 + embedding 融合排序。
对 embedding 结果中带有 bm25_backfill_score 的条目,
将其与 keyword 结果合并后统一归一化,确保 BM25 分数在同一尺度上。
"""
# 把补查的 BM25 score 合并到 keyword_results 中统一归一化
emb_backfill_items = []
for item in embedding_results:
backfill_score = item.get("bm25_backfill_score")
if backfill_score is not None and item.get("id"):
emb_backfill_items.append({"id": item["id"], "score": backfill_score})
# 合并后统一归一化 BM25 scores
all_bm25_items = keyword_results + emb_backfill_items
all_bm25_items = self._normalize_scores(all_bm25_items)
# 建立 id -> normalized BM25 score 的映射
bm25_norm_map: Dict[str, float] = {}
for item in all_bm25_items:
item_id = item.get("id", "")
if item_id:
bm25_norm_map[item_id] = float(item.get("normalized_score", 0))
# 归一化 embedding scores
embedding_results = self._normalize_scores(embedding_results)
# 合并
combined: Dict[str, dict] = {}
for item in keyword_results:
item_id = item.get("id", "")
if not item_id:
continue
combined[item_id] = item.copy()
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
combined[item_id]["embedding_score"] = 0.0
for item in embedding_results:
item_id = item.get("id", "")
if not item_id:
continue
if item_id in combined:
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
else:
combined[item_id] = item.copy()
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
for item in combined.values():
bm25 = float(item.get("bm25_score", 0) or 0)
emb = float(item.get("embedding_score", 0) or 0)
item["content_score"] = self.alpha * bm25 + (1 - self.alpha) * emb
results = list(combined.values())
before = len(results)
results = [r for r in results if r["content_score"] >= self.content_score_threshold]
results.sort(key=lambda x: x["content_score"], reverse=True)
results = results[:limit]
logger.info(
f"[PerceptualSearch] rerank: merged={before}, after_threshold={len(results)} "
f"(alpha={self.alpha}, threshold={self.content_score_threshold})"
)
return results
@staticmethod
def _normalize_scores(items: List[dict], field: str = "score") -> List[dict]:
"""Z-score + sigmoid 归一化。"""
if not items:
return items
scores = [float(it.get(field, 0) or 0) for it in items]
if len(scores) <= 1:
for it in items:
it[f"normalized_{field}"] = 1.0
return items
mean = sum(scores) / len(scores)
var = sum((s - mean) ** 2 for s in scores) / len(scores)
std = math.sqrt(var)
if std == 0:
for it in items:
it[f"normalized_{field}"] = 1.0
else:
for it, s in zip(items, scores):
z = (s - mean) / std
it[f"normalized_{field}"] = 1 / (1 + math.exp(-z))
return items
@staticmethod
def _format_result(record: dict) -> dict:
return {
"id": record.get("id", ""),
"perceptual_type": record.get("perceptual_type", ""),
"file_name": record.get("file_name", ""),
"file_path": record.get("file_path", ""),
"summary": record.get("summary", ""),
"topic": record.get("topic", ""),
"domain": record.get("domain", ""),
"keywords": record.get("keywords", []),
"created_at": str(record.get("created_at", "")),
"file_type": record.get("file_type", ""),
"score": record.get("score", 0),
}
@staticmethod
def _build_content_text(formatted: dict) -> str:
parts = []
if formatted["summary"]:
parts.append(formatted["summary"])
if formatted["topic"]:
parts.append(f"[主题: {formatted['topic']}]")
if formatted["keywords"]:
kw_list = formatted["keywords"]
if isinstance(kw_list, list):
parts.append(f"[关键词: {', '.join(kw_list)}]")
if formatted["file_name"]:
parts.append(f"[文件: {formatted['file_name']}]")
return " ".join(parts)
def _extract_keywords_from_problems(problem_extension: dict) -> List[str]:
"""Extract search keywords from problem extension results."""
keywords = []
context = problem_extension.get("context", {})
if isinstance(context, dict):
for original_q, extended_qs in context.items():
keywords.append(original_q)
if isinstance(extended_qs, list):
keywords.extend(extended_qs)
return keywords
async def perceptual_retrieve_node(state: ReadState) -> ReadState:
"""
LangGraph node: perceptual memory retrieval.
Uses PerceptualSearchService to run keyword + embedding search with
BM25 fusion reranking, then writes results to state['perceptual_data'].
"""
end_user_id = state.get("end_user_id", "")
problem_extension = state.get("problem_extension", {})
original_query = state.get("data", "")
memory_config = state.get("memory_config", None)
logger.info(f"Perceptual_Retrieve: start, end_user_id={end_user_id}")
keywords = _extract_keywords_from_problems(problem_extension)
if not keywords:
keywords = [original_query] if original_query else []
logger.info(f"Perceptual_Retrieve: {len(keywords)} keywords extracted")
service = PerceptualSearchService(
end_user_id=end_user_id,
memory_config=memory_config,
)
search_result = await service.search(
query=original_query,
keywords=keywords,
limit=10,
)
result = {
"memories": search_result["memories"],
"content": search_result["content"],
"_intermediate": {
"type": "perceptual_retrieve",
"title": "感知记忆检索",
"data": search_result["memories"],
"query": original_query,
"result_count": len(search_result["memories"]),
},
}
return {"perceptual_data": result}

View File

@@ -263,7 +263,6 @@ async def Problem_Extension(state: ReadState) -> ReadState:
logger.info(f"Problem extension result: {aggregated_dict}") logger.info(f"Problem extension result: {aggregated_dict}")
# Emit intermediate output for frontend # Emit intermediate output for frontend
print(time.time() - start)
result = { result = {
"context": aggregated_dict, "context": aggregated_dict,
"original": data, "original": data,

View File

@@ -1,7 +1,11 @@
import asyncio
import os import os
import time import time
from app.core.logging_config import get_agent_logger, log_time from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
PerceptualSearchService,
)
from app.core.memory.agent.models.summary_models import ( from app.core.memory.agent.models.summary_models import (
RetrieveSummaryResponse, RetrieveSummaryResponse,
SummaryResponse, SummaryResponse,
@@ -339,11 +343,45 @@ async def Input_Summary(state: ReadState) -> ReadState:
try: try:
if storage_type != "rag": if storage_type != "rag":
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(
async def _perceptual_search():
service = PerceptualSearchService(
end_user_id=end_user_id,
memory_config=memory_config,
)
return await service.search(query=data, limit=5)
hybrid_task = SearchService().execute_hybrid_search(
**search_params, **search_params,
memory_config=memory_config, memory_config=memory_config,
expand_communities=False, # 路径 "2" 只需要 community 的 summary 文本,不展开到 Statement expand_communities=False,
) )
perceptual_task = _perceptual_search()
gather_results = await asyncio.gather(
hybrid_task, perceptual_task, return_exceptions=True
)
hybrid_result = gather_results[0]
perceptual_results = gather_results[1]
# 处理 hybrid search 异常
if isinstance(hybrid_result, Exception):
raise hybrid_result
retrieve_info, question, raw_results = hybrid_result
# 处理感知记忆结果
if isinstance(perceptual_results, Exception):
logger.warning(f"[Input_Summary] perceptual search failed: {perceptual_results}")
perceptual_results = []
# 拼接感知记忆内容到 retrieve_info
if perceptual_results and isinstance(perceptual_results, dict):
perceptual_content = perceptual_results.get("content", "")
if perceptual_content:
retrieve_info = f"{retrieve_info}\n\n<history-files>\n{perceptual_content}"
count = len(perceptual_results.get("memories", []))
logger.info(f"[Input_Summary] appended {count} perceptual memories (reranked)")
# 调试:打印 community 检索结果数量 # 调试:打印 community 检索结果数量
if raw_results and isinstance(raw_results, dict): if raw_results and isinstance(raw_results, dict):
reranked = raw_results.get('reranked_results', {}) reranked = raw_results.get('reranked_results', {})
@@ -371,10 +409,7 @@ async def Input_Summary(state: ReadState) -> ReadState:
"error": str(e) "error": str(e)
} }
end = time.time() end = time.time()
try: duration = end - start
duration = end - start
except Exception:
duration = 0.0
log_time('检索', duration) log_time('检索', duration)
return {"summary": summary} return {"summary": summary}
@@ -412,8 +447,20 @@ async def Retrieve_Summary(state: ReadState) -> ReadState:
retrieve_info_str = list(set(retrieve_info_str)) retrieve_info_str = list(set(retrieve_info_str))
retrieve_info_str = '\n'.join(retrieve_info_str) retrieve_info_str = '\n'.join(retrieve_info_str)
aimessages = await summary_llm(state, history, retrieve_info_str, # Merge perceptual memory content
'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1") perceptual_data = state.get("perceptual_data", {})
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
if perceptual_content:
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
aimessages = await summary_llm(
state,
history,
retrieve_info_str,
'direct_summary_prompt.jinja2',
'retrieve_summary', RetrieveSummaryResponse,
"1"
)
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "": if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
await summary_redis_save(state, aimessages) await summary_redis_save(state, aimessages)
if aimessages == '': if aimessages == '':
@@ -458,6 +505,12 @@ async def Summary(state: ReadState) -> ReadState:
retrieve_info_str += i + '\n' retrieve_info_str += i + '\n'
history = await summary_history(state) history = await summary_history(state)
# Merge perceptual memory content
perceptual_data = state.get("perceptual_data", {})
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
if perceptual_content:
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
data = { data = {
"query": query, "query": query,
"history": history, "history": history,
@@ -508,6 +561,13 @@ async def Summary_fails(state: ReadState) -> ReadState:
if key == 'answer_small': if key == 'answer_small':
for i in value: for i in value:
retrieve_info_str += i + '\n' retrieve_info_str += i + '\n'
# Merge perceptual memory content
perceptual_data = state.get("perceptual_data", {})
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
if perceptual_content:
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
data = { data = {
"query": query, "query": query,
"history": history, "history": history,

View File

@@ -17,6 +17,9 @@ from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import (
from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import ( from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import (
retrieve_nodes, retrieve_nodes,
) )
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
perceptual_retrieve_node,
)
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import ( from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
Input_Summary, Input_Summary,
Retrieve_Summary, Retrieve_Summary,
@@ -55,6 +58,7 @@ async def make_read_graph():
workflow.add_node("Input_Summary", Input_Summary) workflow.add_node("Input_Summary", Input_Summary)
workflow.add_node("Retrieve", retrieve_nodes) workflow.add_node("Retrieve", retrieve_nodes)
# workflow.add_node("Retrieve", retrieve) # workflow.add_node("Retrieve", retrieve)
workflow.add_node("Perceptual_Retrieve", perceptual_retrieve_node)
workflow.add_node("Verify", Verify) workflow.add_node("Verify", Verify)
workflow.add_node("Retrieve_Summary", Retrieve_Summary) workflow.add_node("Retrieve_Summary", Retrieve_Summary)
workflow.add_node("Summary", Summary) workflow.add_node("Summary", Summary)
@@ -65,14 +69,15 @@ async def make_read_graph():
workflow.add_conditional_edges("content_input", Split_continue) workflow.add_conditional_edges("content_input", Split_continue)
workflow.add_edge("Input_Summary", END) workflow.add_edge("Input_Summary", END)
workflow.add_edge("Split_The_Problem", "Problem_Extension") workflow.add_edge("Split_The_Problem", "Problem_Extension")
workflow.add_edge("Problem_Extension", "Retrieve") # After Problem_Extension, retrieve perceptual memory first, then main Retrieve
workflow.add_edge("Problem_Extension", "Perceptual_Retrieve")
workflow.add_edge("Perceptual_Retrieve", "Retrieve")
workflow.add_conditional_edges("Retrieve", Retrieve_continue) workflow.add_conditional_edges("Retrieve", Retrieve_continue)
workflow.add_edge("Retrieve_Summary", END) workflow.add_edge("Retrieve_Summary", END)
workflow.add_conditional_edges("Verify", Verify_continue) workflow.add_conditional_edges("Verify", Verify_continue)
workflow.add_edge("Summary_fails", END) workflow.add_edge("Summary_fails", END)
workflow.add_edge("Summary", END) workflow.add_edge("Summary", END)
'''-----'''
# workflow.add_edge("Retrieve", END) # workflow.add_edge("Retrieve", END)
# Compile workflow # Compile workflow
@@ -80,7 +85,5 @@ async def make_read_graph():
yield graph yield graph
except Exception as e: except Exception as e:
print(f"创建工作流失败: {e}") logger.error(f"创建工作流失败: {e}")
raise raise
finally:
print("工作流创建完成")

View File

@@ -10,7 +10,6 @@ from app.core.logging_config import get_agent_logger
from app.core.memory.src.search import run_hybrid_search from app.core.memory.src.search import run_hybrid_search
from app.core.memory.utils.data.text_utils import escape_lucene_query from app.core.memory.utils.data.text_utils import escape_lucene_query
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
# 需要从展开结果中过滤的字段(含 Neo4j DateTime不可 JSON 序列化) # 需要从展开结果中过滤的字段(含 Neo4j DateTime不可 JSON 序列化)
@@ -31,10 +30,10 @@ def _clean_expand_fields(obj):
async def expand_communities_to_statements( async def expand_communities_to_statements(
community_results: List[dict], community_results: List[dict],
end_user_id: str, end_user_id: str,
existing_content: str = "", existing_content: str = "",
limit: int = 10, limit: int = 10,
) -> Tuple[List[dict], List[str]]: ) -> Tuple[List[dict], List[str]]:
""" """
社区展开 helper给定命中的 community 列表,拉取关联 Statement。 社区展开 helper给定命中的 community 列表,拉取关联 Statement。
@@ -76,7 +75,8 @@ async def expand_communities_to_statements(
if s.get("statement") and s["statement"] not in existing_lines if s.get("statement") and s["statement"] not in existing_lines
] ]
cleaned = _clean_expand_fields(expanded_stmts) cleaned = _clean_expand_fields(expanded_stmts)
logger.info(f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements新增 {len(new_texts)}community_ids={community_ids}") logger.info(
f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements新增 {len(new_texts)}community_ids={community_ids}")
return cleaned, new_texts return cleaned, new_texts
@@ -117,9 +117,9 @@ class SearchService:
# Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定 # Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定
# 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要 # 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要
is_community = ( is_community = (
node_type == "community" node_type == "community"
or 'member_count' in result or 'member_count' in result
or 'core_entities' in result or 'core_entities' in result
) )
if is_community: if is_community:
name = result.get('name', '') name = result.get('name', '')
@@ -158,7 +158,7 @@ class SearchService:
# Remove wrapping quotes # Remove wrapping quotes
if (q.startswith("'") and q.endswith("'")) or ( if (q.startswith("'") and q.endswith("'")) or (
q.startswith('"') and q.endswith('"') q.startswith('"') and q.endswith('"')
): ):
q = q[1:-1] q = q[1:-1]
@@ -171,17 +171,17 @@ class SearchService:
return q return q
async def execute_hybrid_search( async def execute_hybrid_search(
self, self,
end_user_id: str, end_user_id: str,
question: str, question: str,
limit: int = 5, limit: int = 5,
search_type: str = "hybrid", search_type: str = "hybrid",
include: Optional[List[str]] = None, include: Optional[List[str]] = None,
rerank_alpha: float = 0.4, rerank_alpha: float = 0.4,
output_path: str = "search_results.json", output_path: str = "search_results.json",
return_raw_results: bool = False, return_raw_results: bool = False,
memory_config = None, memory_config=None,
expand_communities: bool = True, expand_communities: bool = True,
) -> Tuple[str, str, Optional[dict]]: ) -> Tuple[str, str, Optional[dict]]:
""" """
Execute hybrid search and return clean content. Execute hybrid search and return clean content.
@@ -269,7 +269,6 @@ class SearchService:
ntype = "community" if ('member_count' in ans or 'core_entities' in ans) else "" ntype = "community" if ('member_count' in ans or 'core_entities' in ans) else ""
content_list.append(self.extract_content_from_result(ans, node_type=ntype)) content_list.append(self.extract_content_from_result(ans, node_type=ntype))
# Filter out empty strings and join with newlines # Filter out empty strings and join with newlines
clean_content = '\n'.join([c for c in content_list if c]) clean_content = '\n'.join([c for c in content_list if c])

View File

@@ -1,4 +1,3 @@
import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Annotated, TypedDict from typing import Annotated, TypedDict
@@ -52,6 +51,7 @@ class ReadState(TypedDict):
embedding_id: str embedding_id: str
memory_config: object # 新增字段用于传递内存配置对象 memory_config: object # 新增字段用于传递内存配置对象
retrieve: dict retrieve: dict
perceptual_data: dict
RetrieveSummary: dict RetrieveSummary: dict
InputSummary: dict InputSummary: dict
verify: dict verify: dict

View File

@@ -152,6 +152,24 @@ async def write(
# Step 3: Save all data to Neo4j database # Step 3: Save all data to Neo4j database
step_start = time.time() step_start = time.time()
# Neo4j 写入前:清洗用户/AI助手实体之间的别名交叉污染
# 从 Neo4j 查询已有的 AI 助手别名,与本轮实体中的 AI 助手别名合并,
# 确保用户实体的 aliases 不包含 AI 助手的名字
try:
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
clean_cross_role_aliases,
fetch_neo4j_assistant_aliases,
)
neo4j_assistant_aliases = set()
if all_entity_nodes:
_eu_id = all_entity_nodes[0].end_user_id
if _eu_id:
neo4j_assistant_aliases = await fetch_neo4j_assistant_aliases(neo4j_connector, _eu_id)
clean_cross_role_aliases(all_entity_nodes, external_assistant_aliases=neo4j_assistant_aliases)
logger.info(f"Neo4j 写入前别名清洗完成AI助手别名排除集大小: {len(neo4j_assistant_aliases)}")
except Exception as e:
logger.warning(f"Neo4j 写入前别名清洗失败(不影响主流程): {e}")
# 添加死锁重试机制 # 添加死锁重试机制
max_retries = 3 max_retries = 3
retry_delay = 1 # 秒 retry_delay = 1 # 秒

View File

@@ -43,6 +43,7 @@ load_dotenv()
logger = get_memory_logger(__name__) logger = get_memory_logger(__name__)
def _parse_datetime(value: Any) -> Optional[datetime]: def _parse_datetime(value: Any) -> Optional[datetime]:
"""Parse ISO `created_at` strings of the form 'YYYY-MM-DDTHH:MM:SS.ssssss'.""" """Parse ISO `created_at` strings of the form 'YYYY-MM-DDTHH:MM:SS.ssssss'."""
if value is None: if value is None:
@@ -94,7 +95,7 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
item[f"normalized_{score_field}"] = None item[f"normalized_{score_field}"] = None
return results return results
if len(valid_scores) == 1: # Single valid score, set to 1.0 if len(valid_scores) == 1: # Single valid score, set to 1.0
for item, score in zip(results, scores): for item, score in zip(results, scores):
if score_field in item or score_field == "activation_value": if score_field in item or score_field == "activation_value":
if score is None: if score is None:
@@ -132,7 +133,6 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
return results return results
def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]: def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
""" """
Remove duplicate items from search results based on content. Remove duplicate items from search results based on content.
@@ -157,11 +157,11 @@ def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
# Extract content from various possible fields # Extract content from various possible fields
content = ( content = (
item.get("text") or item.get("text") or
item.get("content") or item.get("content") or
item.get("statement") or item.get("statement") or
item.get("name") or item.get("name") or
"" ""
) )
# Normalize content for comparison (strip whitespace and lowercase) # Normalize content for comparison (strip whitespace and lowercase)
@@ -189,13 +189,14 @@ def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
def rerank_with_activation( def rerank_with_activation(
keyword_results: Dict[str, List[Dict[str, Any]]], keyword_results: Dict[str, List[Dict[str, Any]]],
embedding_results: Dict[str, List[Dict[str, Any]]], embedding_results: Dict[str, List[Dict[str, Any]]],
alpha: float = 0.6, alpha: float = 0.6,
limit: int = 10, limit: int = 10,
forgetting_config: ForgettingEngineConfig | None = None, forgetting_config: ForgettingEngineConfig | None = None,
activation_boost_factor: float = 0.8, activation_boost_factor: float = 0.8,
now: datetime | None = None, now: datetime | None = None,
content_score_threshold: float = 0.5,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
两阶段排序:先按内容相关性筛选,再按激活值排序。 两阶段排序:先按内容相关性筛选,再按激活值排序。
@@ -222,6 +223,8 @@ def rerank_with_activation(
forgetting_config: 遗忘引擎配置(当前未使用) forgetting_config: 遗忘引擎配置(当前未使用)
activation_boost_factor: 激活度对记忆强度的影响系数 (默认: 0.8) activation_boost_factor: 激活度对记忆强度的影响系数 (默认: 0.8)
now: 当前时间(用于遗忘计算) now: 当前时间(用于遗忘计算)
content_score_threshold: 内容相关性最低阈值(基于归一化后的 content_score
低于此阈值的结果会被过滤。默认 0.5。
返回: 返回:
带评分元数据的重排序结果,按 final_score 排序 带评分元数据的重排序结果,按 final_score 排序
@@ -391,7 +394,19 @@ def rerank_with_activation(
# 无激活值:使用内容相关性分数 # 无激活值:使用内容相关性分数
item["final_score"] = item.get("base_score", 0) item["final_score"] = item.get("base_score", 0)
# 最终去重确保没有重复项 if content_score_threshold > 0:
before_count = len(sorted_items)
sorted_items = [
item for item in sorted_items
if float(item.get("content_score", 0) or 0) >= content_score_threshold
]
filtered_count = before_count - len(sorted_items)
if filtered_count > 0:
logger.info(
f"[rerank] {category}: filtered {filtered_count}/{before_count} "
f"items below content_score_threshold={content_score_threshold}"
)
sorted_items = _deduplicate_results(sorted_items) sorted_items = _deduplicate_results(sorted_items)
reranked[category] = sorted_items reranked[category] = sorted_items
@@ -399,7 +414,8 @@ def rerank_with_activation(
return reranked return reranked
def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str], log_file: str = None): def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str],
log_file: str = None):
"""Log search query information using the logger. """Log search query information using the logger.
Args: Args:
@@ -439,8 +455,8 @@ def _remove_keys_recursive(obj: Any, keys_to_remove: List[str]) -> Any:
def apply_reranker_placeholder( def apply_reranker_placeholder(
results: Dict[str, List[Dict[str, Any]]], results: Dict[str, List[Dict[str, Any]]],
query_text: str, query_text: str,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Placeholder for a cross-encoder reranker. Placeholder for a cross-encoder reranker.
@@ -673,17 +689,17 @@ def apply_reranker_placeholder(
async def run_hybrid_search( async def run_hybrid_search(
query_text: str, query_text: str,
search_type: str, search_type: str,
end_user_id: str | None, end_user_id: str | None,
limit: int, limit: int,
include: List[str], include: List[str],
output_path: str | None, output_path: str | None,
memory_config: "MemoryConfig", memory_config: "MemoryConfig",
rerank_alpha: float = 0.6, rerank_alpha: float = 0.6,
activation_boost_factor: float = 0.8, activation_boost_factor: float = 0.8,
use_forgetting_rerank: bool = False, use_forgetting_rerank: bool = False,
use_llm_rerank: bool = False, use_llm_rerank: bool = False,
): ):
""" """
@@ -788,7 +804,7 @@ async def run_hybrid_search(
if keyword_task: if keyword_task:
keyword_results = await keyword_task keyword_results = await keyword_task
keyword_latency = time.time() - keyword_start keyword_latency = time.time() - search_start_time
latency_metrics["keyword_search_latency"] = round(keyword_latency, 4) latency_metrics["keyword_search_latency"] = round(keyword_latency, 4)
logger.info(f"[PERF] Keyword search completed in {keyword_latency:.4f}s") logger.info(f"[PERF] Keyword search completed in {keyword_latency:.4f}s")
if search_type == "keyword": if search_type == "keyword":
@@ -798,7 +814,7 @@ async def run_hybrid_search(
if embedding_task: if embedding_task:
embedding_results = await embedding_task embedding_results = await embedding_task
embedding_latency = time.time() - embedding_start embedding_latency = time.time() - search_start_time
latency_metrics["embedding_search_latency"] = round(embedding_latency, 4) latency_metrics["embedding_search_latency"] = round(embedding_latency, 4)
logger.info(f"[PERF] Embedding search completed in {embedding_latency:.4f}s") logger.info(f"[PERF] Embedding search completed in {embedding_latency:.4f}s")
if search_type == "embedding": if search_type == "embedding":
@@ -810,7 +826,8 @@ async def run_hybrid_search(
if search_type == "hybrid": if search_type == "hybrid":
results["combined_summary"] = { results["combined_summary"] = {
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()), "total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
"total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()), "total_embedding_results": sum(
len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
"search_query": query_text, "search_query": query_text,
"search_timestamp": datetime.now().isoformat() "search_timestamp": datetime.now().isoformat()
} }
@@ -866,7 +883,8 @@ async def run_hybrid_search(
results["reranked_results"] = reranked_results results["reranked_results"] = reranked_results
results["combined_summary"] = { results["combined_summary"] = {
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()), "total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
"total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()), "total_embedding_results": sum(
len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
"total_reranked_results": sum(len(v) if isinstance(v, list) else 0 for v in reranked_results.values()), "total_reranked_results": sum(len(v) if isinstance(v, list) else 0 for v in reranked_results.values()),
"search_query": query_text, "search_query": query_text,
"search_timestamp": datetime.now().isoformat(), "search_timestamp": datetime.now().isoformat(),
@@ -908,8 +926,10 @@ async def run_hybrid_search(
# Log search completion with result count # Log search completion with result count
if search_type == "hybrid": if search_type == "hybrid":
result_counts = { result_counts = {
"keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in keyword_results.items()}, "keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in
"embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in embedding_results.items()} keyword_results.items()},
"embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in
embedding_results.items()}
} }
else: else:
result_counts = {key: len(value) if isinstance(value, list) else 0 for key, value in results.items()} result_counts = {key: len(value) if isinstance(value, list) else 0 for key, value in results.items()}
@@ -927,12 +947,12 @@ async def run_hybrid_search(
async def search_by_temporal( async def search_by_temporal(
end_user_id: Optional[str] = "test", end_user_id: Optional[str] = "test",
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
valid_date: Optional[str] = None, valid_date: Optional[str] = None,
invalid_date: Optional[str] = None, invalid_date: Optional[str] = None,
limit: int = 1, limit: int = 1,
): ):
""" """
Temporal search across Statements. Temporal search across Statements.
@@ -968,13 +988,13 @@ async def search_by_temporal(
async def search_by_keyword_temporal( async def search_by_keyword_temporal(
query_text: str, query_text: str,
end_user_id: Optional[str] = "test", end_user_id: Optional[str] = "test",
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
valid_date: Optional[str] = None, valid_date: Optional[str] = None,
invalid_date: Optional[str] = None, invalid_date: Optional[str] = None,
limit: int = 1, limit: int = 1,
): ):
""" """
Temporal keyword search across Statements. Temporal keyword search across Statements.
@@ -1011,9 +1031,9 @@ async def search_by_keyword_temporal(
async def search_chunk_by_chunk_id( async def search_chunk_by_chunk_id(
chunk_id: str, chunk_id: str,
end_user_id: Optional[str] = "test", end_user_id: Optional[str] = "test",
limit: int = 1, limit: int = 1,
): ):
""" """
Search for Chunks by chunk_id. Search for Chunks by chunk_id.
@@ -1026,4 +1046,3 @@ async def search_chunk_by_chunk_id(
limit=limit limit=limit
) )
return {"chunks": chunks} return {"chunks": chunks}

View File

@@ -4,6 +4,7 @@
import asyncio import asyncio
import difflib # 提供字符串相似度计算工具 import difflib # 提供字符串相似度计算工具
import importlib import importlib
import logging
import os import os
import re import re
from datetime import datetime from datetime import datetime
@@ -16,6 +17,8 @@ from app.core.memory.models.graph_models import (
) )
from app.core.memory.models.variate_config import DedupConfig from app.core.memory.models.variate_config import DedupConfig
logger = logging.getLogger(__name__)
# 模块级类型统一工具函数 # 模块级类型统一工具函数
def _unify_entity_type(canonical: ExtractedEntityNode, losing: ExtractedEntityNode, suggested_type: str = None) -> None: def _unify_entity_type(canonical: ExtractedEntityNode, losing: ExtractedEntityNode, suggested_type: str = None) -> None:
@@ -198,6 +201,161 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
except Exception: except Exception:
pass pass
# 用户和AI助手的占位名称集合用于名称标准化
_USER_PLACEHOLDER_NAMES = {"用户", "", "user", "i"}
_ASSISTANT_PLACEHOLDER_NAMES = {"ai助手", "助手", "人工智能助手", "智能助手", "智能体", "ai assistant", "assistant"}
# 标准化后的规范名称和类型
_CANONICAL_USER_NAME = "用户"
_CANONICAL_USER_TYPE = "用户"
_CANONICAL_ASSISTANT_NAME = "AI助手"
_CANONICAL_ASSISTANT_TYPE = "Agent"
# 用户和AI助手的所有可能名称用于判断实体是否为特殊角色实体
_ALL_USER_NAMES = _USER_PLACEHOLDER_NAMES
_ALL_ASSISTANT_NAMES = _ASSISTANT_PLACEHOLDER_NAMES
def _is_user_entity(ent: ExtractedEntityNode) -> bool:
"""判断实体是否为用户实体name 或 entity_type 匹配)"""
name = (getattr(ent, "name", "") or "").strip().lower()
etype = (getattr(ent, "entity_type", "") or "").strip()
return name in _ALL_USER_NAMES or etype == _CANONICAL_USER_TYPE
def _is_assistant_entity(ent: ExtractedEntityNode) -> bool:
"""判断实体是否为AI助手实体name 或 entity_type 匹配)"""
name = (getattr(ent, "name", "") or "").strip().lower()
etype = (getattr(ent, "entity_type", "") or "").strip()
return name in _ALL_ASSISTANT_NAMES or etype == _CANONICAL_ASSISTANT_TYPE
def _would_merge_cross_role(a: ExtractedEntityNode, b: ExtractedEntityNode) -> bool:
"""判断两个实体的合并是否会跨越用户/AI助手角色边界。
用户实体和AI助手实体永远不应该被合并在一起。
如果一方是用户实体、另一方是AI助手实体返回 True阻止合并
"""
return (
(_is_user_entity(a) and _is_assistant_entity(b))
or (_is_assistant_entity(a) and _is_user_entity(b))
)
def _normalize_special_entity_names(
entity_nodes: List[ExtractedEntityNode],
) -> None:
"""标准化用户和AI助手实体的名称和类型。
多轮对话中LLM 对同一角色可能使用不同的名称变体(如"用户"/""/"User"
"AI助手"/"助手"/"Assistant"),导致精确匹配无法合并。
此函数在去重前将这些变体统一为规范名称,并强制绑定 entity_type确保
- name="用户" 的实体 entity_type 一定为 "用户"
- name="AI助手" 的实体 entity_type 一定为 "Agent"
Args:
entity_nodes: 实体节点列表(原地修改)
"""
for ent in entity_nodes:
name = (getattr(ent, "name", "") or "").strip()
name_lower = name.lower()
if name_lower in _USER_PLACEHOLDER_NAMES:
ent.name = _CANONICAL_USER_NAME
ent.entity_type = _CANONICAL_USER_TYPE
elif name_lower in _ASSISTANT_PLACEHOLDER_NAMES:
ent.name = _CANONICAL_ASSISTANT_NAME
ent.entity_type = _CANONICAL_ASSISTANT_TYPE
# 第二步:清洗用户/AI助手之间的别名交叉污染复用 clean_cross_role_aliases
clean_cross_role_aliases(entity_nodes)
async def fetch_neo4j_assistant_aliases(neo4j_connector, end_user_id: str) -> set:
"""从 Neo4j 查询 AI 助手实体的所有别名(小写归一化)。
这是助手别名查询的唯一入口,供 write_tools 和 extraction_orchestrator 共用,
避免多处维护相同的 Cypher 和名称列表。
Args:
neo4j_connector: Neo4j 连接器实例(需提供 execute_query 方法)
end_user_id: 终端用户 ID
Returns:
小写归一化后的助手别名集合
"""
# 查询名称列表:规范名称 + 常见变体(与 _normalize_special_entity_names 标准化后一致)
query_names = [_CANONICAL_ASSISTANT_NAME, *_ASSISTANT_PLACEHOLDER_NAMES]
# 去重保序
query_names = list(dict.fromkeys(query_names))
cypher = """
MATCH (e:ExtractedEntity)
WHERE e.end_user_id = $end_user_id AND e.name IN $names
RETURN e.aliases AS aliases
"""
try:
result = await neo4j_connector.execute_query(
cypher, end_user_id=end_user_id, names=query_names
)
assistant_aliases: set = set()
for record in (result or []):
for alias in (record.get("aliases") or []):
assistant_aliases.add(alias.strip().lower())
if assistant_aliases:
logger.debug(f"Neo4j 中 AI 助手别名: {assistant_aliases}")
return assistant_aliases
except Exception as e:
logger.warning(f"查询 Neo4j AI 助手别名失败: {e}")
return set()
def clean_cross_role_aliases(
entity_nodes: List[ExtractedEntityNode],
external_assistant_aliases: set = None,
) -> None:
"""清洗用户实体和AI助手实体之间的别名交叉污染。
在 Neo4j 写入前调用,确保:
- 用户实体的 aliases 不包含 AI 助手的别名
- AI 助手实体的 aliases 不包含用户的别名
Args:
entity_nodes: 实体节点列表(原地修改)
external_assistant_aliases: 外部传入的 AI 助手别名集合(如从 Neo4j 查询),
与本轮实体中的 AI 助手别名合并使用
"""
# 收集本轮 AI 助手实体的所有别名
assistant_aliases = set(external_assistant_aliases or set())
user_aliases = set()
for ent in entity_nodes:
if _is_assistant_entity(ent):
for alias in (getattr(ent, "aliases", []) or []):
assistant_aliases.add(alias.strip().lower())
elif _is_user_entity(ent):
for alias in (getattr(ent, "aliases", []) or []):
user_aliases.add(alias.strip().lower())
# 从用户实体的 aliases 中移除 AI 助手别名
if assistant_aliases:
for ent in entity_nodes:
if _is_user_entity(ent):
original = getattr(ent, "aliases", []) or []
cleaned = [a for a in original if a.strip().lower() not in assistant_aliases]
if len(cleaned) < len(original):
ent.aliases = cleaned
# 从 AI 助手实体的 aliases 中移除用户别名
if user_aliases:
for ent in entity_nodes:
if _is_assistant_entity(ent):
original = getattr(ent, "aliases", []) or []
cleaned = [a for a in original if a.strip().lower() not in user_aliases]
if len(cleaned) < len(original):
ent.aliases = cleaned
def accurate_match( def accurate_match(
entity_nodes: List[ExtractedEntityNode] entity_nodes: List[ExtractedEntityNode]
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]: ) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]:
@@ -261,6 +419,10 @@ def accurate_match(
canonical = alias_index.get((ent_uid, ent_name)) canonical = alias_index.get((ent_uid, ent_name))
# 确保不是自身 # 确保不是自身
if canonical is not None and canonical.id != ent.id: if canonical is not None and canonical.id != ent.id:
# 保护禁止跨角色合并用户实体和AI助手实体不能互相合并
if _would_merge_cross_role(canonical, ent):
i += 1
continue
_merge_attribute(canonical, ent) _merge_attribute(canonical, ent)
id_redirect[ent.id] = canonical.id id_redirect[ent.id] = canonical.id
for k, v in list(id_redirect.items()): for k, v in list(id_redirect.items()):
@@ -704,6 +866,11 @@ def fuzzy_match(
# 条件A快速通道alias_match_merge = True # 条件A快速通道alias_match_merge = True
# 条件B标准通道s_name ≥ tn AND s_type ≥ type_threshold AND overall ≥ tover # 条件B标准通道s_name ≥ tn AND s_type ≥ type_threshold AND overall ≥ tover
if alias_match_merge or (s_name >= tn and s_type >= type_threshold and overall >= tover): if alias_match_merge or (s_name >= tn and s_type >= type_threshold and overall >= tover):
# 保护禁止跨角色合并用户实体和AI助手实体不能互相合并
if _would_merge_cross_role(a, b):
j += 1
continue
# ========== 第六步:执行实体合并 ========== # ========== 第六步:执行实体合并 ==========
# 6.1 合并别名 # 6.1 合并别名
@@ -813,6 +980,12 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
b = entity_by_id.get(losing_id) b = entity_by_id.get(losing_id)
if not a or not b: # 若不存在 a 或 b可能已在精确或模糊阶段合并在之前阶段合并之后不会再处理但是处于审计的目的会记录 if not a or not b: # 若不存在 a 或 b可能已在精确或模糊阶段合并在之前阶段合并之后不会再处理但是处于审计的目的会记录
continue continue
# 保护禁止跨角色合并用户实体和AI助手实体不能互相合并
if _would_merge_cross_role(a, b):
llm_records.append(
f"[LLM阻断] 跨角色合并被阻止: {a.id} ({a.name}) 与 {b.id} ({b.name})"
)
continue
_merge_attribute(a, b) _merge_attribute(a, b)
# ID 重定向 # ID 重定向
try: try:
@@ -934,6 +1107,9 @@ async def deduplicate_entities_and_edges(
返回:去重后的实体、语句→实体边、实体↔实体边。 返回:去重后的实体、语句→实体边、实体↔实体边。
""" """
local_llm_records: List[str] = [] # 作为“审计日志”的本地收集器 初始化保留为了之后对于LLM决策追溯 local_llm_records: List[str] = [] # 作为“审计日志”的本地收集器 初始化保留为了之后对于LLM决策追溯
# 0) 标准化用户和AI助手实体名称确保多轮对话中的变体名称统一
_normalize_special_entity_names(entity_nodes)
# 1) 精确匹配 # 1) 精确匹配
deduped_entities, id_redirect, exact_merge_map = accurate_match(entity_nodes) deduped_entities, id_redirect, exact_merge_map = accurate_match(entity_nodes)

View File

@@ -15,6 +15,7 @@ from app.core.memory.models.message_models import DialogData
from app.core.memory.models.variate_config import ExtractionPipelineConfig from app.core.memory.models.variate_config import ExtractionPipelineConfig
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import ( from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
deduplicate_entities_and_edges, deduplicate_entities_and_edges,
clean_cross_role_aliases,
) )
from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import ( from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import (
second_layer_dedup_and_merge_with_neo4j, second_layer_dedup_and_merge_with_neo4j,
@@ -100,6 +101,10 @@ async def dedup_layers_and_merge_and_return(
except Exception as e: except Exception as e:
print(f"Second-layer dedup failed: {e}") print(f"Second-layer dedup failed: {e}")
# 第二层去重后,清洗用户/AI助手之间的别名交叉污染
# 第二层从 Neo4j 合并了旧实体,可能带入历史脏数据
clean_cross_role_aliases(fused_entity_nodes)
return ( return (
dialogue_nodes, dialogue_nodes,
chunk_nodes, chunk_nodes,

View File

@@ -44,6 +44,10 @@ from app.core.memory.models.variate_config import (
from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import ( from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import (
dedup_layers_and_merge_and_return, dedup_layers_and_merge_and_return,
) )
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
_USER_PLACEHOLDER_NAMES,
fetch_neo4j_assistant_aliases,
)
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import ( from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import (
embedding_generation, embedding_generation,
generate_entity_embeddings_from_triplets, generate_entity_embeddings_from_triplets,
@@ -1341,14 +1345,20 @@ class ExtractionOrchestrator:
dialog_data_list: List[DialogData] dialog_data_list: List[DialogData]
) -> None: ) -> None:
""" """
从 Neo4j 读取用户实体的最终 aliases同步到 end_user 和 end_user_info 表 将本轮提取的用户别名同步到 end_user 和 end_user_info 表
注意: 注意:此方法在 Neo4j 写入之前调用,因此不能依赖 Neo4j 作为别名的权威数据源。
1. other_name 使用本次对话提取的第一个别名(保持时间顺序) 改为直接使用内存中去重后的 entity_nodes 的 aliases与 PgSQL 已有的 aliases 合并。
2. aliases 从 Neo4j 读取(保持完整性)
策略:
1. 从内存中的 entity_nodes 提取本轮用户别名current_aliases
2. 从去重后的 entity_nodes 中提取完整别名(含 Neo4j 二层去重合并的历史别名)
3. 从 PgSQL end_user_info 读取已有的 aliasesdb_aliases
4. 合并 db_aliases + deduped_aliases + current_aliases去重保序
5. 写回 PgSQL
Args: Args:
entity_nodes: 实体节点列表 entity_nodes: 去重后的实体节点列表(内存中,含二层去重合并结果)
dialog_data_list: 对话数据列表 dialog_data_list: 对话数据列表
""" """
try: try:
@@ -1361,23 +1371,40 @@ class ExtractionOrchestrator:
logger.warning("end_user_id 为空,跳过用户别名同步") logger.warning("end_user_id 为空,跳过用户别名同步")
return return
# 1. 提取本对话的用户别名(保持 LLM 提取的原始顺序,不排序) # 1. 提取本对话的用户别名(保持 LLM 提取的原始顺序,不排序)
current_aliases = self._extract_current_aliases(entity_nodes) current_aliases = self._extract_current_aliases(entity_nodes, dialog_data_list)
# 2. 从 Neo4j 获取完整 aliases权威数据源 # 1.5 从去重后的 entity_nodes 中提取完整别名
neo4j_aliases = await self._fetch_neo4j_user_aliases(end_user_id) # 二层去重会将 Neo4j 中已有的历史别名合并到 entity_nodes 中,
# 这里提取出来确保 PgSQL 与 Neo4j 的别名保持同步
deduped_aliases = self._extract_deduped_entity_aliases(entity_nodes)
if not neo4j_aliases: # 1.6 从 Neo4j 查询已有的 AI 助手别名,作为额外的排除源
# Neo4j 中没有别名,使用本次对话提取的别名 # (防止 LLM 未提取出 AI 助手实体时AI 别名泄漏到用户别名中)
neo4j_aliases = current_aliases neo4j_assistant_aliases = await self._fetch_neo4j_assistant_aliases(end_user_id)
if not neo4j_aliases: if neo4j_assistant_aliases:
logger.debug(f"aliases 为空,跳过同步: end_user_id={end_user_id}") before_count = len(current_aliases)
return current_aliases = [
a for a in current_aliases
if a.strip().lower() not in neo4j_assistant_aliases
]
if len(current_aliases) < before_count:
logger.info(f"通过 Neo4j AI 助手别名排除了 {before_count - len(current_aliases)} 个误归属别名")
# 同样过滤 deduped_aliases
deduped_aliases = [
a for a in deduped_aliases
if a.strip().lower() not in neo4j_assistant_aliases
]
logger.info(f"本次对话提取的 aliases: {current_aliases}") if not current_aliases and not deduped_aliases:
logger.info(f"Neo4j 中的完整 aliases: {neo4j_aliases}") logger.debug(f"本轮未提取到用户别名,跳过同步: end_user_id={end_user_id}")
return
# 3. 同步到数据库 logger.info(f"本轮对话提取的 aliases: {current_aliases}")
if deduped_aliases:
logger.info(f"去重后实体的完整 aliases含历史: {deduped_aliases}")
# 2. 同步到数据库
end_user_uuid = uuid.UUID(end_user_id) end_user_uuid = uuid.UUID(end_user_id)
with get_db_context() as db: with get_db_context() as db:
# 更新 end_user 表 # 更新 end_user 表
@@ -1386,7 +1413,38 @@ class ExtractionOrchestrator:
logger.warning(f"未找到 end_user_id={end_user_id} 的用户记录") logger.warning(f"未找到 end_user_id={end_user_id} 的用户记录")
return return
new_name = self._resolve_other_name(end_user.other_name, current_aliases, neo4j_aliases) # 3. 从 PgSQL 读取已有 aliases 并与本轮合并
info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid)
db_aliases = (info.aliases if info and info.aliases else [])
# 过滤掉占位名称
db_aliases = [a for a in db_aliases if a.strip().lower() not in self.USER_PLACEHOLDER_NAMES]
# 合并:已有 + 去重后完整别名 + 本轮新增,去重保序
merged_aliases = list(db_aliases)
seen_lower = {a.strip().lower() for a in merged_aliases}
# 先合并去重后实体的完整别名(含 Neo4j 历史别名)
for alias in deduped_aliases:
if alias.strip().lower() not in seen_lower:
merged_aliases.append(alias)
seen_lower.add(alias.strip().lower())
# 再合并本轮新提取的别名
for alias in current_aliases:
if alias.strip().lower() not in seen_lower:
merged_aliases.append(alias)
seen_lower.add(alias.strip().lower())
# 最终过滤:从合并结果中排除 AI 助手别名(清理历史脏数据)
if neo4j_assistant_aliases:
merged_aliases = [
a for a in merged_aliases
if a.strip().lower() not in neo4j_assistant_aliases
]
logger.info(f"PgSQL 已有 aliases: {db_aliases}")
logger.info(f"合并后 aliases: {merged_aliases}")
# 更新 end_user 表 other_name
new_name = self._resolve_other_name(end_user.other_name, current_aliases, merged_aliases)
if new_name is not None: if new_name is not None:
end_user.other_name = new_name end_user.other_name = new_name
logger.info(f"更新 end_user 表 other_name → {new_name}") logger.info(f"更新 end_user 表 other_name → {new_name}")
@@ -1394,26 +1452,27 @@ class ExtractionOrchestrator:
logger.debug(f"end_user 表 other_name 保持不变: {end_user.other_name}") logger.debug(f"end_user 表 other_name 保持不变: {end_user.other_name}")
# 更新或创建 end_user_info 记录 # 更新或创建 end_user_info 记录
info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid)
if info: if info:
new_name_info = self._resolve_other_name(info.other_name, current_aliases, neo4j_aliases) new_name_info = self._resolve_other_name(info.other_name, current_aliases, merged_aliases)
if new_name_info is not None: if new_name_info is not None:
info.other_name = new_name_info info.other_name = new_name_info
logger.info(f"更新 end_user_info 表 other_name → {new_name_info}") logger.info(f"更新 end_user_info 表 other_name → {new_name_info}")
if info.aliases != neo4j_aliases: if info.aliases != merged_aliases:
info.aliases = neo4j_aliases info.aliases = merged_aliases
logger.info(f"同步 Neo4j aliases 到 end_user_info: {neo4j_aliases}") logger.info(f"同步合并后 aliases 到 end_user_info: {merged_aliases}")
else: else:
first_alias = current_aliases[0].strip() if current_aliases else "" first_alias = current_aliases[0].strip() if current_aliases else (
deduped_aliases[0].strip() if deduped_aliases else ""
)
# 确保 first_alias 不是占位名称 # 确保 first_alias 不是占位名称
if first_alias and first_alias not in self.USER_PLACEHOLDER_NAMES: if first_alias and first_alias.lower() not in self.USER_PLACEHOLDER_NAMES:
db.add(EndUserInfo( db.add(EndUserInfo(
end_user_id=end_user_uuid, end_user_id=end_user_uuid,
other_name=first_alias, other_name=first_alias,
aliases=neo4j_aliases, aliases=merged_aliases,
meta_data={} meta_data={}
)) ))
logger.info(f"创建 end_user_info 记录other_name={first_alias}, aliases={neo4j_aliases}") logger.info(f"创建 end_user_info 记录other_name={first_alias}, aliases={merged_aliases}")
db.commit() db.commit()
@@ -1423,49 +1482,81 @@ class ExtractionOrchestrator:
# 用户实体占位名称,不允许作为 other_name 或出现在 aliases 中 # 用户实体占位名称,不允许作为 other_name 或出现在 aliases 中
USER_PLACEHOLDER_NAMES = {'用户', '', 'User', 'I'} # 复用 deduped_and_disamb 模块级常量,避免重复维护
USER_PLACEHOLDER_NAMES = _USER_PLACEHOLDER_NAMES
def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode]) -> List[str]: def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode], dialog_data_list=None) -> List[str]:
"""实体节点提取用户别名(保持 LLM 提取的原始顺序,不进行任何排序 """用户发言的原始实体中提取本轮新增别名(绕过去重污染
这个方法直接返回 LLM 提取的别名列表,并过滤掉占位名称("用户""""User""I")。 策略:
第一个别名将被用作 other_name。 仅从 dialog_data_list 中找到 speaker="user" 的 statement
从这些 statement 的 triplet_extraction_info 中提取用户实体的 aliases。
这样拿到的是 LLM 对用户原话的提取结果,不受去重合并的影响。
注意:不再使用去重后 entity_nodes 作为兜底,因为二层去重会将 Neo4j 历史别名
合并进来,导致历史别名被误认为"本轮提取"。历史别名的同步由
_extract_deduped_entity_aliases 负责。
Args: Args:
entity_nodes: 实体节点列表 entity_nodes: 去重后的实体节点列表(未使用,保留参数兼容性)
dialog_data_list: 对话数据列表
Returns: Returns:
别名列表(保持 LLM 提取的原始顺序,已过滤占位名称 别名列表(保持原始顺序,已过滤)
"""
if not dialog_data_list:
return []
all_user_aliases = []
seen_lower = set()
for dialog in dialog_data_list:
for chunk in dialog.chunks:
speaker = getattr(chunk, 'speaker', None)
for statement in chunk.statements:
stmt_speaker = getattr(statement, 'speaker', None) or speaker
if stmt_speaker != "user":
continue
triplet_info = getattr(statement, 'triplet_extraction_info', None)
if not triplet_info:
continue
for entity in (triplet_info.entities or []):
ent_name = getattr(entity, 'name', '').strip()
if ent_name.lower() in self.USER_PLACEHOLDER_NAMES:
for alias in (getattr(entity, 'aliases', []) or []):
a = alias.strip()
if a and a.lower() not in self.USER_PLACEHOLDER_NAMES and a.lower() not in seen_lower:
all_user_aliases.append(a)
seen_lower.add(a.lower())
if all_user_aliases:
logger.debug(f"从用户原始发言提取到别名: {all_user_aliases}")
return all_user_aliases
def _extract_deduped_entity_aliases(self, entity_nodes: List[ExtractedEntityNode]) -> List[str]:
"""从去重后的用户实体中提取完整别名列表。
二层去重会将 Neo4j 中已有的历史别名合并到 entity_nodes 的用户实体中,
因此这里提取到的别名包含了历史积累的所有别名,可用于同步到 PgSQL。
Args:
entity_nodes: 去重后的实体节点列表(含二层去重合并结果)
Returns:
别名列表(已过滤占位名称,去重保序)
""" """
for entity in entity_nodes: for entity in entity_nodes:
if getattr(entity, 'name', '').strip() in self.USER_PLACEHOLDER_NAMES: if getattr(entity, 'name', '').strip().lower() in self.USER_PLACEHOLDER_NAMES:
aliases = getattr(entity, 'aliases', []) or [] aliases = getattr(entity, 'aliases', []) or []
# 过滤掉占位名称,防止 "用户"/"我"/"User"/"I" 被存入 aliases 和 other_name filtered = [
filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES] a for a in aliases
logger.debug(f"提取到用户别名(原始顺序,已过滤占位名称): {filtered}") if a.strip().lower() not in self.USER_PLACEHOLDER_NAMES
return filtered ]
if filtered:
return filtered
return [] return []
async def _fetch_neo4j_assistant_aliases(self, end_user_id: str) -> set:
async def _fetch_neo4j_user_aliases(self, end_user_id: str) -> List[str]: """从 Neo4j 查询 AI 助手实体的所有别名(用于从用户别名中排除)"""
"""从 Neo4j 查询用户实体的完整 aliases 列表(已过滤占位名称)""" return await fetch_neo4j_assistant_aliases(self.connector, end_user_id)
cypher = """
MATCH (e:ExtractedEntity)
WHERE e.end_user_id = $end_user_id AND e.name IN ['用户', '', 'User', 'I']
RETURN e.aliases AS aliases
LIMIT 1
"""
result = await Neo4jConnector().execute_query(cypher, end_user_id=end_user_id)
if not result:
logger.debug(f"Neo4j 中未找到用户实体: end_user_id={end_user_id}")
return []
aliases = result[0].get('aliases') or []
if not aliases:
logger.debug(f"Neo4j 用户实体 aliases 为空: end_user_id={end_user_id}")
return []
# 过滤掉占位名称,防止历史脏数据传播
filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES]
return filtered
def _resolve_other_name( def _resolve_other_name(
self, self,
@@ -1484,16 +1575,16 @@ class ExtractionOrchestrator:
注意:返回值不允许是占位名称("用户""""User""I" 注意:返回值不允许是占位名称("用户""""User""I"
""" """
# 当前值为空或为占位名称时,需要更新 # 当前值为空或为占位名称时,需要更新
if not current or not current.strip() or current.strip() in self.USER_PLACEHOLDER_NAMES: if not current or not current.strip() or current.strip().lower() in self.USER_PLACEHOLDER_NAMES:
candidate = current_aliases[0].strip() if current_aliases else None candidate = current_aliases[0].strip() if current_aliases else None
# 确保候选值不是占位名称 # 确保候选值不是占位名称
if candidate and candidate in self.USER_PLACEHOLDER_NAMES: if candidate and candidate.lower() in self.USER_PLACEHOLDER_NAMES:
return None return None
return candidate return candidate
if current not in neo4j_aliases: if current not in neo4j_aliases:
candidate = neo4j_aliases[0].strip() if neo4j_aliases else None candidate = neo4j_aliases[0].strip() if neo4j_aliases else None
# 确保候选值不是占位名称 # 确保候选值不是占位名称
if candidate and candidate in self.USER_PLACEHOLDER_NAMES: if candidate and candidate.lower() in self.USER_PLACEHOLDER_NAMES:
return None return None
return candidate return candidate

View File

@@ -61,6 +61,7 @@ class TripletExtractor:
predicate_instructions=PREDICATE_DEFINITIONS, predicate_instructions=PREDICATE_DEFINITIONS,
language=self._get_language(), language=self._get_language(),
ontology_types=self.ontology_types, ontology_types=self.ontology_types,
speaker=getattr(statement, 'speaker', None),
) )
# Create messages for LLM # Create messages for LLM

View File

@@ -1,6 +1,6 @@
import os import os
from jinja2 import Environment, FileSystemLoader from jinja2 import Environment, FileSystemLoader
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
from app.core.memory.utils.log.logging_utils import log_prompt_rendering, log_template_rendering from app.core.memory.utils.log.logging_utils import log_prompt_rendering, log_template_rendering
# Setup Jinja2 environment # Setup Jinja2 environment
@@ -205,6 +205,7 @@ async def render_triplet_extraction_prompt(
predicate_instructions: dict = None, predicate_instructions: dict = None,
language: str = "zh", language: str = "zh",
ontology_types: "OntologyTypeList | None" = None, ontology_types: "OntologyTypeList | None" = None,
speaker: str = None,
) -> str: ) -> str:
""" """
Renders the triplet extraction prompt using the extract_triplet.jinja2 template. Renders the triplet extraction prompt using the extract_triplet.jinja2 template.
@@ -216,6 +217,7 @@ async def render_triplet_extraction_prompt(
predicate_instructions: Optional predicate instructions predicate_instructions: Optional predicate instructions
language: The language to use for entity descriptions ("zh" for Chinese, "en" for English) language: The language to use for entity descriptions ("zh" for Chinese, "en" for English)
ontology_types: Optional OntologyTypeList containing predefined ontology types for entity classification ontology_types: Optional OntologyTypeList containing predefined ontology types for entity classification
speaker: Speaker role ("user" or "assistant") for the current statement
Returns: Returns:
Rendered prompt content as string Rendered prompt content as string
@@ -223,7 +225,7 @@ async def render_triplet_extraction_prompt(
template = prompt_env.get_template("extract_triplet.jinja2") template = prompt_env.get_template("extract_triplet.jinja2")
# 准备本体类型数据 # 准备本体类型数据
ontology_type_section = "" ontology_type_section = None
ontology_type_names = [] ontology_type_names = []
type_hierarchy_hints = [] type_hierarchy_hints = []
if ontology_types and ontology_types.types: if ontology_types and ontology_types.types:
@@ -240,6 +242,7 @@ async def render_triplet_extraction_prompt(
ontology_types=ontology_type_section, ontology_types=ontology_type_section,
ontology_type_names=ontology_type_names, ontology_type_names=ontology_type_names,
type_hierarchy_hints=type_hierarchy_hints, type_hierarchy_hints=type_hierarchy_hints,
speaker=speaker,
) )
# 记录渲染结果到提示日志(与示例日志结构一致) # 记录渲染结果到提示日志(与示例日志结构一致)
log_prompt_rendering('triplet extraction', rendered_prompt) log_prompt_rendering('triplet extraction', rendered_prompt)

View File

@@ -23,6 +23,16 @@ Extract entities and knowledge triplets from the given statement.
===Inputs=== ===Inputs===
**Chunk Content:** "{{ chunk_content }}" **Chunk Content:** "{{ chunk_content }}"
**Statement:** "{{ statement }}" **Statement:** "{{ statement }}"
{% if speaker %}
**Speaker:** {{ speaker }}
{% if speaker == "assistant" %}
{% if language == "zh" %}
⚠️ 当前陈述句来自 **AI助手的回复**。AI助手在回复中用来称呼用户的名字是**用户的别名**,不是 AI 助手的别名。但只能提取原文中逐字出现的名字,严禁推测或创造原文中不存在的别名变体。
{% else %}
⚠️ This statement is from the **AI assistant's reply**. Names the AI uses to address the user are **user's aliases**, NOT the AI assistant's aliases. But only extract names that appear VERBATIM in the text — never infer or fabricate alias variants.
{% endif %}
{% endif %}
{% endif %}
{% if ontology_types %} {% if ontology_types %}
===Ontology Type Guidance=== ===Ontology Type Guidance===
@@ -87,7 +97,17 @@ Extract entities and knowledge triplets from the given statement.
* "我叫张三,大家叫我小张" → aliases=["张三", "小张"](张三是第一个,将成为 other_name * "我叫张三,大家叫我小张" → aliases=["张三", "小张"](张三是第一个,将成为 other_name
* "大家叫我小李,我全名叫李明" → aliases=["小李", "李明"](小李先出现,将成为 other_name * "大家叫我小李,我全名叫李明" → aliases=["小李", "李明"](小李先出现,将成为 other_name
- 空值:如果没有别名,使用 `[]` - 空值:如果没有别名,使用 `[]`
- 重要:只提取本次对话中明确提到的别名,不要推测或添加未提及的名字 - **🚨🚨🚨 严禁幻觉:只提取对话原文中逐字出现的别名,绝对不能推测、衍生或创造任何未在原文中出现的名字。例如,看到"陈思远"不能自行添加"思远大人""远哥""小远"等变体。如果原文没有这些字,就不能出现在 aliases 中。**
- **🚨 归属区分:必须严格区分名称的归属对象。默认情况下,用户提到的名字归属用户实体。只有出现明确的第二人称命名表达(如"叫你""给你取名")时,才将名字归属 AI/助手实体。**
- **🚨 说话人视角:当 speaker 为 assistant 时AI 助手用来称呼用户的名字是用户的别名,必须归入用户实体的 aliases绝对不能归入 AI 助手实体。但同样只能提取原文中逐字出现的称呼,不能推测。**
* "我叫陈思远我给AI取名为远仔" → 用户 aliases=["陈思远"]AI助手 aliases=["远仔"]
* "我叫vv" → 用户 aliases=["vv"]没有给AI取名的表达名字归用户
* [speaker=assistant] "好的VV" → 用户 aliases=["VV"]AI 在称呼用户,原文中出现了"VV"
* [speaker=assistant] "我叫陈仔" → AI助手 aliases=["陈仔"]AI 在自我介绍,这是 AI 的别名)
* ❌ 错误:将"远仔"放入用户的 aliases"远仔"是给AI取的名字不是用户的名字
* ❌ 错误:用户说"我叫vv",却把"vv"放入 AI 助手的 aliases
* ❌ 错误AI 称呼用户为"VV",却把"VV"放入 AI 助手的 aliases
* ❌ 错误:原文只有"陈思远",却在 aliases 中添加"思远大人""远哥""小远"等从未出现的变体(这是幻觉)
{% else %} {% else %}
- Include: nicknames, full names, abbreviations, alternative names - Include: nicknames, full names, abbreviations, alternative names
- Order: **The FIRST alias will be used as the user's primary display name (other_name). Put the most important/frequently used name FIRST** - Order: **The FIRST alias will be used as the user's primary display name (other_name). Put the most important/frequently used name FIRST**
@@ -96,7 +116,17 @@ Extract entities and knowledge triplets from the given statement.
* "I'm John, people call me Johnny" → aliases=["John", "Johnny"] (John is first, will become other_name) * "I'm John, people call me Johnny" → aliases=["John", "Johnny"] (John is first, will become other_name)
* "People call me Mike, my full name is Michael" → aliases=["Mike", "Michael"] (Mike appears first, will become other_name) * "People call me Mike, my full name is Michael" → aliases=["Mike", "Michael"] (Mike appears first, will become other_name)
- Empty: If no aliases, use `[]` - Empty: If no aliases, use `[]`
- Important: Only extract aliases explicitly mentioned in current conversation, do not infer or add unmentioned names - **🚨🚨🚨 NO HALLUCINATION: Only extract aliases that appear VERBATIM in the original text. NEVER infer, derive, or fabricate names not present in the text. For example, seeing "John Smith" does NOT allow adding "Johnny", "Smithy", "Mr. Smith" unless those exact strings appear in the conversation.**
- **🚨 Ownership distinction: By default, all names mentioned by the user belong to the user entity. Only assign a name to the AI/assistant entity when an explicit second-person naming expression (e.g., "I'll call you", "your name is") is present.**
- **🚨 Speaker perspective: When speaker is "assistant", names the AI uses to address the user are the USER's aliases and MUST go into the user entity's aliases, NEVER into the AI assistant entity's aliases. But only extract names that appear verbatim in the text, never infer.**
* "I'm Alex, I'll call you Buddy" → User aliases=["Alex"], AI assistant aliases=["Buddy"]
* "I'm vv" → User aliases=["vv"] (no AI-naming expression, name belongs to user)
* [speaker=assistant] "Sure thing, VV" → User aliases=["VV"] (AI addressing the user, "VV" appears in text)
* [speaker=assistant] "I'm Jarvis" → AI assistant aliases=["Jarvis"] (AI self-introduction, this is AI's alias)
* ❌ Wrong: putting "Buddy" in user's aliases ("Buddy" is a name for the AI, not the user)
* ❌ Wrong: User says "I'm vv" but "vv" is put in AI assistant's aliases
* ❌ Wrong: AI calls user "VV" but "VV" is put in AI assistant's aliases
* ❌ Wrong: Text only has "John Smith" but aliases include "Johnny", "Smithy" (hallucinated variants)
{% endif %} {% endif %}
@@ -122,7 +152,60 @@ Extract entities and knowledge triplets from the given statement.
4. **ALIASES ORDER:** 4. **AI/ASSISTANT ENTITY SPECIAL HANDLING:**
{% if language == "zh" %}
- **🚨 默认规则:如果对话中没有出现明确指向 AI/助手的命名表达,则所有名字都归属于用户实体。不要猜测或推断某个名字是给 AI 取的。**
- 只有当用户**明确**对 AI/助手进行命名时,才创建 AI/助手实体并将对应名字放入其 aliases
- AI/助手实体的 name 字段:使用 "AI助手"
- 用户给 AI 取的名字:放入 AI/助手实体的 aliases
- **🚨 禁止将用户给 AI 取的名字放入用户实体的 aliases 中**
- **必须出现以下明确的命名表达才能判定为给 AI 取名:**「给你取名」「叫你」「称呼你为」「给AI取名」「你的名字是」「以后叫你」「你就叫」「你不叫X了」「你现在叫」等**第二人称(你)或明确指向 AI 的命名句式**
- **🚨 "你不叫X了"/"你不叫X你叫Y" 句式X 和 Y 都是 AI 的名字(旧名和新名),绝对不是用户的名字。因为句子主语是"你"AI。**
- **以下情况名字归属用户,不是给 AI 取名:**「我叫」「我的名字是」「叫我」「我是」「大家叫我」「我的英文名是」「我的昵称是」等**第一人称(我)的自我介绍句式**
- **🚨 speaker=assistant 时的特殊规则:**
* AI 用来称呼用户的名字 → 归入**用户**实体的 aliases但必须是原文中逐字出现的称呼不能推测
* AI 自称的名字(如"我叫陈仔""我是你的助手")→ 归入**AI助手**实体的 aliases
* 判断依据AI 说"你叫X"或用 X 称呼用户 → X 是用户别名AI 说"我叫X"或"我是X" → X 是 AI 别名
- 示例:
* "我叫vv" → 用户实体: name="用户", aliases=["vv"](第一人称自我介绍,名字归用户)
* "我的英文名叫vv" → 用户实体: name="用户", aliases=["vv"](第一人称自我介绍,名字归用户)
* "我叫陈思远我给AI取名为远仔" → 用户实体: name="用户", aliases=["陈思远"]AI实体: name="AI助手", aliases=["远仔"]
* "叫你小助,我自己叫老王" → 用户实体: name="用户", aliases=["老王"]AI实体: name="AI助手", aliases=["小助"]
* "你不叫远仔了,你现在叫陈仔" → AI实体: name="AI助手", aliases=["陈仔"]"远仔"是AI旧名"陈仔"是AI新名都归AI。不要把"远仔"或"陈仔"放入用户的aliases
* [speaker=assistant] "好的VV今天想干点啥" → 用户实体: name="用户", aliases=["VV"]AI 在称呼用户,原文中出现了"VV"
* [speaker=assistant] "你叫陈思远,我叫陈仔" → 用户实体: name="用户", aliases=["陈思远"]AI实体: name="AI助手", aliases=["陈仔"]
* ❌ 错误:用户说"我叫vv",却把"vv"放入 AI 助手的 aliases没有任何给 AI 取名的表达)
* ❌ 错误AI 称呼用户为"VV",却把"VV"放入 AI 助手的 aliases
* ❌ 错误aliases=["陈思远", "远仔"]"远仔"是给AI取的名字不是用户的名字
* ❌ 错误:原文只有"陈思远",却在 aliases 中添加"思远大人""远哥""小远"等从未出现的变体(这是幻觉)
{% else %}
- **🚨 Default rule: If there is NO explicit AI/assistant naming expression in the conversation, ALL names belong to the user entity. Do NOT guess or infer that a name is for the AI.**
- Only create an AI/assistant entity when the user **explicitly** names the AI/assistant
- AI/assistant entity name field: use "AI Assistant"
- Names the user gives to the AI: put in the AI/assistant entity's aliases
- **🚨 NEVER put names given to the AI into the user entity's aliases**
- **An AI-naming expression MUST be present to assign a name to the AI:** "I'll call you", "your name is", "I name you", "let me call you", "you'll be called", "you're not called X anymore", "your new name is", etc. — **second-person ("you") or explicit AI-directed naming patterns**
- **🚨 "You're not called X anymore" / "You're not X, you're Y" pattern: BOTH X and Y are AI's names (old and new). They are NOT user's names. The subject is "you" (the AI).**
- **These patterns mean the name belongs to the USER, NOT the AI:** "I'm", "my name is", "call me", "I am", "people call me", "my English name is", "my nickname is", etc. — **first-person ("I"/"me") self-introduction patterns**
- **🚨 Special rules when speaker=assistant:**
* Names the AI uses to address the user → belong to the **user** entity's aliases (but only extract names that appear verbatim in the text, never infer)
* Names the AI uses for itself (e.g., "I'm Jarvis", "I am your assistant") → belong to the **AI assistant** entity's aliases
* Rule: AI says "you are X" or calls user X → X is user's alias; AI says "I'm X" or "I am X" → X is AI's alias
- Examples:
* "I'm vv" → User entity: name="User", aliases=["vv"] (first-person intro, name belongs to user)
* "My English name is vv" → User entity: name="User", aliases=["vv"] (first-person intro, name belongs to user)
* "I'm Alex, I'll call you Buddy" → User entity: name="User", aliases=["Alex"]; AI entity: name="AI Assistant", aliases=["Buddy"]
* "Call yourself Jarvis, my name is Tony" → User entity: name="User", aliases=["Tony"]; AI entity: name="AI Assistant", aliases=["Jarvis"]
* "You're not called Jarvis anymore, your new name is Friday" → AI entity: name="AI Assistant", aliases=["Friday"] (both "Jarvis" and "Friday" are AI names, NOT user names)
* [speaker=assistant] "Sure thing, VV" → User entity: name="User", aliases=["VV"] (AI addressing the user, "VV" appears in text)
* [speaker=assistant] "You're Alex, and I'm Jarvis" → User entity: name="User", aliases=["Alex"]; AI entity: name="AI Assistant", aliases=["Jarvis"]
* ❌ Wrong: User says "I'm vv" but "vv" is put in AI assistant's aliases (no AI-naming expression exists)
* ❌ Wrong: AI calls user "VV" but "VV" is put in AI assistant's aliases
* ❌ Wrong: aliases=["Alex", "Buddy"] ("Buddy" is a name for the AI, not the user)
* ❌ Wrong: Text only has "John Smith" but aliases include "Johnny", "Smithy" (hallucinated variants)
{% endif %}
5. **ALIASES ORDER:**
{% if language == "zh" %} {% if language == "zh" %}
- 顺序优先级:按出现顺序,先出现的在前 - 顺序优先级:按出现顺序,先出现的在前
{% else %} {% else %}
@@ -202,8 +285,19 @@ Output:
{"entity_idx": 0, "name": "Tripod", "type": "Equipment", "description": "Photography equipment accessory", "example": "", "aliases": ["Camera Tripod"], "is_explicit_memory": false} {"entity_idx": 0, "name": "Tripod", "type": "Equipment", "description": "Photography equipment accessory", "example": "", "aliases": ["Camera Tripod"], "is_explicit_memory": false}
] ]
} }
**Example 4 (User vs AI alias distinction - English output):** "I'm Alex, and I'll call you Buddy"
Output:
{
"triplets": [
{"subject_name": "User", "subject_id": 0, "predicate": "NAMED", "object_name": "AI Assistant", "object_id": 1, "value": "Buddy"}
],
"entities": [
{"entity_idx": 0, "name": "User", "type": "Person", "description": "The user", "example": "", "aliases": ["Alex"], "is_explicit_memory": false},
{"entity_idx": 1, "name": "AI Assistant", "type": "Person", "description": "The user's AI assistant", "example": "", "aliases": ["Buddy"], "is_explicit_memory": false}
]
}
{% else %} {% else %}
**Example 1 (English input → Chinese output):** "I plan to travel to Paris next week and visit the Louvre."
Output: Output:
{ {
"triplets": [ "triplets": [
@@ -258,6 +352,39 @@ Output:
] ]
} }
**Example 6 (用户与AI别名区分 - Chinese):** "我称呼自己为陈思远我给AI取名为远仔"
Output:
{
"triplets": [
{"subject_name": "用户", "subject_id": 0, "predicate": "NAMED", "object_name": "AI助手", "object_id": 1, "value": "远仔"}
],
"entities": [
{"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["陈思远"], "is_explicit_memory": false},
{"entity_idx": 1, "name": "AI助手", "type": "Person", "description": "用户的AI助手", "example": "", "aliases": ["远仔"], "is_explicit_memory": false}
]
}
**Example 7 (纯用户自我介绍无AI命名 - Chinese):** "我叫vv"
Output:
{
"triplets": [],
"entities": [
{"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["vv"], "is_explicit_memory": false}
]
}
**Example 8 (给AI改名 - Chinese):** "你不叫远仔了,你现在叫陈仔"
Output:
{
"triplets": [
{"subject_name": "用户", "subject_id": 0, "predicate": "NAMED", "object_name": "AI助手", "object_id": 1, "value": "陈仔"}
],
"entities": [
{"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": [], "is_explicit_memory": false},
{"entity_idx": 1, "name": "AI助手", "type": "Person", "description": "用户的AI助手", "example": "", "aliases": ["陈仔"], "is_explicit_memory": false}
]
}
{% endif %} {% endif %}
===End of Examples=== ===End of Examples===

View File

@@ -25,8 +25,34 @@ class RedBearEmbeddings(Embeddings):
def _create_model(self, config: RedBearModelConfig) -> Embeddings: def _create_model(self, config: RedBearModelConfig) -> Embeddings:
"""根据配置创建 LangChain 模型""" """根据配置创建 LangChain 模型"""
embedding_class = get_provider_embedding_class(config.provider) embedding_class = get_provider_embedding_class(config.provider)
model_params = RedBearModelFactory.get_model_params(config) provider = config.provider.lower()
return embedding_class(**model_params) # Embedding models only need connection params, never LLM-specific ones
# (e.g. enable_thinking, model_kwargs) — build params directly.
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
import httpx
params = {
"model": config.model_name,
"base_url": config.base_url,
"api_key": config.api_key,
"timeout": httpx.Timeout(timeout=config.timeout, connect=60.0),
"max_retries": config.max_retries,
}
elif provider == ModelProvider.DASHSCOPE:
params = {
"model": config.model_name,
"dashscope_api_key": config.api_key,
"max_retries": config.max_retries,
}
elif provider == ModelProvider.OLLAMA:
params = {
"model": config.model_name,
"base_url": config.base_url,
}
elif provider == ModelProvider.BEDROCK:
params = RedBearModelFactory.get_model_params(config)
else:
params = RedBearModelFactory.get_model_params(config)
return embedding_class(**params)
def _create_volcano_client(self, config: RedBearModelConfig): def _create_volcano_client(self, config: RedBearModelConfig):
"""创建火山引擎客户端""" """创建火山引擎客户端"""

View File

@@ -6,14 +6,28 @@ ChatOpenAI 在解析流式 SSE 时只取 delta.content会丢弃 delta.reasoni
""" """
from __future__ import annotations from __future__ import annotations
from typing import Any, Optional from typing import Any, Optional, Union
from langchain_core.outputs import ChatGenerationChunk from langchain_core.outputs import ChatGenerationChunk, ChatResult
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
class VolcanoChatOpenAI(ChatOpenAI): class VolcanoChatOpenAI(ChatOpenAI):
"""火山引擎 Chat 模型支持深度思考内容reasoning_content的流式透传。""" """火山引擎 Chat 模型支持深度思考内容reasoning_content的流式和非流式透传。"""
def _create_chat_result(self, response: Union[dict, Any], generation_info: Optional[dict] = None) -> ChatResult:
result = super()._create_chat_result(response, generation_info)
# 将非流式响应中的 reasoning_content 补入 additional_kwargs
choices = response.choices if hasattr(response, "choices") else response.get("choices", [])
if choices:
message = choices[0].message if hasattr(choices[0], "message") else choices[0].get("message", {})
reasoning = (
getattr(message, "reasoning_content", None)
or (message.get("reasoning_content") if isinstance(message, dict) else None)
)
if reasoning and result.generations:
result.generations[0].message.additional_kwargs["reasoning_content"] = reasoning
return result
def _convert_chunk_to_generation_chunk( def _convert_chunk_to_generation_chunk(
self, self,

View File

@@ -27,7 +27,7 @@ class DateTimeTool(BuiltinTool):
type=ParameterType.STRING, type=ParameterType.STRING,
description="操作类型", description="操作类型",
required=True, required=True,
enum=["format", "convert_timezone", "timestamp_to_datetime", "now"] enum=["format", "convert_timezone", "timestamp_to_datetime", "now", "datetime_to_timestamp"]
), ),
ToolParameter( ToolParameter(
name="input_value", name="input_value",

View File

@@ -32,13 +32,16 @@ from app.core.workflow.nodes.configs import (
NoteNodeConfig, NoteNodeConfig,
ParameterExtractorNodeConfig, ParameterExtractorNodeConfig,
QuestionClassifierNodeConfig, QuestionClassifierNodeConfig,
VariableAggregatorNodeConfig VariableAggregatorNodeConfig,
ListOperatorNodeConfig,
DocExtractorNodeConfig,
) )
from app.core.workflow.nodes.cycle_graph.config import ( from app.core.workflow.nodes.cycle_graph.config import (
ConditionDetail as LoopConditionDetail, ConditionDetail as LoopConditionDetail,
ConditionsConfig, ConditionsConfig,
CycleVariable CycleVariable
) )
from app.core.workflow.nodes.list_operator.config import FilterCondition
from app.core.workflow.nodes.enums import ( from app.core.workflow.nodes.enums import (
ValueInputType, ValueInputType,
ComparisonOperator, ComparisonOperator,
@@ -90,6 +93,8 @@ class DifyConverter(BaseConverter):
NodeType.VAR_AGGREGATOR: self.convert_variable_aggregator_node_config, NodeType.VAR_AGGREGATOR: self.convert_variable_aggregator_node_config,
NodeType.TOOL: self.convert_tool_node_config, NodeType.TOOL: self.convert_tool_node_config,
NodeType.NOTES: self.convert_notes_config, NodeType.NOTES: self.convert_notes_config,
NodeType.LIST_OPERATOR: self.convert_list_operator_node_config,
NodeType.DOCUMENT_EXTRACTOR: self.convert_document_extractor_node_config,
NodeType.CYCLE_START: lambda x: {}, NodeType.CYCLE_START: lambda x: {},
NodeType.BREAK: lambda x: {}, NodeType.BREAK: lambda x: {},
} }
@@ -213,7 +218,9 @@ class DifyConverter(BaseConverter):
"end with": ComparisonOperator.END_WITH, "end with": ComparisonOperator.END_WITH,
"not contains": ComparisonOperator.NOT_CONTAINS, "not contains": ComparisonOperator.NOT_CONTAINS,
"exists": ComparisonOperator.NOT_EMPTY, "exists": ComparisonOperator.NOT_EMPTY,
"not exists": ComparisonOperator.EMPTY "not exists": ComparisonOperator.EMPTY,
"in": ComparisonOperator.IN,
"not in": ComparisonOperator.NOT_IN,
} }
return operator_map.get(operator, operator) return operator_map.get(operator, operator)
@@ -771,3 +778,46 @@ class DifyConverter(BaseConverter):
show_author=node_data.get("showAuthor", True) show_author=node_data.get("showAuthor", True)
).model_dump() ).model_dump()
return result return result
def convert_list_operator_node_config(self, node: dict) -> dict:
"""Dify list-operator — convert variable path array to {{ }} selector format."""
node_data = node["data"]
variable_path = node_data.get("variable", [])
input_list = self._process_list_variable_literal(variable_path) or ""
filter_by = node_data.get("filter_by", {"enabled": False, "conditions": []})
# Convert each condition's comparison_operator from Dify format to native
if filter_by.get("conditions"):
converted_conditions = []
for cond in filter_by["conditions"]:
converted_conditions.append({
**cond,
"comparison_operator": self.convert_compare_operator(
cond.get("comparison_operator", "")
)
})
filter_by = {**filter_by, "conditions": converted_conditions}
result = {
"input_list": input_list,
"filter_by": filter_by,
"order_by": node_data.get("order_by", {"enabled": False, "key": "", "value": "asc"}),
"limit": node_data.get("limit", {"enabled": False, "size": -1}),
"extract_by": node_data.get("extract_by", {"enabled": False, "serial": "1"}),
}
self.config_validate(node["id"], node["data"]["title"], ListOperatorNodeConfig, result)
return result
def convert_document_extractor_node_config(self, node: dict) -> dict:
"""Convert Dify document-extractor node to MemoryBear DocExtractorNodeConfig.
Dify document-extractor data fields:
variable_selector: list[str] - file variable path
"""
node_data = node["data"]
file_selector = self._process_list_variable_literal(
node_data.get("variable_selector", [])
) or ""
result = DocExtractorNodeConfig.model_construct(
file_selector=file_selector,
).model_dump()
self.config_validate(node["id"], node["data"]["title"], DocExtractorNodeConfig, result)
return result

View File

@@ -45,6 +45,8 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
"question-classifier": NodeType.QUESTION_CLASSIFIER, "question-classifier": NodeType.QUESTION_CLASSIFIER,
"variable-aggregator": NodeType.VAR_AGGREGATOR, "variable-aggregator": NodeType.VAR_AGGREGATOR,
"tool": NodeType.TOOL, "tool": NodeType.TOOL,
"list-operator": NodeType.LIST_OPERATOR,
"document-extractor": NodeType.DOCUMENT_EXTRACTOR,
"": NodeType.NOTES "": NodeType.NOTES
} }

View File

@@ -22,6 +22,8 @@ from app.core.workflow.nodes.configs import (
MemoryReadNodeConfig, MemoryReadNodeConfig,
MemoryWriteNodeConfig, MemoryWriteNodeConfig,
NoteNodeConfig, NoteNodeConfig,
ListOperatorNodeConfig,
DocExtractorNodeConfig,
) )
from app.core.workflow.nodes.enums import NodeType from app.core.workflow.nodes.enums import NodeType
@@ -51,6 +53,8 @@ class MemoryBearConverter(BaseConverter):
NodeType.MEMORY_READ: MemoryReadNodeConfig, NodeType.MEMORY_READ: MemoryReadNodeConfig,
NodeType.MEMORY_WRITE: MemoryWriteNodeConfig, NodeType.MEMORY_WRITE: MemoryWriteNodeConfig,
NodeType.NOTES: NoteNodeConfig, NodeType.NOTES: NoteNodeConfig,
NodeType.LIST_OPERATOR: ListOperatorNodeConfig,
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNodeConfig,
} }
@staticmethod @staticmethod

View File

@@ -59,6 +59,9 @@ class WorkflowResultBuilder:
conversation_vars = variable_pool.get_all_conversation_vars() conversation_vars = variable_pool.get_all_conversation_vars()
sys_vars = variable_pool.get_all_system_vars() sys_vars = variable_pool.get_all_system_vars()
# 汇总所有 knowledge 节点的 citations
citations = self.aggregate_citations(node_outputs)
return { return {
"status": "completed" if success else "failed", "status": "completed" if success else "failed",
"output": final_output, "output": final_output,
@@ -71,9 +74,25 @@ class WorkflowResultBuilder:
"conversation_id": execution_context.conversation_id, "conversation_id": execution_context.conversation_id,
"elapsed_time": elapsed_time, "elapsed_time": elapsed_time,
"token_usage": token_usage, "token_usage": token_usage,
"citations": citations,
"error": result.get("error"), "error": result.get("error"),
} }
@staticmethod
def aggregate_citations(node_outputs: dict) -> list:
"""从所有 knowledge 节点的输出中汇总 citations去重"""
seen = set()
citations = []
for node_output in node_outputs.values():
if not isinstance(node_output, dict):
continue
for c in node_output.get("citations", []):
key = c.get("document_id")
if key and key not in seen:
seen.add(key)
citations.append(c)
return citations
@staticmethod @staticmethod
def aggregate_token_usage(node_outputs: dict) -> dict[str, int] | None: def aggregate_token_usage(node_outputs: dict) -> dict[str, int] | None:
""" """

View File

@@ -318,7 +318,7 @@ class VariablePool:
namespace: str, namespace: str,
key: str, key: str,
value: Any, value: Any,
var_type: VariableType, var_type: VariableType | None,
mut: bool mut: bool
): ):
if self.has(f"{namespace}.{key}"): if self.has(f"{namespace}.{key}"):
@@ -493,6 +493,23 @@ class VariablePoolInitializer:
var_value = var_default var_value = var_default
else: else:
var_value = DEFAULT_VALUE(var_type) var_value = DEFAULT_VALUE(var_type)
# Convert FileInput-format dicts to full FileObject dicts
if var_type == VariableType.FILE:
if not var_value:
continue
var_value = await self._resolve_file_default(var_value)
if not var_value:
continue
elif var_type == VariableType.ARRAY_FILE:
if not var_value:
var_value = []
else:
resolved = []
for item in var_value:
f = await self._resolve_file_default(item)
if f:
resolved.append(f)
var_value = resolved
await variable_pool.new( await variable_pool.new(
namespace="conv", namespace="conv",
key=var_name, key=var_name,
@@ -501,6 +518,17 @@ class VariablePoolInitializer:
mut=True mut=True
) )
@staticmethod
async def _resolve_file_default(file_def: dict) -> dict | None:
"""Accept only already-resolved FileObject dicts (is_file=True).
FileInput-format dicts are converted at save time by WorkflowService._resolve_variables_file_defaults.
"""
if not isinstance(file_def, dict):
return None
if file_def.get("is_file"):
return file_def
return None
@staticmethod @staticmethod
async def _init_system_vars( async def _init_system_vars(
variable_pool: VariablePool, variable_pool: VariablePool,

View File

@@ -395,7 +395,8 @@ class BaseNode(ABC):
"output": output, "output": output,
"elapsed_time": elapsed_time, "elapsed_time": elapsed_time,
"token_usage": token_usage, "token_usage": token_usage,
"error": None "error": None,
**self._extract_extra_fields(business_result),
} }
final_output = { final_output = {
"node_outputs": {self.node_id: node_output}, "node_outputs": {self.node_id: node_output},
@@ -498,6 +499,13 @@ class BaseNode(ABC):
# Default implementation returns the business result directly # Default implementation returns the business result directly
return business_result return business_result
def _extract_extra_fields(self, business_result: Any) -> dict:
"""Extracts extra fields to merge into node_output (e.g. citations).
Subclasses may override to inject additional metadata.
"""
return {}
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None: def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
"""Extracts token usage information from the business result. """Extracts token usage information from the business result.

View File

@@ -13,7 +13,7 @@ from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes import BaseNode from app.core.workflow.nodes import BaseNode
from app.core.workflow.nodes.code.config import CodeNodeConfig from app.core.workflow.nodes.code.config import CodeNodeConfig
from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -70,7 +70,8 @@ class CodeNode(BaseNode):
for output in self.typed_config.output_variables: for output in self.typed_config.output_variables:
value = exec_result.get(output.name) value = exec_result.get(output.name)
if value is None: if value is None:
raise RuntimeError(f"Return value {output.name} does not exist") result[output.name] = DEFAULT_VALUE(output.type)
continue
match output.type: match output.type:
case VariableType.STRING: case VariableType.STRING:
if not isinstance(value, str): if not isinstance(value, str):

View File

@@ -24,6 +24,8 @@ from app.core.workflow.nodes.start.config import StartNodeConfig
from app.core.workflow.nodes.tool.config import ToolNodeConfig from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
from app.core.workflow.nodes.notes.config import NoteNodeConfig from app.core.workflow.nodes.notes.config import NoteNodeConfig
from app.core.workflow.nodes.list_operator.config import ListOperatorNodeConfig
from app.core.workflow.nodes.document_extractor.config import DocExtractorNodeConfig
__all__ = [ __all__ = [
# 基础类 # 基础类
@@ -49,5 +51,7 @@ __all__ = [
"MemoryReadNodeConfig", "MemoryReadNodeConfig",
"MemoryWriteNodeConfig", "MemoryWriteNodeConfig",
"CodeNodeConfig", "CodeNodeConfig",
"NoteNodeConfig" "NoteNodeConfig",
"ListOperatorNodeConfig",
"DocExtractorNodeConfig",
] ]

View File

@@ -14,12 +14,22 @@ logger = logging.getLogger(__name__)
def _file_object_to_file_input(f: FileObject) -> FileInput: def _file_object_to_file_input(f: FileObject) -> FileInput:
"""Convert workflow FileObject to multimodal FileInput.""" """Convert workflow FileObject to multimodal FileInput."""
file_type = f.origin_file_type or ""
# Prefer mime_type for more accurate type detection
if not file_type and f.mime_type:
file_type = f.mime_type
resolved_type = FileType.trans(f.type) if isinstance(f.type, str) else f.type
if resolved_type != FileType.DOCUMENT:
raise ValueError(
f"Document extractor only supports document files, got type '{f.type}' "
f"(name={f.name or f.file_id or f.url})"
)
return FileInput( return FileInput(
type=FileType.DOCUMENT, type=resolved_type,
transfer_method=TransferMethod(f.transfer_method), transfer_method=TransferMethod(f.transfer_method),
url=f.url or None, url=f.url or None,
upload_file_id=f.file_id or None, upload_file_id=f.file_id or None,
file_type=f.origin_file_type or "", file_type=file_type,
) )
@@ -81,6 +91,7 @@ class DocExtractorNode(BaseNode):
from app.services.multimodal_service import MultimodalService from app.services.multimodal_service import MultimodalService
svc = MultimodalService(db) svc = MultimodalService(db)
for f in files: for f in files:
label = f.name or f.url or f.file_id
try: try:
file_input = _file_object_to_file_input(f) file_input = _file_object_to_file_input(f)
# Ensure URL is populated for local files # Ensure URL is populated for local files
@@ -93,7 +104,7 @@ class DocExtractorNode(BaseNode):
chunks.append(text) chunks.append(text)
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Node {self.node_id}: failed to extract file url={f.url} file_id={f.file_id}: {e}", f"Node {self.node_id}: failed to extract file '{label}': {e}",
exc_info=True, exc_info=True,
) )
chunks.append("") chunks.append("")

View File

@@ -24,6 +24,7 @@ class NodeType(StrEnum):
MEMORY_READ = "memory-read" MEMORY_READ = "memory-read"
MEMORY_WRITE = "memory-write" MEMORY_WRITE = "memory-write"
DOCUMENT_EXTRACTOR = "document-extractor" DOCUMENT_EXTRACTOR = "document-extractor"
LIST_OPERATOR = "list-operator"
UNKNOWN = "unknown" UNKNOWN = "unknown"
NOTES = "notes" NOTES = "notes"
@@ -45,6 +46,8 @@ class ComparisonOperator(StrEnum):
LE = "le" LE = "le"
GT = "gt" GT = "gt"
GE = "ge" GE = "ge"
IN = "in"
NOT_IN = "not_in"
class LogicOperator(StrEnum): class LogicOperator(StrEnum):

View File

@@ -34,6 +34,20 @@ class KnowledgeRetrievalNode(BaseNode):
"output": VariableType.ARRAY_STRING "output": VariableType.ARRAY_STRING
} }
def _extract_output(self, business_result: Any) -> Any:
"""下游节点只拿 chunks 列表"""
if isinstance(business_result, dict) and "chunks" in business_result:
return business_result["chunks"]
return business_result
def _extract_citations(self, business_result: Any) -> list:
if isinstance(business_result, dict):
return business_result.get("citations", [])
return []
def _extract_extra_fields(self, business_result: Any) -> dict:
return {"citations": self._extract_citations(business_result)}
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
return { return {
"query": self._render_template(self.typed_config.query, variable_pool), "query": self._render_template(self.typed_config.query, variable_pool),
@@ -314,4 +328,20 @@ class KnowledgeRetrievalNode(BaseNode):
logger.info( logger.info(
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}" f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
) )
return [chunk.page_content for chunk in final_rs] citations = []
seen_doc_ids = set()
for chunk in final_rs:
meta = chunk.metadata or {}
doc_id = meta.get("document_id") or meta.get("doc_id")
if doc_id and doc_id not in seen_doc_ids:
seen_doc_ids.add(doc_id)
citations.append({
"document_id": str(doc_id),
"file_name": meta.get("file_name", ""),
"knowledge_id": str(meta.get("knowledge_id", kb_config.kb_id)),
"score": meta.get("score", 0.0),
})
return {
"chunks": [chunk.page_content for chunk in final_rs],
"citations": citations,
}

View File

@@ -0,0 +1,3 @@
from .node import ListOperatorNode
__all__ = ["ListOperatorNode"]

View File

@@ -0,0 +1,44 @@
from typing import Any
from pydantic import BaseModel, Field
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.enums import ComparisonOperator
class FilterCondition(BaseModel):
key: str = ""
comparison_operator: ComparisonOperator = ComparisonOperator.CONTAINS
value: str | list[str] | bool = ""
class FilterBy(BaseModel):
enabled: bool = False
conditions: list[FilterCondition] = Field(default_factory=list)
class OrderByConfig(BaseModel):
enabled: bool = False
key: str = ""
value: str = "asc" # "asc" | "desc"
class Limit(BaseModel):
enabled: bool = False
size: int = -1
class ExtractConfig(BaseModel):
enabled: bool = False
serial: str = "1" # 1-based index string, e.g. "1" = first
class ListOperatorNodeConfig(BaseNodeConfig):
"""
List Operator node config.
Operation order: filter -> extract -> order -> limit
"""
input_list: str = Field(..., description="Variable selector, e.g. {{ sys.files }} or {{ conv.uploaded_files }}")
filter_by: FilterBy = Field(default_factory=FilterBy)
order_by: OrderByConfig = Field(default_factory=OrderByConfig)
limit: Limit = Field(default_factory=Limit)
extract_by: ExtractConfig = Field(default_factory=ExtractConfig)

View File

@@ -0,0 +1,143 @@
import logging
from typing import Any
from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import ComparisonOperator
from app.core.workflow.nodes.list_operator.config import ListOperatorNodeConfig, FilterCondition
from app.core.workflow.variable.base_variable import VariableType
logger = logging.getLogger(__name__)
# File object fields that hold string values
_FILE_STRING_KEYS = {"name", "extension", "mime_type", "url", "transfer_method", "origin_file_type", "file_id"}
_FILE_NUMBER_KEYS = {"size"}
class ListOperatorNode(BaseNode):
def __init__(self, node_config: dict, workflow_config: dict, down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: ListOperatorNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:
return {
"result": VariableType.ANY,
"first_record": VariableType.ANY,
"last_record": VariableType.ANY,
}
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
self.typed_config = ListOperatorNodeConfig(**self.config)
cfg = self.typed_config
# Resolve input variable from path selector
items: list = self.get_variable(cfg.input_list, variable_pool)
if not isinstance(items, list):
raise TypeError(f"Variable '{cfg.input_list}' must be an array, got {type(items)}")
result = list(items)
# 1. Filter
if cfg.filter_by.enabled and cfg.filter_by.conditions:
for condition in cfg.filter_by.conditions:
result = [item for item in result if self._match_condition(item, condition, variable_pool)]
# 2. Extract (take single item by 1-based serial index)
if cfg.extract_by.enabled:
serial_str = self._resolve_value(cfg.extract_by.serial, variable_pool)
idx = int(serial_str) - 1
if idx < 0 or idx >= len(result):
raise ValueError(f"extract_by.serial={cfg.extract_by.serial} out of range (list length={len(result)})")
result = [result[idx]]
# 3. Order
if cfg.order_by.enabled and cfg.order_by.key:
reverse = cfg.order_by.value == "desc"
key_fn = self._make_sort_key(cfg.order_by.key)
result = sorted(result, key=key_fn, reverse=reverse)
# 4. Limit (take first N)
if cfg.limit.enabled and cfg.limit.size > 0:
result = result[:cfg.limit.size]
return {
"result": result,
"first_record": result[0] if result else None,
"last_record": result[-1] if result else None,
}
@staticmethod
def _resolve_value(value: str, variable_pool: VariablePool) -> Any:
"""If value is a {{ namespace.key }} variable selector, resolve it from the pool.
Otherwise return the raw string."""
import re
m = re.fullmatch(r"\{\{\s*(\w+\.\w+)\s*}}", value.strip())
if m:
resolved = variable_pool.get_value(value, default=value, strict=False)
return resolved
return value
@staticmethod
def _make_sort_key(key: str):
def key_fn(item):
if isinstance(item, dict):
return item.get(key) or ""
return item
return key_fn
def _match_condition(self, item: Any, cond: FilterCondition, variable_pool: VariablePool) -> bool:
op = cond.comparison_operator
value = cond.value
# Resolve value if it's a variable reference {{ namespace.key }}
if isinstance(value, str):
value = self._resolve_value(value, variable_pool)
# Resolve left value
if isinstance(item, dict):
left = item.get(cond.key)
else:
left = item # primitive array: compare element directly
# Numeric operators
if op == ComparisonOperator.EQ:
return self._safe_num(left) == self._safe_num(value)
if op == ComparisonOperator.NE:
return self._safe_num(left) != self._safe_num(value)
if op == ComparisonOperator.LT:
return self._safe_num(left) < self._safe_num(value)
if op == ComparisonOperator.LE:
return self._safe_num(left) <= self._safe_num(value)
if op == ComparisonOperator.GT:
return self._safe_num(left) > self._safe_num(value)
if op == ComparisonOperator.GE:
return self._safe_num(left) >= self._safe_num(value)
# String / sequence operators
left_str = str(left) if left is not None else ""
if op == ComparisonOperator.CONTAINS:
return str(value) in left_str
if op == ComparisonOperator.NOT_CONTAINS:
return str(value) not in left_str
if op == ComparisonOperator.START_WITH:
return left_str.startswith(str(value))
if op == ComparisonOperator.END_WITH:
return left_str.endswith(str(value))
if op == ComparisonOperator.IN:
return left_str in (value if isinstance(value, list) else [str(value)])
if op == ComparisonOperator.NOT_IN:
return left_str not in (value if isinstance(value, list) else [str(value)])
if op == ComparisonOperator.EMPTY:
return not left
if op == ComparisonOperator.NOT_EMPTY:
return bool(left)
raise ValueError(f"Unsupported operator: {op}")
@staticmethod
def _safe_num(v) -> float:
try:
return float(v)
except (TypeError, ValueError):
return 0.0

View File

@@ -135,8 +135,7 @@ class LLMNode(BaseNode):
api_key=model_info.api_key, api_key=model_info.api_key,
base_url=model_info.api_base, base_url=model_info.api_base,
extra_params=extra_params, extra_params=extra_params,
is_omni=model_info.is_omni, is_omni=model_info.is_omni
support_thinking="thinking" in (model_info.capability or []),
), ),
type=model_info.model_type type=model_info.model_type
) )
@@ -214,9 +213,10 @@ class LLMNode(BaseNode):
messages = messages[:-1] + history_message + messages[-1:] messages = messages[:-1] + history_message + messages[-1:]
self.messages = messages self.messages = messages
else: else:
# 使用简单的 prompt 格式(向后兼容) # 使用简单的 prompt 格式(向后兼容)——包装为标准消息列表以兼容所有 provider
prompt_template = self.config.get("prompt", "") prompt_template = self.config.get("prompt", "")
self.messages = self._render_template(prompt_template, variable_pool) rendered = self._render_template(prompt_template, variable_pool)
self.messages = [{"role": "user", "content": rendered}]
return llm return llm

View File

@@ -27,6 +27,7 @@ from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
from app.core.workflow.nodes.breaker import BreakNode from app.core.workflow.nodes.breaker import BreakNode
from app.core.workflow.nodes.tool import ToolNode from app.core.workflow.nodes.tool import ToolNode
from app.core.workflow.nodes.document_extractor import DocExtractorNode from app.core.workflow.nodes.document_extractor import DocExtractorNode
from app.core.workflow.nodes.list_operator import ListOperatorNode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -51,7 +52,8 @@ WorkflowNode = Union[
MemoryReadNode, MemoryReadNode,
MemoryWriteNode, MemoryWriteNode,
CodeNode, CodeNode,
DocExtractorNode DocExtractorNode,
ListOperatorNode
] ]
@@ -83,7 +85,8 @@ class NodeFactory:
NodeType.MEMORY_READ: MemoryReadNode, NodeType.MEMORY_READ: MemoryReadNode,
NodeType.MEMORY_WRITE: MemoryWriteNode, NodeType.MEMORY_WRITE: MemoryWriteNode,
NodeType.CODE: CodeNode, NodeType.CODE: CodeNode,
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNode NodeType.DOCUMENT_EXTRACTOR: DocExtractorNode,
NodeType.LIST_OPERATOR: ListOperatorNode
} }
@classmethod @classmethod

View File

@@ -118,8 +118,7 @@ class ParameterExtractorNode(BaseNode):
provider=provider, provider=provider,
api_key=api_key, api_key=api_key,
base_url=api_base, base_url=api_base,
is_omni=is_omni, is_omni=is_omni
support_thinking="thinking" in (capability or []),
), ),
type=ModelType(model_type) type=ModelType(model_type)
) )

View File

@@ -71,8 +71,7 @@ class QuestionClassifierNode(BaseNode):
provider=provider, provider=provider,
api_key=api_key, api_key=api_key,
base_url=base_url, base_url=base_url,
is_omni=is_omni, is_omni=is_omni
support_thinking="thinking" in (capability or []),
), ),
type=ModelType(model_type) type=ModelType(model_type)
) )

View File

@@ -1,7 +1,10 @@
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# Author: Eternity import mimetypes
# @Email: 1533512157@qq.com import os
# @Time : 2026/3/10 13:36 import uuid
from typing import Any
from urllib.parse import urlparse, unquote
TRANSFORM_FILE_TYPE = { TRANSFORM_FILE_TYPE = {
'text/plain': 'document/text', 'text/plain': 'document/text',
'text/markdown': 'document/markdown', 'text/markdown': 'document/markdown',
@@ -52,5 +55,143 @@ ALLOWED_FILE_TYPES = [
def mime_to_file_type(mime_type): def mime_to_file_type(mime_type):
if mime_type not in ALLOWED_FILE_TYPES: if mime_type not in ALLOWED_FILE_TYPES:
return None return None
return TRANSFORM_FILE_TYPE.get(mime_type, mime_type) return TRANSFORM_FILE_TYPE.get(mime_type, mime_type)
def build_file_object_dict_from_url(url: str, file_type: str, origin_file_type: str) -> dict[str, Any]:
"""Build a FileObject dict for a remote_url file using only URL parsing (no HTTP request).
Used as fallback when HTTP request fails.
"""
raw_path = url.split("?")[0]
name = unquote(os.path.basename(urlparse(url).path)) or None
_, ext = os.path.splitext(name or "")
extension = ext.lstrip(".").lower() if ext else None
guessed_mime = mimetypes.guess_type(url)[0]
return {
"type": file_type,
"url": url,
"transfer_method": "remote_url",
"origin_file_type": origin_file_type,
"file_id": None,
"name": name,
"size": None,
"extension": extension,
"mime_type": guessed_mime or origin_file_type,
"is_file": True,
}
async def fetch_remote_file_meta(
url: str,
file_type: str,
origin_file_type: str,
) -> dict[str, Any]:
"""Fetch remote file metadata via HEAD (fallback GET) and build a FileObject dict.
Falls back to URL-only parsing if the HTTP request fails.
"""
import httpx
name = size = mime_type = extension = None
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.head(url, follow_redirects=True)
if resp.status_code != 200:
resp = await client.get(url, follow_redirects=True)
cl = resp.headers.get("Content-Length")
size = int(cl) if cl else None
ct = resp.headers.get("Content-Type", "").split(";")[0].strip()
mime_type = ct or origin_file_type
cd = resp.headers.get("Content-Disposition", "")
if "filename=" in cd:
name = cd.split("filename=")[-1].strip('"').strip("'")
if not name:
name = unquote(os.path.basename(urlparse(url).path)) or None
if name:
_, ext = os.path.splitext(name)
extension = ext.lstrip(".").lower() if ext else None
if not extension and mime_type:
ext = mimetypes.guess_extension(mime_type)
extension = ext.lstrip(".").lower() if ext else None
except Exception:
return build_file_object_dict_from_url(url, file_type, origin_file_type)
return build_file_object_dict_from_meta(
file_type=file_type,
transfer_method="remote_url",
origin_file_type=origin_file_type,
file_id=None,
url=url,
file_name=name,
file_size=size,
file_ext=extension,
content_type=mime_type,
)
def build_file_object_dict_from_meta(
file_type: str,
transfer_method: str,
origin_file_type: str,
file_id: str,
url: str,
file_name: str | None,
file_size: int | None,
file_ext: str | None,
content_type: str | None,
) -> dict[str, Any]:
"""Build a FileObject dict from already-fetched FileMetadata fields."""
ext = (file_ext or "").lstrip(".")
return {
"type": file_type,
"url": url,
"transfer_method": transfer_method,
"origin_file_type": content_type or origin_file_type,
"file_id": file_id,
"name": file_name,
"size": file_size,
"extension": ext.lower() if ext else None,
"mime_type": content_type,
"is_file": True,
}
def resolve_local_file_object_dict(
db,
upload_file_id: str | uuid.UUID,
file_type: str,
origin_file_type: str,
) -> dict[str, Any] | None:
"""Query FileMetadata and build a FileObject dict for a local_file.
Returns None if the file is not found or not completed.
"""
from app.models.file_metadata_model import FileMetadata
from app.core.config import settings
try:
fid = uuid.UUID(str(upload_file_id))
except ValueError:
return None
meta = db.query(FileMetadata).filter(
FileMetadata.id == fid,
FileMetadata.status == "completed"
).first()
if not meta:
return None
url = f"{settings.FILE_LOCAL_SERVER_URL}/storage/permanent/{fid}"
return build_file_object_dict_from_meta(
file_type=file_type,
transfer_method="local_file",
origin_file_type=origin_file_type,
file_id=str(fid),
url=url,
file_name=meta.file_name,
file_size=meta.file_size,
file_ext=meta.file_ext,
content_type=meta.content_type,
)

View File

@@ -301,7 +301,7 @@ class WorkflowValidator:
for node in nodes: for node in nodes:
if node.get("type") not in [NodeType.START, NodeType.CYCLE_START, NodeType.END] and not node.get("name"): if node.get("type") not in [NodeType.START, NodeType.CYCLE_START, NodeType.END] and not node.get("name"):
errors.append( errors.append(
f"节点 {node.get('id')} 缺少名称(发布时必须提供)" f"节点 {node.get('name')} 缺少名称(发布时必须提供)"
) )
# 2. 验证所有非 start/end 节点都有配置 # 2. 验证所有非 start/end 节点都有配置
@@ -311,7 +311,7 @@ class WorkflowValidator:
config = node.get("config") config = node.get("config")
if not config or not isinstance(config, dict): if not config or not isinstance(config, dict):
errors.append( errors.append(
f"节点 {node.get('id')} 缺少配置(发布时必须提供)" f"节点 {node.get('name')} 缺少配置(发布时必须提供)"
) )
# 3. 验证必填变量 # 3. 验证必填变量

View File

@@ -91,7 +91,7 @@ def DEFAULT_VALUE(var_type: VariableType) -> Any:
case VariableType.OBJECT: case VariableType.OBJECT:
return {} return {}
case VariableType.FILE: case VariableType.FILE:
return None return {}
case VariableType.ARRAY_STRING: case VariableType.ARRAY_STRING:
return [] return []
case VariableType.ARRAY_NUMBER: case VariableType.ARRAY_NUMBER:
@@ -113,6 +113,12 @@ class FileObject(BaseModel):
origin_file_type: str origin_file_type: str
file_id: str | None file_id: str | None
# Extended file metadata
name: str | None = None
size: int | None = None
extension: str | None = None
mime_type: str | None = None
content_cache: dict = Field(default_factory=dict) content_cache: dict = Field(default_factory=dict)
is_file: bool is_file: bool

View File

@@ -66,20 +66,10 @@ class FileVariable(BaseVariable):
type = 'file' type = 'file'
def valid_value(self, value) -> FileObject: def valid_value(self, value) -> FileObject:
if isinstance(value, dict): if isinstance(value, dict):
if not value.get("is_file"): if not value.get("is_file"):
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}") raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
return FileObject( return FileObject(**value)
**{
"type": str(value.get('type')),
"transfer_method": value.get("transfer_method"),
"url": value.get('url'),
"file_id": value.get("file_id"),
"origin_file_type": value.get("origin_file_type"),
"is_file": True
}
)
if isinstance(value, FileObject): if isinstance(value, FileObject):
return value return value
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}") raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
@@ -88,7 +78,7 @@ class FileVariable(BaseVariable):
return f'{"!"if self.value.type == FileType.IMAGE else ""}[file]({self.value.url})' return f'{"!"if self.value.type == FileType.IMAGE else ""}[file]({self.value.url})'
def get_value(self) -> Any: def get_value(self) -> Any:
return self.value.model_dump() return self.value.model_dump(exclude={"content_cache"})
async def get_content(self): async def get_content(self):
total_bytes = 0 total_bytes = 0
@@ -186,6 +176,8 @@ def create_variable_instance(var_type: VariableType, value: Any) -> T:
return BooleanVariable(value) return BooleanVariable(value)
case VariableType.OBJECT: case VariableType.OBJECT:
return DictVariable(value) return DictVariable(value)
case VariableType.FILE:
return FileVariable(value)
case VariableType.ARRAY_STRING: case VariableType.ARRAY_STRING:
return make_array(StringVariable, value) return make_array(StringVariable, value)
case VariableType.ARRAY_NUMBER: case VariableType.ARRAY_NUMBER:

View File

@@ -62,6 +62,7 @@ async def lifespan(app: FastAPI):
else: else:
logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)") logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)")
await create_all_indexes() await create_all_indexes()
logger.info("All neo4j indexes and constraints created successfully!")
logger.info("应用程序启动完成") logger.info("应用程序启动完成")

View File

@@ -1,11 +1,11 @@
import asyncio
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def create_fulltext_indexes(): async def create_fulltext_indexes():
"""Create full-text indexes for keyword search with BM25 scoring.""" """Create full-text indexes for keyword search with BM25 scoring."""
connector = Neo4jConnector() connector = Neo4jConnector()
try: try:
# 创建 Statements 索引 # 创建 Statements 索引
await connector.execute_query(""" await connector.execute_query("""
CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement] CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement]
@@ -40,8 +40,16 @@ async def create_fulltext_indexes():
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""") """)
# 创建 Perceptual 感知记忆索引
await connector.execute_query("""
CREATE FULLTEXT INDEX perceptualFulltext IF NOT EXISTS FOR (p:Perceptual) ON EACH [p.summary, p.topic, p.domain]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
finally: finally:
await connector.close() await connector.close()
async def create_vector_indexes(): async def create_vector_indexes():
"""Create vector indexes for fast embedding similarity search. """Create vector indexes for fast embedding similarity search.
@@ -51,7 +59,6 @@ async def create_vector_indexes():
connector = Neo4jConnector() connector = Neo4jConnector()
try: try:
# Statement embedding index # Statement embedding index
await connector.execute_query(""" await connector.execute_query("""
CREATE VECTOR INDEX statement_embedding_index IF NOT EXISTS CREATE VECTOR INDEX statement_embedding_index IF NOT EXISTS
@@ -63,7 +70,6 @@ async def create_vector_indexes():
}} }}
""") """)
# Chunk embedding index # Chunk embedding index
await connector.execute_query(""" await connector.execute_query("""
CREATE VECTOR INDEX chunk_embedding_index IF NOT EXISTS CREATE VECTOR INDEX chunk_embedding_index IF NOT EXISTS
@@ -75,7 +81,6 @@ async def create_vector_indexes():
}} }}
""") """)
# Entity name embedding index # Entity name embedding index
await connector.execute_query(""" await connector.execute_query("""
CREATE VECTOR INDEX entity_embedding_index IF NOT EXISTS CREATE VECTOR INDEX entity_embedding_index IF NOT EXISTS
@@ -87,7 +92,6 @@ async def create_vector_indexes():
}} }}
""") """)
# Memory summary embedding index # Memory summary embedding index
await connector.execute_query(""" await connector.execute_query("""
CREATE VECTOR INDEX summary_embedding_index IF NOT EXISTS CREATE VECTOR INDEX summary_embedding_index IF NOT EXISTS
@@ -121,8 +125,20 @@ async def create_vector_indexes():
}} }}
""") """)
# Perceptual summary embedding index
await connector.execute_query("""
CREATE VECTOR INDEX perceptual_summary_embedding_index IF NOT EXISTS
FOR (p:Perceptual)
ON p.summary_embedding
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""")
finally: finally:
await connector.close() await connector.close()
async def create_unique_constraints(): async def create_unique_constraints():
"""Create uniqueness constraints for core node identifiers. """Create uniqueness constraints for core node identifiers.
Ensures concurrent MERGE operations remain safe and prevents duplicates. Ensures concurrent MERGE operations remain safe and prevents duplicates.
@@ -155,10 +171,10 @@ async def create_unique_constraints():
finally: finally:
await connector.close() await connector.close()
async def create_all_indexes(): async def create_all_indexes():
"""Create all indexes and constraints in one go.""" """Create all indexes and constraints in one go."""
await create_fulltext_indexes() await create_fulltext_indexes()
await create_vector_indexes() await create_vector_indexes()
await create_unique_constraints() await create_unique_constraints()
print("✓ All indexes and constraints created successfully!")

View File

@@ -1449,3 +1449,44 @@ ON CREATE SET r.end_user_id = edge.end_user_id,
r.created_at = edge.created_at r.created_at = edge.created_at
RETURN elementId(r) AS uuid RETURN elementId(r) AS uuid
""" """
SEARCH_PERCEPTUAL_BY_KEYWORD = """
CALL db.index.fulltext.queryNodes("perceptualFulltext", $q) YIELD node AS p, score
WHERE p.end_user_id = $end_user_id
RETURN p.id AS id,
p.end_user_id AS end_user_id,
p.perceptual_type AS perceptual_type,
p.file_path AS file_path,
p.file_name AS file_name,
p.file_ext AS file_ext,
p.summary AS summary,
p.keywords AS keywords,
p.topic AS topic,
p.domain AS domain,
p.created_at AS created_at,
p.file_type AS file_type,
score
ORDER BY score DESC
LIMIT $limit
"""
PERCEPTUAL_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('perceptual_summary_embedding_index', $limit * 100, $embedding)
YIELD node AS p, score
WHERE p.summary_embedding IS NOT NULL AND p.end_user_id = $end_user_id
RETURN p.id AS id,
p.end_user_id AS end_user_id,
p.perceptual_type AS perceptual_type,
p.file_path AS file_path,
p.file_name AS file_name,
p.file_ext AS file_ext,
p.summary AS summary,
p.keywords AS keywords,
p.topic AS topic,
p.domain AS domain,
p.created_at AS created_at,
p.file_type AS file_type,
score
ORDER BY score DESC
LIMIT $limit
"""

View File

@@ -8,6 +8,7 @@ from app.repositories.neo4j.cypher_queries import (
ENTITY_EMBEDDING_SEARCH, ENTITY_EMBEDDING_SEARCH,
EXPAND_COMMUNITY_STATEMENTS, EXPAND_COMMUNITY_STATEMENTS,
MEMORY_SUMMARY_EMBEDDING_SEARCH, MEMORY_SUMMARY_EMBEDDING_SEARCH,
PERCEPTUAL_EMBEDDING_SEARCH,
SEARCH_CHUNK_BY_CHUNK_ID, SEARCH_CHUNK_BY_CHUNK_ID,
SEARCH_CHUNKS_BY_CONTENT, SEARCH_CHUNKS_BY_CONTENT,
SEARCH_COMMUNITIES_BY_KEYWORD, SEARCH_COMMUNITIES_BY_KEYWORD,
@@ -15,6 +16,7 @@ from app.repositories.neo4j.cypher_queries import (
SEARCH_ENTITIES_BY_NAME, SEARCH_ENTITIES_BY_NAME,
SEARCH_ENTITIES_BY_NAME_OR_ALIAS, SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
SEARCH_PERCEPTUAL_BY_KEYWORD,
SEARCH_STATEMENTS_BY_CREATED_AT, SEARCH_STATEMENTS_BY_CREATED_AT,
SEARCH_STATEMENTS_BY_KEYWORD, SEARCH_STATEMENTS_BY_KEYWORD,
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL, SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
@@ -34,11 +36,11 @@ logger = logging.getLogger(__name__)
async def _update_activation_values_batch( async def _update_activation_values_batch(
connector: Neo4jConnector, connector: Neo4jConnector,
nodes: List[Dict[str, Any]], nodes: List[Dict[str, Any]],
node_label: str, node_label: str,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
max_retries: int = 3 max_retries: int = 3
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
批量更新节点的激活值 批量更新节点的激活值
@@ -120,9 +122,9 @@ async def _update_activation_values_batch(
async def _update_search_results_activation( async def _update_search_results_activation(
connector: Neo4jConnector, connector: Neo4jConnector,
results: Dict[str, List[Dict[str, Any]]], results: Dict[str, List[Dict[str, Any]]],
end_user_id: Optional[str] = None end_user_id: Optional[str] = None
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
更新搜索结果中所有知识节点的激活值 更新搜索结果中所有知识节点的激活值
@@ -196,7 +198,7 @@ async def _update_search_results_activation(
'importance_score', 'importance_score',
'version', 'version',
'statement', # Statement 节点的内容字段 'statement', # Statement 节点的内容字段
'content' # MemorySummary 节点的内容字段 'content' # MemorySummary 节点的内容字段
} }
# 只更新激活值相关字段,保留原始节点的其他字段 # 只更新激活值相关字段,保留原始节点的其他字段
@@ -220,11 +222,11 @@ async def _update_search_results_activation(
async def search_graph( async def search_graph(
connector: Neo4jConnector, connector: Neo4jConnector,
q: str, q: str,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 50, limit: int = 50,
include: List[str] = None, include: List[str] = None,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Search across Statements, Entities, Chunks, and Summaries using a free-text query. Search across Statements, Entities, Chunks, and Summaries using a free-text query.
@@ -257,6 +259,7 @@ async def search_graph(
if "statements" in include: if "statements" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD, SEARCH_STATEMENTS_BY_KEYWORD,
json_format=True,
q=q, q=q,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
@@ -266,6 +269,7 @@ async def search_graph(
if "entities" in include: if "entities" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
SEARCH_ENTITIES_BY_NAME_OR_ALIAS, SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
json_format=True,
q=q, q=q,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
@@ -275,6 +279,7 @@ async def search_graph(
if "chunks" in include: if "chunks" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
SEARCH_CHUNKS_BY_CONTENT, SEARCH_CHUNKS_BY_CONTENT,
json_format=True,
q=q, q=q,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
@@ -284,6 +289,7 @@ async def search_graph(
if "summaries" in include: if "summaries" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
json_format=True,
q=q, q=q,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
@@ -293,6 +299,7 @@ async def search_graph(
if "communities" in include: if "communities" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
SEARCH_COMMUNITIES_BY_KEYWORD, SEARCH_COMMUNITIES_BY_KEYWORD,
json_format=True,
q=q, q=q,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
@@ -336,12 +343,12 @@ async def search_graph(
async def search_graph_by_embedding( async def search_graph_by_embedding(
connector: Neo4jConnector, connector: Neo4jConnector,
embedder_client, embedder_client,
query_text: str, query_text: str,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 50, limit: int = 50,
include: List[str] = ["statements", "chunks", "entities","summaries"], include: List[str] = ["statements", "chunks", "entities", "summaries"],
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Embedding-based semantic search across Statements, Chunks, and Entities. Embedding-based semantic search across Statements, Chunks, and Entities.
@@ -360,7 +367,7 @@ async def search_graph_by_embedding(
embed_start = time.time() embed_start = time.time()
embeddings = await embedder_client.response([query_text]) embeddings = await embedder_client.response([query_text])
embed_time = time.time() - embed_start embed_time = time.time() - embed_start
print(f"[PERF] Embedding generation took: {embed_time:.4f}s") logger.debug(f"[PERF] Embedding generation took: {embed_time:.4f}s")
if not embeddings or not embeddings[0]: if not embeddings or not embeddings[0]:
logger.warning( logger.warning(
@@ -378,6 +385,7 @@ async def search_graph_by_embedding(
if "statements" in include: if "statements" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
STATEMENT_EMBEDDING_SEARCH, STATEMENT_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding, embedding=embedding,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
@@ -388,6 +396,7 @@ async def search_graph_by_embedding(
if "chunks" in include: if "chunks" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
CHUNK_EMBEDDING_SEARCH, CHUNK_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding, embedding=embedding,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
@@ -398,6 +407,7 @@ async def search_graph_by_embedding(
if "entities" in include: if "entities" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
ENTITY_EMBEDDING_SEARCH, ENTITY_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding, embedding=embedding,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
@@ -408,6 +418,7 @@ async def search_graph_by_embedding(
if "summaries" in include: if "summaries" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
MEMORY_SUMMARY_EMBEDDING_SEARCH, MEMORY_SUMMARY_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding, embedding=embedding,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
@@ -418,6 +429,7 @@ async def search_graph_by_embedding(
if "communities" in include: if "communities" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
COMMUNITY_EMBEDDING_SEARCH, COMMUNITY_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding, embedding=embedding,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
@@ -428,7 +440,7 @@ async def search_graph_by_embedding(
query_start = time.time() query_start = time.time()
task_results = await asyncio.gather(*tasks, return_exceptions=True) task_results = await asyncio.gather(*tasks, return_exceptions=True)
query_time = time.time() - query_start query_time = time.time() - query_start
print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s") logger.debug(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
# Build results dictionary # Build results dictionary
results: Dict[str, List[Dict[str, Any]]] = { results: Dict[str, List[Dict[str, Any]]] = {
@@ -473,13 +485,15 @@ async def search_graph_by_embedding(
logger.info(f"[PERF] Skipping activation updates (only summaries)") logger.info(f"[PERF] Skipping activation updates (only summaries)")
return results return results
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体 async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
connector: Neo4jConnector, connector: Neo4jConnector,
end_user_id: str, end_user_id: str,
entities: List[Dict[str, Any]], entities: List[Dict[str, Any]],
use_contains_fallback: bool = True, use_contains_fallback: bool = True,
batch_size: int = 500, batch_size: int = 500,
max_concurrency: int = 5, max_concurrency: int = 5,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
为第二层去重消歧批量检索候选实体(适配新版 cypher_queries 为第二层去重消歧批量检索候选实体(适配新版 cypher_queries
@@ -560,14 +574,14 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
async def search_graph_by_keyword_temporal( async def search_graph_by_keyword_temporal(
connector: Neo4jConnector, connector: Neo4jConnector,
query_text: str, query_text: str,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
valid_date: Optional[str] = None, valid_date: Optional[str] = None,
invalid_date: Optional[str] = None, invalid_date: Optional[str] = None,
limit: int = 50, limit: int = 50,
) -> Dict[str, List[Any]]: ) -> Dict[str, List[Any]]:
""" """
Temporal keyword search across Statements. Temporal keyword search across Statements.
@@ -579,7 +593,7 @@ async def search_graph_by_keyword_temporal(
- Returns up to 'limit' statements - Returns up to 'limit' statements
""" """
if not query_text: if not query_text:
print(f"query_text不能为空") logger.warning(f"query_text不能为空")
return {"statements": []} return {"statements": []}
statements = await connector.execute_query( statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL, SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
@@ -591,7 +605,7 @@ async def search_graph_by_keyword_temporal(
invalid_date=invalid_date, invalid_date=invalid_date,
limit=limit, limit=limit,
) )
print(f"查询结果为:\n{statements}") logger.debug(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
@@ -605,13 +619,13 @@ async def search_graph_by_keyword_temporal(
async def search_graph_by_temporal( async def search_graph_by_temporal(
connector: Neo4jConnector, connector: Neo4jConnector,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
valid_date: Optional[str] = None, valid_date: Optional[str] = None,
invalid_date: Optional[str] = None, invalid_date: Optional[str] = None,
limit: int = 10, limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Statements. Temporal search across Statements.
@@ -632,10 +646,6 @@ async def search_graph_by_temporal(
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
@@ -648,10 +658,10 @@ async def search_graph_by_temporal(
async def search_graph_by_dialog_id( async def search_graph_by_dialog_id(
connector: Neo4jConnector, connector: Neo4jConnector,
dialog_id: str, dialog_id: str,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 1, limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Dialogues. Temporal search across Dialogues.
@@ -661,7 +671,7 @@ async def search_graph_by_dialog_id(
- Returns up to 'limit' dialogues - Returns up to 'limit' dialogues
""" """
if not dialog_id: if not dialog_id:
print(f"dialog_id不能为空") logger.warning(f"dialog_id不能为空")
return {"dialogues": []} return {"dialogues": []}
dialogues = await connector.execute_query( dialogues = await connector.execute_query(
@@ -674,13 +684,13 @@ async def search_graph_by_dialog_id(
async def search_graph_by_chunk_id( async def search_graph_by_chunk_id(
connector: Neo4jConnector, connector: Neo4jConnector,
chunk_id : str, chunk_id: str,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 1, limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
if not chunk_id: if not chunk_id:
print(f"chunk_id不能为空") logger.warning(f"chunk_id不能为空")
return {"chunks": []} return {"chunks": []}
chunks = await connector.execute_query( chunks = await connector.execute_query(
SEARCH_CHUNK_BY_CHUNK_ID, SEARCH_CHUNK_BY_CHUNK_ID,
@@ -692,10 +702,10 @@ async def search_graph_by_chunk_id(
async def search_graph_community_expand( async def search_graph_community_expand(
connector: Neo4jConnector, connector: Neo4jConnector,
community_ids: List[str], community_ids: List[str],
end_user_id: str, end_user_id: str,
limit: int = 10, limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
三期:社区展开检索 —— 主题 → 细节两级检索。 三期:社区展开检索 —— 主题 → 细节两级检索。
@@ -748,12 +758,11 @@ async def search_graph_community_expand(
async def search_graph_by_created_at( async def search_graph_by_created_at(
connector: Neo4jConnector, connector: Neo4jConnector,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
created_at: Optional[str] = None, limit: int = 1,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Statements. Temporal search across Statements.
@@ -768,15 +777,10 @@ async def search_graph_by_created_at(
SEARCH_STATEMENTS_BY_CREATED_AT, SEARCH_STATEMENTS_BY_CREATED_AT,
end_user_id=end_user_id, end_user_id=end_user_id,
created_at=created_at, created_at=created_at,
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id} created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
@@ -787,13 +791,13 @@ async def search_graph_by_created_at(
return results return results
async def search_graph_by_valid_at( async def search_graph_by_valid_at(
connector: Neo4jConnector, connector: Neo4jConnector,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
valid_at: Optional[str] = None, limit: int = 1,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Statements. Temporal search across Statements.
@@ -808,15 +812,10 @@ async def search_graph_by_valid_at(
SEARCH_STATEMENTS_BY_VALID_AT, SEARCH_STATEMENTS_BY_VALID_AT,
end_user_id=end_user_id, end_user_id=end_user_id,
valid_at=valid_at, valid_at=valid_at,
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id} valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
@@ -827,13 +826,13 @@ async def search_graph_by_valid_at(
return results return results
async def search_graph_g_created_at( async def search_graph_g_created_at(
connector: Neo4jConnector, connector: Neo4jConnector,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
created_at: Optional[str] = None, limit: int = 1,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Statements. Temporal search across Statements.
@@ -848,15 +847,10 @@ async def search_graph_g_created_at(
SEARCH_STATEMENTS_G_CREATED_AT, SEARCH_STATEMENTS_G_CREATED_AT,
end_user_id=end_user_id, end_user_id=end_user_id,
created_at=created_at, created_at=created_at,
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
@@ -867,13 +861,13 @@ async def search_graph_g_created_at(
return results return results
async def search_graph_g_valid_at( async def search_graph_g_valid_at(
connector: Neo4jConnector, connector: Neo4jConnector,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
valid_at: Optional[str] = None, limit: int = 1,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Statements. Temporal search across Statements.
@@ -887,16 +881,10 @@ async def search_graph_g_valid_at(
statements = await connector.execute_query( statements = await connector.execute_query(
SEARCH_STATEMENTS_G_VALID_AT, SEARCH_STATEMENTS_G_VALID_AT,
end_user_id=end_user_id, end_user_id=end_user_id,
valid_at=valid_at, valid_at=valid_at,
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
@@ -907,13 +895,13 @@ async def search_graph_g_valid_at(
return results return results
async def search_graph_l_created_at( async def search_graph_l_created_at(
connector: Neo4jConnector, connector: Neo4jConnector,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
created_at: Optional[str] = None, limit: int = 1,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Statements. Temporal search across Statements.
@@ -928,15 +916,10 @@ async def search_graph_l_created_at(
SEARCH_STATEMENTS_L_CREATED_AT, SEARCH_STATEMENTS_L_CREATED_AT,
end_user_id=end_user_id, end_user_id=end_user_id,
created_at=created_at, created_at=created_at,
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
@@ -947,13 +930,13 @@ async def search_graph_l_created_at(
return results return results
async def search_graph_l_valid_at( async def search_graph_l_valid_at(
connector: Neo4jConnector, connector: Neo4jConnector,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
valid_at: Optional[str] = None, limit: int = 1,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Statements. Temporal search across Statements.
@@ -968,15 +951,10 @@ async def search_graph_l_valid_at(
SEARCH_STATEMENTS_L_VALID_AT, SEARCH_STATEMENTS_L_VALID_AT,
end_user_id=end_user_id, end_user_id=end_user_id,
valid_at=valid_at, valid_at=valid_at,
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
@@ -986,3 +964,87 @@ async def search_graph_l_valid_at(
) )
return results return results
async def search_perceptual(
connector: Neo4jConnector,
q: str,
end_user_id: Optional[str] = None,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Search Perceptual memory nodes using fulltext keyword search.
Matches against summary, topic, and domain fields via the perceptualFulltext index.
Args:
connector: Neo4j connector
q: Query text
end_user_id: Optional user filter
limit: Max results
Returns:
Dictionary with 'perceptuals' key containing matched perceptual memory nodes
"""
try:
perceptuals = await connector.execute_query(
SEARCH_PERCEPTUAL_BY_KEYWORD,
q=q,
end_user_id=end_user_id,
limit=limit,
)
except Exception as e:
logger.warning(f"search_perceptual: keyword search failed: {e}")
perceptuals = []
# Deduplicate
from app.core.memory.src.search import _deduplicate_results
perceptuals = _deduplicate_results(perceptuals)
return {"perceptuals": perceptuals}
async def search_perceptual_by_embedding(
connector: Neo4jConnector,
embedder_client,
query_text: str,
end_user_id: Optional[str] = None,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Search Perceptual memory nodes using embedding-based semantic search.
Uses cosine similarity on summary_embedding via the perceptual_summary_embedding_index.
Args:
connector: Neo4j connector
embedder_client: Embedding client with async response() method
query_text: Query text to embed
end_user_id: Optional user filter
limit: Max results
Returns:
Dictionary with 'perceptuals' key containing matched perceptual memory nodes
"""
embeddings = await embedder_client.response([query_text])
if not embeddings or not embeddings[0]:
logger.warning(f"search_perceptual_by_embedding: embedding generation failed for '{query_text[:50]}'")
return {"perceptuals": []}
embedding = embeddings[0]
try:
perceptuals = await connector.execute_query(
PERCEPTUAL_EMBEDDING_SEARCH,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
)
except Exception as e:
logger.warning(f"search_perceptual_by_embedding: vector search failed: {e}")
perceptuals = []
from app.core.memory.src.search import _deduplicate_results
perceptuals = _deduplicate_results(perceptuals)
return {"perceptuals": perceptuals}

View File

@@ -11,10 +11,28 @@ Classes:
from typing import Any, List, Dict from typing import Any, List, Dict
from neo4j import AsyncGraphDatabase, basic_auth from neo4j import AsyncGraphDatabase, basic_auth
from neo4j.time import DateTime as Neo4jDateTime, Date as Neo4jDate, Time as Neo4jTime, Duration as Neo4jDuration
from app.core.config import settings from app.core.config import settings
def _convert_neo4j_types(value: Any) -> Any:
"""递归将 neo4j 原生时间类型转为 Python 原生类型 / ISO 字符串,确保可被 json.dumps 序列化。"""
if isinstance(value, Neo4jDateTime):
return value.to_native().isoformat() if value.tzinfo else value.iso_format()
if isinstance(value, Neo4jDate):
return value.iso_format()
if isinstance(value, Neo4jTime):
return value.iso_format()
if isinstance(value, Neo4jDuration):
return str(value)
if isinstance(value, dict):
return {k: _convert_neo4j_types(v) for k, v in value.items()}
if isinstance(value, list):
return [_convert_neo4j_types(item) for item in value]
return value
class Neo4jConnector: class Neo4jConnector:
"""Neo4j数据库连接器 """Neo4j数据库连接器
@@ -59,11 +77,12 @@ class Neo4jConnector:
""" """
await self.driver.close() await self.driver.close()
async def execute_query(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]: async def execute_query(self, query: str, json_format=False, **kwargs: Any) -> List[Dict[str, Any]]:
"""执行Cypher查询 """执行Cypher查询
Args: Args:
query: Cypher查询语句 query: Cypher查询语句
json_format: json格式化
**kwargs: 查询参数将作为参数传递给Cypher查询 **kwargs: 查询参数将作为参数传递给Cypher查询
Returns: Returns:
@@ -78,7 +97,10 @@ class Neo4jConnector:
**kwargs **kwargs
) )
records, summary, keys = result records, summary, keys = result
return [record.data() for record in records] if json_format:
return [_convert_neo4j_types(record.data()) for record in records]
else:
return [record.data() for record in records]
async def execute_write_transaction(self, transaction_func, **kwargs: Any) -> Any: async def execute_write_transaction(self, transaction_func, **kwargs: Any) -> Any:
"""在写事务中执行操作 """在写事务中执行操作

View File

@@ -4,10 +4,6 @@ from typing import Optional, Any, List, Dict, Union
from enum import Enum, StrEnum from enum import Enum, StrEnum
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
from app.schemas.workflow_schema import WorkflowConfigCreate
# ---------- Multimodal File Support ---------- # ---------- Multimodal File Support ----------
class FileType(StrEnum): class FileType(StrEnum):
@@ -317,7 +313,7 @@ class AppCreate(BaseModel):
# only for type=multi_agent # only for type=multi_agent
multi_agent_config: Optional[Dict[str, Any]] = None multi_agent_config: Optional[Dict[str, Any]] = None
workflow_config: Optional[WorkflowConfigCreate] = None workflow_config: Optional[Dict[str, Any]] = None
class AppUpdate(BaseModel): class AppUpdate(BaseModel):
@@ -644,6 +640,7 @@ class CitationSource(BaseModel):
class DraftRunResponse(BaseModel): class DraftRunResponse(BaseModel):
"""试运行响应(非流式)""" """试运行响应(非流式)"""
message: str = Field(..., description="AI 回复消息") message: str = Field(..., description="AI 回复消息")
reasoning_content: Optional[str] = Field(default=None, description="深度思考内容")
conversation_id: Optional[str] = Field(default=None, description="会话ID用于多轮对话") conversation_id: Optional[str] = Field(default=None, description="会话ID用于多轮对话")
usage: Optional[Dict[str, Any]] = Field(default=None, description="Token 使用情况") usage: Optional[Dict[str, Any]] = Field(default=None, description="Token 使用情况")
elapsed_time: Optional[float] = Field(default=None, description="耗时(秒)") elapsed_time: Optional[float] = Field(default=None, description="耗时(秒)")
@@ -651,6 +648,12 @@ class DraftRunResponse(BaseModel):
citations: List[CitationSource] = Field(default_factory=list, description="引用来源") citations: List[CitationSource] = Field(default_factory=list, description="引用来源")
audio_url: Optional[str] = Field(default=None, description="TTS 语音URL") audio_url: Optional[str] = Field(default=None, description="TTS 语音URL")
def model_dump(self, **kwargs):
data = super().model_dump(**kwargs)
if not data.get("reasoning_content"):
data.pop("reasoning_content", None)
return data
class OpeningResponse(BaseModel): class OpeningResponse(BaseModel):
"""应用开场白响应""" """应用开场白响应"""

View File

@@ -401,7 +401,7 @@ class AppService:
def _create_workflow_config( def _create_workflow_config(
self, self,
app_id: uuid.UUID, app_id: uuid.UUID,
data: app_schema.WorkflowConfigCreate, data,
now: datetime.datetime now: datetime.datetime
): ):
workflow_cfg = WorkflowConfig( workflow_cfg = WorkflowConfig(
@@ -678,7 +678,9 @@ class AppService:
self._create_multi_agent_config(app.id, data.multi_agent_config, now) self._create_multi_agent_config(app.id, data.multi_agent_config, now)
if app.type == "workflow" and data.workflow_config: if app.type == "workflow" and data.workflow_config:
self._create_workflow_config(app.id, data.workflow_config, now) from app.schemas.workflow_schema import WorkflowConfigCreate
wf_data = WorkflowConfigCreate(**data.workflow_config) if isinstance(data.workflow_config, dict) else data.workflow_config
self._create_workflow_config(app.id, wf_data, now)
self.db.commit() self.db.commit()
self.db.refresh(app) self.db.refresh(app)

View File

@@ -462,11 +462,6 @@ class MemoryAgentService:
logger.info(f"Read operation for group {end_user_id} with config_id {config_id}") logger.info(f"Read operation for group {end_user_id} with config_id {config_id}")
# 导入审计日志记录器
config_load_start = time.time() config_load_start = time.time()
try: try:
# Use a separate database session to avoid transaction failures # Use a separate database session to avoid transaction failures
@@ -507,10 +502,13 @@ class MemoryAgentService:
async with make_read_graph() as graph: async with make_read_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}} config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段 # 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch, initial_state = {
"end_user_id": end_user_id "messages": [HumanMessage(content=message)],
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id, "search_switch": search_switch,
"memory_config": memory_config} "end_user_id": end_user_id
, "storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"memory_config": memory_config}
# 获取节点更新信息 # 获取节点更新信息
_intermediate_outputs = [] _intermediate_outputs = []
summary = '' summary = ''
@@ -522,7 +520,7 @@ class MemoryAgentService:
for node_name, node_data in update_event.items(): for node_name, node_data in update_event.items():
# if 'save_neo4j' == node_name: # if 'save_neo4j' == node_name:
# massages = node_data # massages = node_data
print(f"处理节点: {node_name}") logger.info(f"处理节点: {node_name}")
# 处理不同Summary节点的返回结构 # 处理不同Summary节点的返回结构
if 'Summary' in node_name: if 'Summary' in node_name:
@@ -549,6 +547,11 @@ class MemoryAgentService:
if retrieve_node and retrieve_node != [] and retrieve_node != {}: if retrieve_node and retrieve_node != [] and retrieve_node != {}:
_intermediate_outputs.extend(retrieve_node) _intermediate_outputs.extend(retrieve_node)
# Perceptual_Retrieve 节点
perceptual_node = node_data.get('perceptual_data', {}).get('_intermediate', None)
if perceptual_node and perceptual_node != [] and perceptual_node != {}:
_intermediate_outputs.append(perceptual_node)
# Verify 节点 # Verify 节点
verify_n = node_data.get('verify', {}).get('_intermediate', None) verify_n = node_data.get('verify', {}).get('_intermediate', None)
if verify_n and verify_n != [] and verify_n != {}: if verify_n and verify_n != [] and verify_n != {}:

View File

@@ -16,7 +16,6 @@ from app.core.workflow.adapters.registry import PlatformAdapterRegistry
from app.core.workflow.executor import execute_workflow, execute_workflow_stream from app.core.workflow.executor import execute_workflow, execute_workflow_stream
from app.core.workflow.nodes.enums import NodeType from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.validator import validate_workflow_config from app.core.workflow.validator import validate_workflow_config
from app.core.workflow.variable.base_variable import FileObject
from app.db import get_db from app.db import get_db
from app.models import App from app.models import App
from app.models.workflow_model import WorkflowConfig, WorkflowExecution from app.models.workflow_model import WorkflowConfig, WorkflowExecution
@@ -453,22 +452,70 @@ class WorkflowService:
"success_rate": completed / total if total > 0 else 0 "success_rate": completed / total if total > 0 else 0
} }
async def _resolve_variables_file_defaults(
self,
variables: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""Convert FileInput-format defaults in workflow variables to full FileObject dicts."""
from app.core.workflow.utils.file_processor import (
resolve_local_file_object_dict,
fetch_remote_file_meta,
)
async def _resolve_one(item: dict) -> dict | None:
if not isinstance(item, dict) or item.get("is_file"):
return item
transfer_method = item.get("transfer_method", "remote_url")
file_type = item.get("type", "document")
origin_file_type = item.get("file_type") or file_type
if transfer_method == "remote_url":
url = item.get("url", "")
return await fetch_remote_file_meta(url, file_type, origin_file_type) if url else None
else:
return resolve_local_file_object_dict(self.db, item.get("upload_file_id"), file_type, origin_file_type)
result = []
for var_def in variables:
var_type = var_def.get("type", "")
default = var_def.get("default")
if var_type == "file" and isinstance(default, dict) and not default.get("is_file"):
var_def = {**var_def, "default": await _resolve_one(default)}
elif var_type == "array[file]" and isinstance(default, list):
resolved = []
for item in default:
r = await _resolve_one(item)
if r is not None:
resolved.append(r)
var_def = {**var_def, "default": resolved}
result.append(var_def)
return result
async def _handle_file_input(self, files: list[FileInput]): async def _handle_file_input(self, files: list[FileInput]):
if not files: if not files:
return [] return []
from app.core.workflow.utils.file_processor import (
resolve_local_file_object_dict,
build_file_object_dict_from_meta,
fetch_remote_file_meta,
)
files_struct = [] files_struct = []
for file in files: for file in files:
files_struct.append( url = await self.multimodal_service.get_file_url(file)
FileObject( file_type = str(file.type)
type=file.type, origin_file_type = file.file_type or file_type
url=await self.multimodal_service.get_file_url(file),
transfer_method=file.transfer_method, if file.transfer_method.value == "local_file" and file.upload_file_id:
file_id=str(file.upload_file_id) if file.upload_file_id else None, fo = resolve_local_file_object_dict(self.db, file.upload_file_id, file_type, origin_file_type)
origin_file_type=file.file_type, files_struct.append(fo or build_file_object_dict_from_meta(
is_file=True file_type=file_type, transfer_method="local_file",
).model_dump() origin_file_type=origin_file_type,
) file_id=str(file.upload_file_id), url=url,
file_name=None, file_size=None, file_ext=None, content_type=None,
))
else:
files_struct.append(await fetch_remote_file_meta(url, file_type, origin_file_type))
return files_struct return files_struct
@staticmethod @staticmethod
@@ -545,6 +592,12 @@ class WorkflowService:
def _get_memory_store_info(self, workspace_id: uuid.UUID) -> tuple[str, str]: def _get_memory_store_info(self, workspace_id: uuid.UUID) -> tuple[str, str]:
storage_type = get_workspace_storage_type_without_auth(self.db, workspace_id) storage_type = get_workspace_storage_type_without_auth(self.db, workspace_id)
user_rag_memory_id = "" user_rag_memory_id = ""
# 如果 storage_type 为 None使用默认值 'neo4j'
if not storage_type:
storage_type = 'neo4j'
logger.warning(
f"Storage type not set for workspace {workspace_id}, using default: neo4j"
)
if storage_type == "rag": if storage_type == "rag":
knowledge = knowledge_repository.get_knowledge_by_name( knowledge = knowledge_repository.get_knowledge_by_name(
db=self.db, db=self.db,
@@ -659,6 +712,26 @@ class WorkflowService:
input_data["conv_messages"] = conv_messages input_data["conv_messages"] = conv_messages
init_message_length = len(input_data.get("conv_messages", [])) init_message_length = len(input_data.get("conv_messages", []))
# 新会话时写入开场白
is_new_conversation = init_message_length == 0
if is_new_conversation:
opening_cfg = feature_configs.get("opening_statement", {})
if isinstance(opening_cfg, dict) and opening_cfg.get("enabled") and opening_cfg.get("statement"):
statement = opening_cfg["statement"]
suggested_questions = opening_cfg.get("suggested_questions", [])
if payload.variables:
for var_name, var_value in payload.variables.items():
statement = statement.replace(f"{{{{{var_name}}}}}", str(var_value))
self.conversation_service.add_message(
conversation_id=conversation_id_uuid,
role="assistant",
content=statement,
meta_data={"suggested_questions": suggested_questions}
)
# 注入到 conv_messages让 LLM 感知开场白
input_data["conv_messages"] = [{"role": "assistant", "content": statement}]
init_message_length = 1
result = await execute_workflow( result = await execute_workflow(
workflow_config=workflow_config_dict, workflow_config=workflow_config_dict,
input_data=input_data, input_data=input_data,
@@ -696,12 +769,21 @@ class WorkflowService:
content=human_message, content=human_message,
meta_data=human_meta meta_data=human_meta
) )
# 过滤 citations
citations = result.get("citations", [])
citation_cfg = feature_configs.get("citation", {})
filtered_citations = (
citations if isinstance(citation_cfg, dict) and citation_cfg.get("enabled") else []
)
assistant_meta = {"usage": token_usage, "audio_url": None}
if filtered_citations:
assistant_meta["citations"] = filtered_citations
self.conversation_service.add_message( self.conversation_service.add_message(
message_id=message_id, message_id=message_id,
conversation_id=conversation_id_uuid, conversation_id=conversation_id_uuid,
role="assistant", role="assistant",
content=assistant_message, content=assistant_message,
meta_data={"usage": token_usage, "audio_url": None} meta_data=assistant_meta
) )
self.update_execution_status( self.update_execution_status(
execution.execution_id, execution.execution_id,
@@ -720,6 +802,7 @@ class WorkflowService:
) )
logger.error(f"Workflow Run Failed, execution_id: {execution.execution_id}," logger.error(f"Workflow Run Failed, execution_id: {execution.execution_id},"
f" error: {result.get('error')}") f" error: {result.get('error')}")
filtered_citations = []
# 返回增强的响应结构 # 返回增强的响应结构
return { return {
@@ -734,7 +817,8 @@ class WorkflowService:
"conversation_id": result.get("conversation_id"), # 所有节点输出详细数据payload., # 会话 ID "conversation_id": result.get("conversation_id"), # 所有节点输出详细数据payload., # 会话 ID
"error_message": result.get("error"), "error_message": result.get("error"),
"elapsed_time": result.get("elapsed_time"), "elapsed_time": result.get("elapsed_time"),
"token_usage": result.get("token_usage") "token_usage": result.get("token_usage"),
"citations": filtered_citations,
} }
except Exception as e: except Exception as e:
@@ -825,6 +909,27 @@ class WorkflowService:
input_data["conv_messages"] = conv_messages input_data["conv_messages"] = conv_messages
init_message_length = len(input_data.get("conv_messages", [])) init_message_length = len(input_data.get("conv_messages", []))
message_id = uuid.uuid4() message_id = uuid.uuid4()
# 新会话时写入开场白
is_new_conversation = init_message_length == 0
if is_new_conversation:
opening_cfg = feature_configs.get("opening_statement", {})
if isinstance(opening_cfg, dict) and opening_cfg.get("enabled") and opening_cfg.get("statement"):
statement = opening_cfg["statement"]
suggested_questions = opening_cfg.get("suggested_questions", [])
if payload.variables:
for var_name, var_value in payload.variables.items():
statement = statement.replace(f"{{{{{var_name}}}}}", str(var_value))
self.conversation_service.add_message(
conversation_id=conversation_id_uuid,
role="assistant",
content=statement,
meta_data={"suggested_questions": suggested_questions}
)
# 注入到 conv_messages让 LLM 感知开场白
input_data["conv_messages"] = [{"role": "assistant", "content": statement}]
init_message_length = 1
async for event in execute_workflow_stream( async for event in execute_workflow_stream(
workflow_config=workflow_config_dict, workflow_config=workflow_config_dict,
input_data=input_data, input_data=input_data,
@@ -862,12 +967,21 @@ class WorkflowService:
content=human_message, content=human_message,
meta_data=human_meta meta_data=human_meta
) )
# 过滤 citations
citations = event.get("data", {}).get("citations", [])
citation_cfg = feature_configs.get("citation", {})
filtered_citations = (
citations if isinstance(citation_cfg, dict) and citation_cfg.get("enabled") else []
)
assistant_meta = {"usage": token_usage, "audio_url": None}
if filtered_citations:
assistant_meta["citations"] = filtered_citations
self.conversation_service.add_message( self.conversation_service.add_message(
message_id=message_id, message_id=message_id,
conversation_id=conversation_id_uuid, conversation_id=conversation_id_uuid,
role="assistant", role="assistant",
content=assistant_message, content=assistant_message,
meta_data={"usage": token_usage, "audio_url": None} meta_data=assistant_meta
) )
self.update_execution_status( self.update_execution_status(
execution.execution_id, execution.execution_id,
@@ -875,6 +989,7 @@ class WorkflowService:
output_data=event.get("data"), output_data=event.get("data"),
token_usage=token_usage.get("total_tokens", None) token_usage=token_usage.get("total_tokens", None)
) )
event.setdefault("data", {})["citations"] = filtered_citations
logger.info(f"Workflow Run Success, " logger.info(f"Workflow Run Success, "
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}") f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
elif status == "failed": elif status == "failed":

View File

@@ -480,21 +480,21 @@ def create_workspace_invite(
try: try:
# 检查权限 # 检查权限
_check_workspace_admin_permission(db, workspace_id, user) _check_workspace_admin_permission(db, workspace_id, user)
if settings.ENABLE_SINGLE_WORKSPACE: # if settings.ENABLE_SINGLE_WORKSPACE:
# 检查被邀请用户是否已经在工作空间中 # 检查被邀请用户是否已经在工作空间中
from app.repositories import user_repository from app.repositories import user_repository
invited_user = user_repository.get_user_by_email(db, invite_data.email) invited_user = user_repository.get_user_by_email(db, invite_data.email)
if invited_user: if invited_user:
# 用户存在,检查是否已经是工作空间成员 # 用户存在,检查是否已经是工作空间成员
existing_member = workspace_repository.get_member_in_workspace( existing_member = workspace_repository.get_member_in_workspace(
db=db, db=db,
user_id=invited_user.id, user_id=invited_user.id,
workspace_id=workspace_id workspace_id=workspace_id
) )
if existing_member: if existing_member:
business_logger.warning(f"用户 {invite_data.email} 已经是工作空间成员") business_logger.warning(f"用户 {invite_data.email} 已经是工作空间成员")
raise BusinessException("该用户已经是工作空间成员", BizCode.RESOURCE_ALREADY_EXISTS) raise BusinessException("该用户已经是工作空间成员", BizCode.RESOURCE_ALREADY_EXISTS)
# 检查是否已有待处理的邀请 # 检查是否已有待处理的邀请
invite_repo = WorkspaceInviteRepository(db) invite_repo = WorkspaceInviteRepository(db)

View File

@@ -153,7 +153,8 @@ def workflow_config_4_app_release(release: AppRelease) -> WorkflowConfig:
edges=config_dict.get("edges", []), edges=config_dict.get("edges", []),
variables=config_dict.get("variables", []), variables=config_dict.get("variables", []),
execution_config=config_dict.get("execution_config", {}), execution_config=config_dict.get("execution_config", {}),
triggers=config_dict.get("triggers", []) triggers=config_dict.get("triggers", []),
features=config_dict.get("features", {})
) )
return config return config

View File

@@ -11,13 +11,13 @@ import { ContentEditable } from '@lexical/react/LexicalContentEditable';
import { HistoryPlugin } from '@lexical/react/LexicalHistoryPlugin'; import { HistoryPlugin } from '@lexical/react/LexicalHistoryPlugin';
import { LexicalErrorBoundary } from '@lexical/react/LexicalErrorBoundary'; import { LexicalErrorBoundary } from '@lexical/react/LexicalErrorBoundary';
import AutocompletePlugin, { type Suggestion } from './plugin/AutocompletePlugin'; import { type Suggestion } from './plugin/AutocompletePlugin';
import CharacterCountPlugin from './plugin/CharacterCountPlugin'; import CharacterCountPlugin from './plugin/CharacterCountPlugin';
import InitialValuePlugin from './plugin/InitialValuePlugin'; import Jinja2InitialValuePlugin from './plugin/Jinja2InitialValuePlugin';
import CommandPlugin from './plugin/CommandPlugin'; import Jinja2AutocompletePlugin from './plugin/Jinja2AutocompletePlugin';
import Jinja2HighlightPlugin from './plugin/Jinja2HighlightPlugin'; import Jinja2HighlightPlugin from './plugin/Jinja2HighlightPlugin';
import Jinja2BlurPlugin from './plugin/Jinja2BlurPlugin';
import LineNumberPlugin from './plugin/LineNumberPlugin'; import LineNumberPlugin from './plugin/LineNumberPlugin';
import BlurPlugin from './plugin/BlurPlugin';
const jinja2Theme = { const jinja2Theme = {
paragraph: 'editor-paragraph', paragraph: 'editor-paragraph',
@@ -171,13 +171,12 @@ const Jinja2Editor: FC<Jinja2EditorProps> = ({
ErrorBoundary={LexicalErrorBoundary} ErrorBoundary={LexicalErrorBoundary}
/> />
<HistoryPlugin /> <HistoryPlugin />
<CommandPlugin />
<Jinja2HighlightPlugin /> <Jinja2HighlightPlugin />
<LineNumberPlugin /> <LineNumberPlugin />
<AutocompletePlugin options={options} enableJinja2 /> <Jinja2AutocompletePlugin options={options} />
<CharacterCountPlugin setCount={() => {}} onChange={onChange} /> <CharacterCountPlugin setCount={() => {}} onChange={onChange} waitForInit />
<InitialValuePlugin value={value} options={options} enableLineNumbers /> <Jinja2InitialValuePlugin value={value} />
<BlurPlugin enableJinja2 /> <Jinja2BlurPlugin />
</div> </div>
</LexicalComposer> </LexicalComposer>
); );

View File

@@ -33,6 +33,7 @@ export interface LexicalEditorProps {
type?: 'input' | 'textarea'; type?: 'input' | 'textarea';
language?: 'string' | 'jinja2'; language?: 'string' | 'jinja2';
className?: string; className?: string;
waitForInit?: boolean;
} }
// Default theme for editor // Default theme for editor
@@ -55,8 +56,10 @@ const Editor: FC<LexicalEditorProps> =({
type = 'textarea', type = 'textarea',
language = 'string', language = 'string',
height, height,
className className,
waitForInit = false,
}) => { }) => {
console.log('Editor value', value)
const [_count, setCount] = useState(0); const [_count, setCount] = useState(0);
if (language === 'jinja2') { if (language === 'jinja2') {

View File

@@ -2,11 +2,11 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2025-12-23 16:22:51 * @Date: 2025-12-23 16:22:51
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-25 16:13:37 * @Last Modified time: 2026-04-02 17:12:41
*/ */
import { useEffect, useLayoutEffect, useState, useRef, type FC } from 'react'; import { useEffect, useLayoutEffect, useState, useRef, type FC } from 'react';
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'; import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
import { $getSelection, $isRangeSelection, $isTextNode, COMMAND_PRIORITY_HIGH, KEY_ENTER_COMMAND, KEY_ARROW_DOWN_COMMAND, KEY_ARROW_UP_COMMAND, KEY_ESCAPE_COMMAND } from 'lexical'; import { $getSelection, $isRangeSelection, COMMAND_PRIORITY_HIGH, KEY_ENTER_COMMAND, KEY_ARROW_DOWN_COMMAND, KEY_ARROW_UP_COMMAND, KEY_ESCAPE_COMMAND } from 'lexical';
import { Space, Flex } from 'antd'; import { Space, Flex } from 'antd';
import { INSERT_VARIABLE_COMMAND, CLOSE_AUTOCOMPLETE_COMMAND } from '../commands'; import { INSERT_VARIABLE_COMMAND, CLOSE_AUTOCOMPLETE_COMMAND } from '../commands';
@@ -28,7 +28,7 @@ export interface Suggestion {
} }
// Autocomplete plugin for variable suggestions triggered by '/' character // Autocomplete plugin for variable suggestions triggered by '/' character
const AutocompletePlugin: FC<{ options: Suggestion[], enableJinja2?: boolean }> = ({ options, enableJinja2 = false }) => { const AutocompletePlugin: FC<{ options: Suggestion[] }> = ({ options }) => {
const [editor] = useLexicalComposerContext(); const [editor] = useLexicalComposerContext();
const [showSuggestions, setShowSuggestions] = useState(false); const [showSuggestions, setShowSuggestions] = useState(false);
const [selectedIndex, setSelectedIndex] = useState(0); const [selectedIndex, setSelectedIndex] = useState(0);
@@ -159,34 +159,7 @@ const AutocompletePlugin: FC<{ options: Suggestion[], enableJinja2?: boolean }>
// Insert selected suggestion into editor // Insert selected suggestion into editor
const insertMention = (suggestion: Suggestion) => { const insertMention = (suggestion: Suggestion) => {
if (enableJinja2) { editor.dispatchCommand(INSERT_VARIABLE_COMMAND, { data: suggestion });
// In Jinja2 mode, insert {{variable}} format text
editor.update(() => {
const selection = $getSelection();
if ($isRangeSelection(selection)) {
const anchorNode = selection.anchor.getNode();
const anchorOffset = selection.anchor.offset;
const nodeText = anchorNode.getTextContent();
// Remove trigger character '/'
const textBefore = nodeText.substring(0, anchorOffset - 1);
const textAfter = nodeText.substring(anchorOffset);
const newText = textBefore + `{{${suggestion.value}}}` + textAfter;
if ($isTextNode(anchorNode)) {
anchorNode.setTextContent(newText);
}
// Set cursor position after inserted text
const newOffset = textBefore.length + `{{${suggestion.value}}}`.length;
selection.anchor.offset = newOffset;
selection.focus.offset = newOffset;
}
});
} else {
// In normal mode, use VariableNode
editor.dispatchCommand(INSERT_VARIABLE_COMMAND, { data: suggestion });
}
setShowSuggestions(false); setShowSuggestions(false);
setExpandedParent(null); setExpandedParent(null);
setChildPanelTop(0); setChildPanelTop(0);

View File

@@ -1,64 +1,33 @@
/* /*
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-01-20 10:42:13 * @Date: 2026-01-20 10:42:13
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-03 10:12:10 * @Last Modified time: 2026-04-02 17:13:08
*/ */
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'; import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
import { useEffect } from 'react'; import { useEffect } from 'react';
import { $setSelection } from 'lexical';
import { CLOSE_AUTOCOMPLETE_COMMAND } from '../commands'; import { CLOSE_AUTOCOMPLETE_COMMAND } from '../commands';
// Plugin to handle blur events and close autocomplete when clicking outside // Plugin to handle blur events and close autocomplete when clicking outside
export default function BlurPlugin({ enableJinja2 }: { enableJinja2: boolean }) { export default function BlurPlugin() {
const [editor] = useLexicalComposerContext(); const [editor] = useLexicalComposerContext();
useEffect(() => { useEffect(() => {
// Close autocomplete when clicking outside the popup
const handleClickOutside = (e: MouseEvent) => { const handleClickOutside = (e: MouseEvent) => {
const target = e.target as HTMLElement; if ((e.target as HTMLElement)?.closest('[data-autocomplete-popup="true"]')) return;
if (target?.closest('[data-autocomplete-popup="true"]')) {
return;
}
editor.dispatchCommand(CLOSE_AUTOCOMPLETE_COMMAND, undefined); editor.dispatchCommand(CLOSE_AUTOCOMPLETE_COMMAND, undefined);
}; };
document.addEventListener('mousedown', handleClickOutside); document.addEventListener('mousedown', handleClickOutside);
return editor.registerRootListener((rootElement) => { return editor.registerRootListener((rootElement) => {
if (rootElement) { if (rootElement) {
const handleBlur = (e: FocusEvent) => {
if (enableJinja2) {
// Check if autocomplete popup was clicked
const target = e.target as HTMLElement;
if (target?.closest('[data-autocomplete-popup="true"]')) {
return;
}
// Check if blur was caused by paste operation
const relatedTarget = e.relatedTarget as HTMLElement;
if (!relatedTarget || relatedTarget === document.body) {
return;
}
// Clear selection on blur
editor.update(() => {
$setSelection(null);
});
}
};
rootElement.addEventListener('blur', handleBlur);
return () => { return () => {
document.removeEventListener('mousedown', handleClickOutside); document.removeEventListener('mousedown', handleClickOutside);
rootElement.removeEventListener('blur', handleBlur);
}; };
} }
return () => { return () => { document.removeEventListener('mousedown', handleClickOutside); };
document.removeEventListener('mousedown', handleClickOutside);
};
}); });
}, [editor, enableJinja2]); }, [editor]);
return null; return null;
} }

View File

@@ -1,3 +1,9 @@
/*
* @Author: ZhaoYing
* @Date: 2025-12-23 16:22:51
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-02 17:14:15
*/
import { useEffect, useRef } from 'react'; import { useEffect, useRef } from 'react';
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'; import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
import { $getRoot, $createParagraphNode, $createTextNode } from 'lexical'; import { $getRoot, $createParagraphNode, $createTextNode } from 'lexical';
@@ -8,19 +14,17 @@ import { type Suggestion } from '../plugin/AutocompletePlugin'
interface InitialValuePluginProps { interface InitialValuePluginProps {
value: string; value: string;
options?: Suggestion[]; options?: Suggestion[];
enableLineNumbers?: boolean;
} }
const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options = [], enableLineNumbers = false }) => { const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options = [] }) => {
const [editor] = useLexicalComposerContext(); const [editor] = useLexicalComposerContext();
const prevValueRef = useRef<string>(''); const prevValueRef = useRef<string>('');
const prevEnableLineNumbersRef = useRef<boolean>(enableLineNumbers);
const isUserInputRef = useRef(false); const isUserInputRef = useRef(false);
const optionsRef = useRef(options); const optionsRef = useRef(options);
optionsRef.current = options; optionsRef.current = options;
useEffect(() => { useEffect(() => {
const removeListener = editor.registerUpdateListener(({ editorState, tags }) => { return editor.registerUpdateListener(({ editorState, tags }) => {
if (tags.has('programmatic')) return; if (tags.has('programmatic')) return;
editorState.read(() => { editorState.read(() => {
const root = $getRoot(); const root = $getRoot();
@@ -31,21 +35,16 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
} }
}); });
}); });
return removeListener;
}, [editor]); }, [editor]);
useEffect(() => { useEffect(() => {
if (value !== prevValueRef.current || enableLineNumbers !== prevEnableLineNumbersRef.current) { if (value !== prevValueRef.current) {
// Skip reset if the change was triggered by user input (avoid cursor jump) if (isUserInputRef.current) {
if (isUserInputRef.current && enableLineNumbers === prevEnableLineNumbersRef.current) {
prevValueRef.current = value; prevValueRef.current = value;
isUserInputRef.current = false; isUserInputRef.current = false;
return; return;
} }
// Update refs BEFORE editor.update to prevent re-entry
prevValueRef.current = value; prevValueRef.current = value;
prevEnableLineNumbersRef.current = enableLineNumbers;
isUserInputRef.current = false; isUserInputRef.current = false;
queueMicrotask(() => { queueMicrotask(() => {
@@ -54,16 +53,7 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
root.clear(); root.clear();
const parts = value.split(/(\{\{[^}]+\}\}|\n)/); const parts = value.split(/(\{\{[^}]+\}\}|\n)/);
let paragraph = $createParagraphNode();
if (enableLineNumbers) {
const lines = value.split('\n');
lines.forEach((line) => {
const paragraph = $createParagraphNode();
paragraph.append($createTextNode(line));
root.append(paragraph);
});
} else {
let paragraph = $createParagraphNode();
parts.forEach(part => { parts.forEach(part => {
if (part === '\n') { if (part === '\n') {
@@ -129,15 +119,10 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
} }
}); });
root.append(paragraph); root.append(paragraph);
}
}, { tag: 'programmatic' }); }, { tag: 'programmatic' });
}); });
} else {
prevValueRef.current = value;
prevEnableLineNumbersRef.current = enableLineNumbers;
isUserInputRef.current = false;
} }
}, [value, editor, enableLineNumbers]); }, [value, editor]);
return null; return null;
}; };

View File

@@ -0,0 +1,199 @@
/*
* @Author: ZhaoYing
* @Date: 2026-04-02 17:10:59
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-02 17:10:59
*/
import { useEffect, useState, useRef, type FC } from 'react';
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
import {
$getSelection, $isRangeSelection, $isTextNode,
COMMAND_PRIORITY_HIGH, KEY_ENTER_COMMAND, KEY_ARROW_DOWN_COMMAND,
KEY_ARROW_UP_COMMAND, KEY_ESCAPE_COMMAND,
} from 'lexical';
import { Space, Flex } from 'antd';
import { CLOSE_AUTOCOMPLETE_COMMAND } from '../commands';
import type { Suggestion } from './AutocompletePlugin';
const Jinja2AutocompletePlugin: FC<{ options: Suggestion[] }> = ({ options }) => {
const [editor] = useLexicalComposerContext();
const [showSuggestions, setShowSuggestions] = useState(false);
const [selectedIndex, setSelectedIndex] = useState(0);
const [popupPosition, setPopupPosition] = useState({ top: 0, left: 0 });
const popupRef = useRef<HTMLDivElement>(null);
const scrollSelectedIntoView = () => {
if (!popupRef.current) return;
const selectedElement = popupRef.current.querySelector('[data-selected="true"]');
if (!selectedElement) return;
const container = popupRef.current;
const element = selectedElement as HTMLElement;
const containerRect = container.getBoundingClientRect();
const elementRect = element.getBoundingClientRect();
if (elementRect.bottom > containerRect.bottom) {
container.scrollTop += elementRect.bottom - containerRect.bottom;
} else if (elementRect.top < containerRect.top) {
container.scrollTop -= containerRect.top - elementRect.top;
}
};
useEffect(() => {
return editor.registerUpdateListener(({ editorState }) => {
editorState.read(() => {
const selection = $getSelection();
if (!selection || !$isRangeSelection(selection)) {
setShowSuggestions(false);
return;
}
const anchorNode = selection.anchor.getNode();
const anchorOffset = selection.anchor.offset;
const textBeforeCursor = anchorNode.getTextContent().substring(0, anchorOffset);
const shouldShow = textBeforeCursor.endsWith('/');
setShowSuggestions(shouldShow);
if (!shouldShow) { setSelectedIndex(0); return; }
const domSelection = window.getSelection();
if (domSelection && domSelection.rangeCount > 0) {
const rect = domSelection.getRangeAt(0).getBoundingClientRect();
const popupWidth = 280, popupHeight = 200;
const vw = window.innerWidth, vh = window.innerHeight;
let left = Math.min(Math.max(rect.left, 10), vw - popupWidth - 10);
let top = rect.top - 10;
if (top - popupHeight < 10) {
top = Math.min(rect.bottom + 10, vh - popupHeight - 10);
}
setPopupPosition({ top, left });
}
});
});
}, [editor]);
useEffect(() => {
return editor.registerCommand(
CLOSE_AUTOCOMPLETE_COMMAND,
() => { setShowSuggestions(false); return true; },
COMMAND_PRIORITY_HIGH,
);
}, [editor]);
const insertMention = (suggestion: Suggestion) => {
editor.update(() => {
const selection = $getSelection();
if (!$isRangeSelection(selection)) return;
const anchorNode = selection.anchor.getNode();
const anchorOffset = selection.anchor.offset;
const nodeText = anchorNode.getTextContent();
const textBefore = nodeText.substring(0, anchorOffset - 1);
const textAfter = nodeText.substring(anchorOffset);
const inserted = `{{${suggestion.value}}}`;
if ($isTextNode(anchorNode)) {
anchorNode.setTextContent(textBefore + inserted + textAfter);
const newOffset = textBefore.length + inserted.length;
selection.anchor.offset = newOffset;
selection.focus.offset = newOffset;
}
});
setShowSuggestions(false);
};
const groupedSuggestions = options.reduce((groups: Record<string, Suggestion[]>, s) => {
const id = s.nodeData.id as string;
if (!groups[id]) groups[id] = [];
groups[id].push(s);
return groups;
}, {});
const allOptions = Object.values(groupedSuggestions).flat();
useEffect(() => {
if (!showSuggestions) return;
return editor.registerCommand(
KEY_ENTER_COMMAND,
(event) => {
const opt = allOptions[selectedIndex];
if (opt && !opt.disabled) { event?.preventDefault(); insertMention(opt); return true; }
return false;
},
COMMAND_PRIORITY_HIGH,
);
}, [showSuggestions, selectedIndex, allOptions]);
useEffect(() => {
if (!showSuggestions) return;
const down = editor.registerCommand(KEY_ARROW_DOWN_COMMAND, (e) => {
e?.preventDefault();
setSelectedIndex(prev => {
let next = prev + 1;
while (next < allOptions.length && allOptions[next].disabled) next++;
setTimeout(scrollSelectedIntoView, 0);
return next >= allOptions.length ? prev : next;
});
return true;
}, COMMAND_PRIORITY_HIGH);
const up = editor.registerCommand(KEY_ARROW_UP_COMMAND, (e) => {
e?.preventDefault();
setSelectedIndex(prev => {
let p = prev - 1;
while (p >= 0 && allOptions[p].disabled) p--;
setTimeout(scrollSelectedIntoView, 0);
return p < 0 ? prev : p;
});
return true;
}, COMMAND_PRIORITY_HIGH);
const esc = editor.registerCommand(KEY_ESCAPE_COMMAND, (e) => {
e?.preventDefault(); setShowSuggestions(false); return true;
}, COMMAND_PRIORITY_HIGH);
return () => { down(); up(); esc(); };
}, [showSuggestions, selectedIndex, allOptions, editor]);
if (!showSuggestions || Object.keys(groupedSuggestions).length === 0) return null;
return (
<div
ref={popupRef}
data-autocomplete-popup="true"
onMouseDown={(e) => e.preventDefault()}
className="rb:fixed rb:z-1000 rb:py-1 rb:bg-white rb:rounded-xl rb:min-w-70 rb:max-h-50 rb:overflow-y-auto rb:transform-[translateY(-100%)] rb:shadow-[0px_2px_12px_0px_rgba(23,23,25,0.12)]"
style={{ top: popupPosition.top, left: popupPosition.left }}
>
<Flex vertical gap={12}>
{Object.entries(groupedSuggestions).map(([nodeId, nodeOptions]) => (
<div key={nodeId}>
<Flex align="center" gap={4} className="rb:px-3! rb:text-[12px] rb:py-1.25! rb:font-medium rb:text-[#5B6167]">
{nodeOptions[0]?.nodeData?.icon && <img src={nodeOptions[0].nodeData.icon} className="rb:size-3" alt="" />}
{nodeOptions[0]?.nodeData?.name || nodeId}
</Flex>
{nodeOptions.map((option) => {
const globalIndex = allOptions.indexOf(option);
return (
<Flex
key={option.key}
data-selected={selectedIndex === globalIndex}
className="rb:pl-6! rb:pr-3! rb:py-2!"
align="center"
justify="space-between"
style={{
cursor: option.disabled ? 'not-allowed' : 'pointer',
background: selectedIndex === globalIndex ? '#f0f8ff' : 'white',
opacity: option.disabled ? 0.5 : 1,
}}
onClick={() => !option.disabled && insertMention(option)}
onMouseEnter={() => setSelectedIndex(globalIndex)}
>
<Space size={4}>
<span className="rb:text-[#155EEF]">{option.isContext ? '📄' : '{x}'}</span>
<span>{option.label}</span>
</Space>
{option.dataType && <span className="rb:text-[#5B6167]">{option.dataType}</span>}
</Flex>
);
})}
</div>
))}
</Flex>
</div>
);
};
export default Jinja2AutocompletePlugin;

View File

@@ -0,0 +1,41 @@
/*
* @Author: ZhaoYing
* @Date: 2026-04-02 17:11:04
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-02 17:11:04
*/
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
import { useEffect } from 'react';
import { $setSelection } from 'lexical';
import { CLOSE_AUTOCOMPLETE_COMMAND } from '../commands';
export default function Jinja2BlurPlugin() {
const [editor] = useLexicalComposerContext();
useEffect(() => {
const handleClickOutside = (e: MouseEvent) => {
if ((e.target as HTMLElement)?.closest('[data-autocomplete-popup="true"]')) return;
editor.dispatchCommand(CLOSE_AUTOCOMPLETE_COMMAND, undefined);
};
document.addEventListener('mousedown', handleClickOutside);
return editor.registerRootListener((rootElement) => {
if (rootElement) {
const handleBlur = (e: FocusEvent) => {
if ((e.target as HTMLElement)?.closest('[data-autocomplete-popup="true"]')) return;
const relatedTarget = e.relatedTarget as HTMLElement;
if (!relatedTarget || relatedTarget === document.body) return;
editor.update(() => { $setSelection(null); });
};
rootElement.addEventListener('blur', handleBlur);
return () => {
document.removeEventListener('mousedown', handleClickOutside);
rootElement.removeEventListener('blur', handleBlur);
};
}
return () => { document.removeEventListener('mousedown', handleClickOutside); };
});
}, [editor]);
return null;
}

View File

@@ -0,0 +1,61 @@
/*
* @Author: ZhaoYing
* @Date: 2026-04-02 17:11:07
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-02 17:11:07
*/
import { useEffect, useRef } from 'react';
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
import { $getRoot, $createParagraphNode, $createTextNode } from 'lexical';
interface Jinja2InitialValuePluginProps {
value: string;
}
const Jinja2InitialValuePlugin: React.FC<Jinja2InitialValuePluginProps> = ({ value }) => {
const [editor] = useLexicalComposerContext();
const prevValueRef = useRef<string>('');
const isUserInputRef = useRef(false);
useEffect(() => {
return editor.registerUpdateListener(({ editorState, tags }) => {
if (tags.has('programmatic')) return;
editorState.read(() => {
const textContent = $getRoot().getTextContent();
if (textContent !== prevValueRef.current) {
isUserInputRef.current = true;
prevValueRef.current = textContent;
}
});
});
}, [editor]);
useEffect(() => {
if (value === prevValueRef.current) return;
if (isUserInputRef.current) {
prevValueRef.current = value;
isUserInputRef.current = false;
return;
}
prevValueRef.current = value;
isUserInputRef.current = false;
queueMicrotask(() => {
editor.update(() => {
const root = $getRoot();
root.clear();
value.split('\n').forEach((line) => {
const paragraph = $createParagraphNode();
paragraph.append($createTextNode(line));
root.append(paragraph);
});
}, { tag: 'programmatic' });
});
}, [value, editor]);
return null;
};
export default Jinja2InitialValuePlugin;

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-09 18:35:43 * @Date: 2026-02-09 18:35:43
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-20 11:32:44 * @Last Modified time: 2026-04-02 17:17:06
*/ */
import { type FC, useRef, useState } from "react"; import { type FC, useRef, useState } from "react";
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
@@ -114,6 +114,7 @@ const HttpRequest: FC<{ options: Suggestion[]; selectedNode?: any; graphRef?: an
<Col span={16}> <Col span={16}>
<Form.Item name="url"> <Form.Item name="url">
<Editor <Editor
key="url"
options={options.filter(vo => vo.dataType === 'string' || vo.dataType === 'number')} options={options.filter(vo => vo.dataType === 'string' || vo.dataType === 'number')}
variant="outlined" variant="outlined"
type="input" type="input"
@@ -212,13 +213,15 @@ const HttpRequest: FC<{ options: Suggestion[]; selectedNode?: any; graphRef?: an
} }
{values?.body?.content_type === 'binary' && {values?.body?.content_type === 'binary' &&
<Form.Item name={['body', 'data']} <Form.Item name={['body', 'data']}
className="rb:bg-[#F6F6F6] rb:border-[#F6F6F6]! rb:hover:bg-white rb:hover:border-[#171719]! rb:border rb:rounded-lg rb:px-2! rb:py-1.5! rb:mb-0!" className="rb:bg-[#F6F6F6] rb:border-[#F6F6F6]! rb:hover:bg-white rb:hover:border-[#171719]! rb:border rb:rounded-lg rb:mb-0!"
> >
<Editor <Editor
key={['body', 'data'].join('_')}
placeholder={t('common.pleaseSelect')} placeholder={t('common.pleaseSelect')}
options={options.filter(vo => vo.dataType.includes('file'))} options={options.filter(vo => vo.dataType.includes('file'))}
type="input" type="input"
size="small" size="small"
height={28}
/> />
</Form.Item> </Form.Item>
} }