Merge branch 'develop' into feature/ui_upgrade_zy
This commit is contained in:
@@ -292,10 +292,19 @@ def get_opening(
|
||||
):
|
||||
"""返回开场白文本和预设问题,供前端对话界面初始化时展示"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
|
||||
features = cfg.features or {}
|
||||
if hasattr(features, "model_dump"):
|
||||
features = features.model_dump()
|
||||
|
||||
# 根据应用类型获取 features
|
||||
from app.models.app_model import App as AppModel
|
||||
app = db.get(AppModel, app_id)
|
||||
if app and app.type == "workflow":
|
||||
cfg = app_service.get_workflow_config(db=db, app_id=app_id, workspace_id=workspace_id)
|
||||
features = cfg.features or {}
|
||||
else:
|
||||
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
|
||||
features = cfg.features or {}
|
||||
if hasattr(features, "model_dump"):
|
||||
features = features.model_dump()
|
||||
|
||||
opening = features.get("opening_statement", {})
|
||||
return success(data=app_schema.OpeningResponse(
|
||||
enabled=opening.get("enabled", False),
|
||||
@@ -1070,6 +1079,14 @@ async def update_workflow_config(
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if payload.variables:
|
||||
from app.services.workflow_service import WorkflowService
|
||||
resolved = await WorkflowService(db)._resolve_variables_file_defaults(
|
||||
[v.model_dump() for v in payload.variables]
|
||||
)
|
||||
# Patch default values back into VariableDefinition objects
|
||||
for var_def, resolved_def in zip(payload.variables, resolved):
|
||||
var_def.default = resolved_def.get("default", var_def.default)
|
||||
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
||||
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
||||
|
||||
|
||||
@@ -53,22 +53,24 @@ async def login_for_access_token(
|
||||
user = auth_service.authenticate_user_or_raise(db, form_data.email, form_data.password)
|
||||
auth_logger.info(f"用户认证成功: {user.email} (ID: {user.id})")
|
||||
if form_data.invite:
|
||||
auth_service.bind_workspace_with_invite(db=db,
|
||||
user=user,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id)
|
||||
auth_service.bind_workspace_with_invite(
|
||||
db=db,
|
||||
user=user,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id
|
||||
)
|
||||
except BusinessException as e:
|
||||
# 用户不存在且有邀请码,尝试注册
|
||||
if e.code == BizCode.USER_NOT_FOUND:
|
||||
auth_logger.info(f"用户不存在,使用邀请码注册: {form_data.email}")
|
||||
user = auth_service.register_user_with_invite(
|
||||
db=db,
|
||||
email=form_data.email,
|
||||
username=form_data.username,
|
||||
password=form_data.password,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id
|
||||
)
|
||||
db=db,
|
||||
email=form_data.email,
|
||||
username=form_data.username,
|
||||
password=form_data.password,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id
|
||||
)
|
||||
elif e.code == BizCode.PASSWORD_ERROR:
|
||||
# 用户存在但密码错误
|
||||
auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}")
|
||||
|
||||
@@ -314,8 +314,10 @@ async def parse_documents(
|
||||
)
|
||||
|
||||
# 4. Check if the file exists
|
||||
api_logger.debug(f"Constructed file path: {file_path}")
|
||||
api_logger.debug(f"File metadata - kb_id: {db_file.kb_id}, parent_id: {db_file.parent_id}, file_id: {db_file.id}, extension: {db_file.file_ext}")
|
||||
if not os.path.exists(file_path):
|
||||
api_logger.warning(f"File not found (possibly deleted): file_path={file_path}")
|
||||
api_logger.error(f"File not found (possibly deleted): file_path={file_path}, file_id={db_file.id}, document_id={document_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="File not found (possibly deleted)"
|
||||
|
||||
@@ -475,7 +475,7 @@ class LangChainAgent:
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
files: Optional[List[Dict[str, Any]]] = None
|
||||
) -> AsyncGenerator[str | int, None]:
|
||||
) -> AsyncGenerator[str | int | dict[str, str], None]:
|
||||
"""执行流式对话
|
||||
|
||||
Args:
|
||||
|
||||
@@ -0,0 +1,408 @@
|
||||
"""
|
||||
Perceptual Memory Retrieval Node & Service
|
||||
|
||||
Provides PerceptualSearchService for searching perceptual memories (vision, audio,
|
||||
text, conversation) from Neo4j using keyword fulltext + embedding semantic search
|
||||
with BM25+embedding fusion reranking.
|
||||
|
||||
Also provides the perceptual_retrieve_node for use as a LangGraph node.
|
||||
"""
|
||||
import asyncio
|
||||
import math
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState
|
||||
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||
from app.repositories.neo4j.graph_search import (
|
||||
search_perceptual,
|
||||
search_perceptual_by_embedding,
|
||||
)
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class PerceptualSearchService:
|
||||
"""
|
||||
感知记忆检索服务。
|
||||
|
||||
封装关键词全文检索 + 向量语义检索 + BM25/embedding 融合排序的完整流程。
|
||||
调用方只需提供 query / keywords、end_user_id、memory_config,即可获得
|
||||
格式化并排序后的感知记忆列表和拼接文本。
|
||||
|
||||
Usage:
|
||||
service = PerceptualSearchService(end_user_id=..., memory_config=...)
|
||||
results = await service.search(query="...", keywords=[...], limit=10)
|
||||
# results = {"memories": [...], "content": "...", "keyword_raw": N, "embedding_raw": M}
|
||||
"""
|
||||
|
||||
DEFAULT_ALPHA = 0.6
|
||||
DEFAULT_CONTENT_SCORE_THRESHOLD = 0.5
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
end_user_id: str,
|
||||
memory_config: Any,
|
||||
alpha: float = DEFAULT_ALPHA,
|
||||
content_score_threshold: float = DEFAULT_CONTENT_SCORE_THRESHOLD,
|
||||
):
|
||||
self.end_user_id = end_user_id
|
||||
self.memory_config = memory_config
|
||||
self.alpha = alpha
|
||||
self.content_score_threshold = content_score_threshold
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
keywords: Optional[List[str]] = None,
|
||||
limit: int = 10,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
执行感知记忆检索(关键词 + 向量并行),融合排序后返回结果。
|
||||
|
||||
对 embedding 命中但 keyword 未命中的结果,补查全文索引获取 BM25 分数,
|
||||
确保所有结果都同时具备 BM25 和 embedding 两个维度的评分。
|
||||
|
||||
Args:
|
||||
query: 原始用户查询(用于向量检索和 BM25 补查)
|
||||
keywords: 关键词列表(用于全文检索),为 None 时使用 [query]
|
||||
limit: 最大返回数量
|
||||
|
||||
Returns:
|
||||
{
|
||||
"memories": [格式化后的记忆 dict, ...],
|
||||
"content": "拼接的纯文本摘要",
|
||||
"keyword_raw": int,
|
||||
"embedding_raw": int,
|
||||
}
|
||||
"""
|
||||
if keywords is None:
|
||||
keywords = [query] if query else []
|
||||
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
kw_task = self._keyword_search(connector, keywords, limit)
|
||||
emb_task = self._embedding_search(connector, query, limit)
|
||||
|
||||
kw_results, emb_results = await asyncio.gather(
|
||||
kw_task, emb_task, return_exceptions=True
|
||||
)
|
||||
if isinstance(kw_results, Exception):
|
||||
logger.warning(f"[PerceptualSearch] keyword search error: {kw_results}")
|
||||
kw_results = []
|
||||
if isinstance(emb_results, Exception):
|
||||
logger.warning(f"[PerceptualSearch] embedding search error: {emb_results}")
|
||||
emb_results = []
|
||||
|
||||
# 补查 BM25:找出 embedding 命中但 keyword 未命中的 id,
|
||||
# 用原始 query 对这些节点补查全文索引拿 BM25 score
|
||||
kw_ids = {r.get("id") for r in kw_results if r.get("id")}
|
||||
emb_only_ids = {r.get("id") for r in emb_results if r.get("id") and r.get("id") not in kw_ids}
|
||||
|
||||
if emb_only_ids and query:
|
||||
backfill = await self._bm25_backfill(connector, query, emb_only_ids, limit)
|
||||
# 把补查到的 BM25 score 注入到 embedding 结果中
|
||||
backfill_map = {r["id"]: r.get("score", 0) for r in backfill}
|
||||
for r in emb_results:
|
||||
rid = r.get("id", "")
|
||||
if rid in backfill_map:
|
||||
r["bm25_backfill_score"] = backfill_map[rid]
|
||||
logger.info(
|
||||
f"[PerceptualSearch] BM25 backfill: {len(emb_only_ids)} embedding-only ids, "
|
||||
f"{len(backfill_map)} got BM25 scores"
|
||||
)
|
||||
|
||||
reranked = self._rerank(kw_results, emb_results, limit)
|
||||
|
||||
memories = []
|
||||
content_parts = []
|
||||
for record in reranked:
|
||||
fmt = self._format_result(record)
|
||||
fmt["score"] = round(record.get("content_score", 0), 4)
|
||||
memories.append(fmt)
|
||||
content_parts.append(self._build_content_text(fmt))
|
||||
|
||||
logger.info(
|
||||
f"[PerceptualSearch] {len(memories)} results after rerank "
|
||||
f"(keyword_raw={len(kw_results)}, embedding_raw={len(emb_results)})"
|
||||
)
|
||||
return {
|
||||
"memories": memories,
|
||||
"content": "\n\n".join(content_parts),
|
||||
"keyword_raw": len(kw_results),
|
||||
"embedding_raw": len(emb_results),
|
||||
}
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
async def _bm25_backfill(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
query: str,
|
||||
target_ids: set,
|
||||
limit: int,
|
||||
) -> List[dict]:
|
||||
"""
|
||||
对指定 id 集合补查全文索引 BM25 score。
|
||||
|
||||
用原始 query 查全文索引,只保留 id 在 target_ids 中的结果。
|
||||
"""
|
||||
escaped = escape_lucene_query(query)
|
||||
if not escaped.strip():
|
||||
return []
|
||||
try:
|
||||
r = await search_perceptual(
|
||||
connector=connector, q=escaped,
|
||||
end_user_id=self.end_user_id,
|
||||
limit=limit * 5, # 多查一些以提高命中率
|
||||
)
|
||||
all_hits = r.get("perceptuals", [])
|
||||
return [h for h in all_hits if h.get("id") in target_ids]
|
||||
except Exception as e:
|
||||
logger.warning(f"[PerceptualSearch] BM25 backfill failed: {e}")
|
||||
return []
|
||||
|
||||
async def _keyword_search(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
keywords: List[str],
|
||||
limit: int,
|
||||
) -> List[dict]:
|
||||
"""并发对每个关键词做全文检索,去重后按 score 降序返回 top N 原始结果。"""
|
||||
seen_ids: set = set()
|
||||
all_results: List[dict] = []
|
||||
|
||||
async def _one(kw: str):
|
||||
escaped = escape_lucene_query(kw)
|
||||
if not escaped.strip():
|
||||
return []
|
||||
r = await search_perceptual(
|
||||
connector=connector, q=escaped,
|
||||
end_user_id=self.end_user_id, limit=limit,
|
||||
)
|
||||
return r.get("perceptuals", [])
|
||||
|
||||
tasks = [_one(kw) for kw in keywords[:10]]
|
||||
batch = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for result in batch:
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f"[PerceptualSearch] keyword sub-query error: {result}")
|
||||
continue
|
||||
for rec in result:
|
||||
rid = rec.get("id", "")
|
||||
if rid and rid not in seen_ids:
|
||||
seen_ids.add(rid)
|
||||
all_results.append(rec)
|
||||
|
||||
all_results.sort(key=lambda x: float(x.get("score", 0)), reverse=True)
|
||||
return all_results[:limit]
|
||||
|
||||
async def _embedding_search(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
query_text: str,
|
||||
limit: int,
|
||||
) -> List[dict]:
|
||||
"""向量语义检索,返回原始结果(不做阈值过滤)。"""
|
||||
try:
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
with get_db_context() as db:
|
||||
cfg = MemoryConfigService(db).get_embedder_config(
|
||||
str(self.memory_config.embedding_model_id)
|
||||
)
|
||||
client = OpenAIEmbedderClient(RedBearModelConfig(**cfg))
|
||||
|
||||
r = await search_perceptual_by_embedding(
|
||||
connector=connector, embedder_client=client,
|
||||
query_text=query_text, end_user_id=self.end_user_id,
|
||||
limit=limit,
|
||||
)
|
||||
return r.get("perceptuals", [])
|
||||
except Exception as e:
|
||||
logger.warning(f"[PerceptualSearch] embedding search failed: {e}")
|
||||
return []
|
||||
|
||||
def _rerank(
|
||||
self,
|
||||
keyword_results: List[dict],
|
||||
embedding_results: List[dict],
|
||||
limit: int,
|
||||
) -> List[dict]:
|
||||
"""BM25 + embedding 融合排序。
|
||||
|
||||
对 embedding 结果中带有 bm25_backfill_score 的条目,
|
||||
将其与 keyword 结果合并后统一归一化,确保 BM25 分数在同一尺度上。
|
||||
"""
|
||||
# 把补查的 BM25 score 合并到 keyword_results 中统一归一化
|
||||
emb_backfill_items = []
|
||||
for item in embedding_results:
|
||||
backfill_score = item.get("bm25_backfill_score")
|
||||
if backfill_score is not None and item.get("id"):
|
||||
emb_backfill_items.append({"id": item["id"], "score": backfill_score})
|
||||
|
||||
# 合并后统一归一化 BM25 scores
|
||||
all_bm25_items = keyword_results + emb_backfill_items
|
||||
all_bm25_items = self._normalize_scores(all_bm25_items)
|
||||
|
||||
# 建立 id -> normalized BM25 score 的映射
|
||||
bm25_norm_map: Dict[str, float] = {}
|
||||
for item in all_bm25_items:
|
||||
item_id = item.get("id", "")
|
||||
if item_id:
|
||||
bm25_norm_map[item_id] = float(item.get("normalized_score", 0))
|
||||
|
||||
# 归一化 embedding scores
|
||||
embedding_results = self._normalize_scores(embedding_results)
|
||||
|
||||
# 合并
|
||||
combined: Dict[str, dict] = {}
|
||||
for item in keyword_results:
|
||||
item_id = item.get("id", "")
|
||||
if not item_id:
|
||||
continue
|
||||
combined[item_id] = item.copy()
|
||||
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
|
||||
combined[item_id]["embedding_score"] = 0.0
|
||||
|
||||
for item in embedding_results:
|
||||
item_id = item.get("id", "")
|
||||
if not item_id:
|
||||
continue
|
||||
if item_id in combined:
|
||||
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
else:
|
||||
combined[item_id] = item.copy()
|
||||
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
|
||||
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
|
||||
for item in combined.values():
|
||||
bm25 = float(item.get("bm25_score", 0) or 0)
|
||||
emb = float(item.get("embedding_score", 0) or 0)
|
||||
item["content_score"] = self.alpha * bm25 + (1 - self.alpha) * emb
|
||||
|
||||
results = list(combined.values())
|
||||
before = len(results)
|
||||
results = [r for r in results if r["content_score"] >= self.content_score_threshold]
|
||||
results.sort(key=lambda x: x["content_score"], reverse=True)
|
||||
results = results[:limit]
|
||||
|
||||
logger.info(
|
||||
f"[PerceptualSearch] rerank: merged={before}, after_threshold={len(results)} "
|
||||
f"(alpha={self.alpha}, threshold={self.content_score_threshold})"
|
||||
)
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _normalize_scores(items: List[dict], field: str = "score") -> List[dict]:
|
||||
"""Z-score + sigmoid 归一化。"""
|
||||
if not items:
|
||||
return items
|
||||
scores = [float(it.get(field, 0) or 0) for it in items]
|
||||
if len(scores) <= 1:
|
||||
for it in items:
|
||||
it[f"normalized_{field}"] = 1.0
|
||||
return items
|
||||
mean = sum(scores) / len(scores)
|
||||
var = sum((s - mean) ** 2 for s in scores) / len(scores)
|
||||
std = math.sqrt(var)
|
||||
if std == 0:
|
||||
for it in items:
|
||||
it[f"normalized_{field}"] = 1.0
|
||||
else:
|
||||
for it, s in zip(items, scores):
|
||||
z = (s - mean) / std
|
||||
it[f"normalized_{field}"] = 1 / (1 + math.exp(-z))
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
def _format_result(record: dict) -> dict:
|
||||
return {
|
||||
"id": record.get("id", ""),
|
||||
"perceptual_type": record.get("perceptual_type", ""),
|
||||
"file_name": record.get("file_name", ""),
|
||||
"file_path": record.get("file_path", ""),
|
||||
"summary": record.get("summary", ""),
|
||||
"topic": record.get("topic", ""),
|
||||
"domain": record.get("domain", ""),
|
||||
"keywords": record.get("keywords", []),
|
||||
"created_at": str(record.get("created_at", "")),
|
||||
"file_type": record.get("file_type", ""),
|
||||
"score": record.get("score", 0),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_content_text(formatted: dict) -> str:
|
||||
parts = []
|
||||
if formatted["summary"]:
|
||||
parts.append(formatted["summary"])
|
||||
if formatted["topic"]:
|
||||
parts.append(f"[主题: {formatted['topic']}]")
|
||||
if formatted["keywords"]:
|
||||
kw_list = formatted["keywords"]
|
||||
if isinstance(kw_list, list):
|
||||
parts.append(f"[关键词: {', '.join(kw_list)}]")
|
||||
if formatted["file_name"]:
|
||||
parts.append(f"[文件: {formatted['file_name']}]")
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def _extract_keywords_from_problems(problem_extension: dict) -> List[str]:
|
||||
"""Extract search keywords from problem extension results."""
|
||||
keywords = []
|
||||
context = problem_extension.get("context", {})
|
||||
if isinstance(context, dict):
|
||||
for original_q, extended_qs in context.items():
|
||||
keywords.append(original_q)
|
||||
if isinstance(extended_qs, list):
|
||||
keywords.extend(extended_qs)
|
||||
return keywords
|
||||
|
||||
|
||||
async def perceptual_retrieve_node(state: ReadState) -> ReadState:
|
||||
"""
|
||||
LangGraph node: perceptual memory retrieval.
|
||||
|
||||
Uses PerceptualSearchService to run keyword + embedding search with
|
||||
BM25 fusion reranking, then writes results to state['perceptual_data'].
|
||||
"""
|
||||
end_user_id = state.get("end_user_id", "")
|
||||
problem_extension = state.get("problem_extension", {})
|
||||
original_query = state.get("data", "")
|
||||
memory_config = state.get("memory_config", None)
|
||||
|
||||
logger.info(f"Perceptual_Retrieve: start, end_user_id={end_user_id}")
|
||||
|
||||
keywords = _extract_keywords_from_problems(problem_extension)
|
||||
if not keywords:
|
||||
keywords = [original_query] if original_query else []
|
||||
|
||||
logger.info(f"Perceptual_Retrieve: {len(keywords)} keywords extracted")
|
||||
|
||||
service = PerceptualSearchService(
|
||||
end_user_id=end_user_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
search_result = await service.search(
|
||||
query=original_query,
|
||||
keywords=keywords,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
result = {
|
||||
"memories": search_result["memories"],
|
||||
"content": search_result["content"],
|
||||
"_intermediate": {
|
||||
"type": "perceptual_retrieve",
|
||||
"title": "感知记忆检索",
|
||||
"data": search_result["memories"],
|
||||
"query": original_query,
|
||||
"result_count": len(search_result["memories"]),
|
||||
},
|
||||
}
|
||||
return {"perceptual_data": result}
|
||||
@@ -263,7 +263,6 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
logger.info(f"Problem extension result: {aggregated_dict}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
print(time.time() - start)
|
||||
result = {
|
||||
"context": aggregated_dict,
|
||||
"original": data,
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
|
||||
PerceptualSearchService,
|
||||
)
|
||||
from app.core.memory.agent.models.summary_models import (
|
||||
RetrieveSummaryResponse,
|
||||
SummaryResponse,
|
||||
@@ -339,11 +343,45 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
|
||||
try:
|
||||
if storage_type != "rag":
|
||||
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(
|
||||
|
||||
async def _perceptual_search():
|
||||
service = PerceptualSearchService(
|
||||
end_user_id=end_user_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
return await service.search(query=data, limit=5)
|
||||
|
||||
hybrid_task = SearchService().execute_hybrid_search(
|
||||
**search_params,
|
||||
memory_config=memory_config,
|
||||
expand_communities=False, # 路径 "2" 只需要 community 的 summary 文本,不展开到 Statement
|
||||
expand_communities=False,
|
||||
)
|
||||
perceptual_task = _perceptual_search()
|
||||
|
||||
gather_results = await asyncio.gather(
|
||||
hybrid_task, perceptual_task, return_exceptions=True
|
||||
)
|
||||
hybrid_result = gather_results[0]
|
||||
perceptual_results = gather_results[1]
|
||||
|
||||
# 处理 hybrid search 异常
|
||||
if isinstance(hybrid_result, Exception):
|
||||
raise hybrid_result
|
||||
retrieve_info, question, raw_results = hybrid_result
|
||||
|
||||
# 处理感知记忆结果
|
||||
if isinstance(perceptual_results, Exception):
|
||||
logger.warning(f"[Input_Summary] perceptual search failed: {perceptual_results}")
|
||||
perceptual_results = []
|
||||
|
||||
# 拼接感知记忆内容到 retrieve_info
|
||||
if perceptual_results and isinstance(perceptual_results, dict):
|
||||
perceptual_content = perceptual_results.get("content", "")
|
||||
if perceptual_content:
|
||||
retrieve_info = f"{retrieve_info}\n\n<history-files>\n{perceptual_content}"
|
||||
count = len(perceptual_results.get("memories", []))
|
||||
logger.info(f"[Input_Summary] appended {count} perceptual memories (reranked)")
|
||||
|
||||
# 调试:打印 community 检索结果数量
|
||||
if raw_results and isinstance(raw_results, dict):
|
||||
reranked = raw_results.get('reranked_results', {})
|
||||
@@ -371,10 +409,7 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
"error": str(e)
|
||||
}
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
duration = end - start
|
||||
log_time('检索', duration)
|
||||
return {"summary": summary}
|
||||
|
||||
@@ -412,8 +447,20 @@ async def Retrieve_Summary(state: ReadState) -> ReadState:
|
||||
retrieve_info_str = list(set(retrieve_info_str))
|
||||
retrieve_info_str = '\n'.join(retrieve_info_str)
|
||||
|
||||
aimessages = await summary_llm(state, history, retrieve_info_str,
|
||||
'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1")
|
||||
# Merge perceptual memory content
|
||||
perceptual_data = state.get("perceptual_data", {})
|
||||
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
|
||||
if perceptual_content:
|
||||
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
|
||||
|
||||
aimessages = await summary_llm(
|
||||
state,
|
||||
history,
|
||||
retrieve_info_str,
|
||||
'direct_summary_prompt.jinja2',
|
||||
'retrieve_summary', RetrieveSummaryResponse,
|
||||
"1"
|
||||
)
|
||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||
await summary_redis_save(state, aimessages)
|
||||
if aimessages == '':
|
||||
@@ -458,6 +505,12 @@ async def Summary(state: ReadState) -> ReadState:
|
||||
retrieve_info_str += i + '\n'
|
||||
history = await summary_history(state)
|
||||
|
||||
# Merge perceptual memory content
|
||||
perceptual_data = state.get("perceptual_data", {})
|
||||
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
|
||||
if perceptual_content:
|
||||
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
|
||||
|
||||
data = {
|
||||
"query": query,
|
||||
"history": history,
|
||||
@@ -508,6 +561,13 @@ async def Summary_fails(state: ReadState) -> ReadState:
|
||||
if key == 'answer_small':
|
||||
for i in value:
|
||||
retrieve_info_str += i + '\n'
|
||||
|
||||
# Merge perceptual memory content
|
||||
perceptual_data = state.get("perceptual_data", {})
|
||||
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
|
||||
if perceptual_content:
|
||||
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
|
||||
|
||||
data = {
|
||||
"query": query,
|
||||
"history": history,
|
||||
|
||||
@@ -17,6 +17,9 @@ from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import (
|
||||
from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import (
|
||||
retrieve_nodes,
|
||||
)
|
||||
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
|
||||
perceptual_retrieve_node,
|
||||
)
|
||||
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
|
||||
Input_Summary,
|
||||
Retrieve_Summary,
|
||||
@@ -48,13 +51,14 @@ async def make_read_graph():
|
||||
"""
|
||||
try:
|
||||
# Build workflow graph
|
||||
workflow = StateGraph(ReadState)
|
||||
workflow = StateGraph(ReadState)
|
||||
workflow.add_node("content_input", content_input_node)
|
||||
workflow.add_node("Split_The_Problem", Split_The_Problem)
|
||||
workflow.add_node("Problem_Extension", Problem_Extension)
|
||||
workflow.add_node("Input_Summary", Input_Summary)
|
||||
workflow.add_node("Retrieve", retrieve_nodes)
|
||||
# workflow.add_node("Retrieve", retrieve)
|
||||
workflow.add_node("Perceptual_Retrieve", perceptual_retrieve_node)
|
||||
workflow.add_node("Verify", Verify)
|
||||
workflow.add_node("Retrieve_Summary", Retrieve_Summary)
|
||||
workflow.add_node("Summary", Summary)
|
||||
@@ -65,14 +69,15 @@ async def make_read_graph():
|
||||
workflow.add_conditional_edges("content_input", Split_continue)
|
||||
workflow.add_edge("Input_Summary", END)
|
||||
workflow.add_edge("Split_The_Problem", "Problem_Extension")
|
||||
workflow.add_edge("Problem_Extension", "Retrieve")
|
||||
# After Problem_Extension, retrieve perceptual memory first, then main Retrieve
|
||||
workflow.add_edge("Problem_Extension", "Perceptual_Retrieve")
|
||||
workflow.add_edge("Perceptual_Retrieve", "Retrieve")
|
||||
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
|
||||
workflow.add_edge("Retrieve_Summary", END)
|
||||
workflow.add_conditional_edges("Verify", Verify_continue)
|
||||
workflow.add_edge("Summary_fails", END)
|
||||
workflow.add_edge("Summary", END)
|
||||
|
||||
'''-----'''
|
||||
# workflow.add_edge("Retrieve", END)
|
||||
|
||||
# Compile workflow
|
||||
@@ -80,7 +85,5 @@ async def make_read_graph():
|
||||
yield graph
|
||||
|
||||
except Exception as e:
|
||||
print(f"创建工作流失败: {e}")
|
||||
logger.error(f"创建工作流失败: {e}")
|
||||
raise
|
||||
finally:
|
||||
print("工作流创建完成")
|
||||
|
||||
@@ -10,7 +10,6 @@ from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.src.search import run_hybrid_search
|
||||
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
# 需要从展开结果中过滤的字段(含 Neo4j DateTime,不可 JSON 序列化)
|
||||
@@ -31,10 +30,10 @@ def _clean_expand_fields(obj):
|
||||
|
||||
|
||||
async def expand_communities_to_statements(
|
||||
community_results: List[dict],
|
||||
end_user_id: str,
|
||||
existing_content: str = "",
|
||||
limit: int = 10,
|
||||
community_results: List[dict],
|
||||
end_user_id: str,
|
||||
existing_content: str = "",
|
||||
limit: int = 10,
|
||||
) -> Tuple[List[dict], List[str]]:
|
||||
"""
|
||||
社区展开 helper:给定命中的 community 列表,拉取关联 Statement。
|
||||
@@ -76,17 +75,18 @@ async def expand_communities_to_statements(
|
||||
if s.get("statement") and s["statement"] not in existing_lines
|
||||
]
|
||||
cleaned = _clean_expand_fields(expanded_stmts)
|
||||
logger.info(f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements,新增 {len(new_texts)} 条,community_ids={community_ids}")
|
||||
logger.info(
|
||||
f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements,新增 {len(new_texts)} 条,community_ids={community_ids}")
|
||||
return cleaned, new_texts
|
||||
|
||||
|
||||
class SearchService:
|
||||
"""Service for executing hybrid search and processing results."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the search service."""
|
||||
logger.info("SearchService initialized")
|
||||
|
||||
|
||||
def extract_content_from_result(self, result: dict, node_type: str = "") -> str:
|
||||
"""
|
||||
Extract only meaningful content from search results, dropping all metadata.
|
||||
@@ -107,19 +107,19 @@ class SearchService:
|
||||
"""
|
||||
if not isinstance(result, dict):
|
||||
return str(result)
|
||||
|
||||
|
||||
content_parts = []
|
||||
|
||||
|
||||
# Statements: extract statement field
|
||||
if 'statement' in result and result['statement']:
|
||||
content_parts.append(result['statement'])
|
||||
|
||||
|
||||
# Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定
|
||||
# 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要
|
||||
is_community = (
|
||||
node_type == "community"
|
||||
or 'member_count' in result
|
||||
or 'core_entities' in result
|
||||
node_type == "community"
|
||||
or 'member_count' in result
|
||||
or 'core_entities' in result
|
||||
)
|
||||
if is_community:
|
||||
name = result.get('name', '')
|
||||
@@ -130,16 +130,16 @@ class SearchService:
|
||||
elif 'content' in result and result['content']:
|
||||
# Summaries / Chunks
|
||||
content_parts.append(result['content'])
|
||||
|
||||
|
||||
# Entities: extract name and fact_summary (commented out in original)
|
||||
# if 'name' in result and result['name']:
|
||||
# content_parts.append(result['name'])
|
||||
# if result.get('fact_summary'):
|
||||
# content_parts.append(result['fact_summary'])
|
||||
|
||||
|
||||
# Return concatenated content or empty string
|
||||
return '\n'.join(content_parts) if content_parts else ""
|
||||
|
||||
|
||||
def clean_query(self, query: str) -> str:
|
||||
"""
|
||||
Clean and escape query text for Lucene.
|
||||
@@ -155,33 +155,33 @@ class SearchService:
|
||||
Cleaned and escaped query string
|
||||
"""
|
||||
q = str(query).strip()
|
||||
|
||||
|
||||
# Remove wrapping quotes
|
||||
if (q.startswith("'") and q.endswith("'")) or (
|
||||
q.startswith('"') and q.endswith('"')
|
||||
q.startswith('"') and q.endswith('"')
|
||||
):
|
||||
q = q[1:-1]
|
||||
|
||||
|
||||
# Remove newlines and carriage returns
|
||||
q = q.replace('\r', ' ').replace('\n', ' ').strip()
|
||||
|
||||
|
||||
# Apply Lucene escaping
|
||||
q = escape_lucene_query(q)
|
||||
|
||||
|
||||
return q
|
||||
|
||||
|
||||
async def execute_hybrid_search(
|
||||
self,
|
||||
end_user_id: str,
|
||||
question: str,
|
||||
limit: int = 5,
|
||||
search_type: str = "hybrid",
|
||||
include: Optional[List[str]] = None,
|
||||
rerank_alpha: float = 0.4,
|
||||
output_path: str = "search_results.json",
|
||||
return_raw_results: bool = False,
|
||||
memory_config = None,
|
||||
expand_communities: bool = True,
|
||||
self,
|
||||
end_user_id: str,
|
||||
question: str,
|
||||
limit: int = 5,
|
||||
search_type: str = "hybrid",
|
||||
include: Optional[List[str]] = None,
|
||||
rerank_alpha: float = 0.4,
|
||||
output_path: str = "search_results.json",
|
||||
return_raw_results: bool = False,
|
||||
memory_config=None,
|
||||
expand_communities: bool = True,
|
||||
) -> Tuple[str, str, Optional[dict]]:
|
||||
"""
|
||||
Execute hybrid search and return clean content.
|
||||
@@ -205,10 +205,10 @@ class SearchService:
|
||||
"""
|
||||
if include is None:
|
||||
include = ["statements", "chunks", "entities", "summaries", "communities"]
|
||||
|
||||
|
||||
# Clean query
|
||||
cleaned_query = self.clean_query(question)
|
||||
|
||||
|
||||
try:
|
||||
# Execute search
|
||||
answer = await run_hybrid_search(
|
||||
@@ -221,18 +221,18 @@ class SearchService:
|
||||
memory_config=memory_config,
|
||||
rerank_alpha=rerank_alpha
|
||||
)
|
||||
|
||||
|
||||
# Extract results based on search type and include parameter
|
||||
# Prioritize summaries as they contain synthesized contextual information
|
||||
answer_list = []
|
||||
|
||||
|
||||
# For hybrid search, use reranked_results
|
||||
if search_type == "hybrid":
|
||||
reranked_results = answer.get('reranked_results', {})
|
||||
|
||||
|
||||
# Priority order: summaries first (most contextual), then communities, statements, chunks, entities
|
||||
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
|
||||
|
||||
|
||||
for category in priority_order:
|
||||
if category in include and category in reranked_results:
|
||||
category_results = reranked_results[category]
|
||||
@@ -242,7 +242,7 @@ class SearchService:
|
||||
# For keyword or embedding search, results are directly in answer dict
|
||||
# Apply same priority order
|
||||
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
|
||||
|
||||
|
||||
for category in priority_order:
|
||||
if category in include and category in answer:
|
||||
category_results = answer[category]
|
||||
@@ -261,7 +261,7 @@ class SearchService:
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
answer_list.extend(cleaned_stmts)
|
||||
|
||||
|
||||
# Extract clean content from all results,按类型传入 node_type 区分 community
|
||||
content_list = []
|
||||
for ans in answer_list:
|
||||
@@ -269,19 +269,18 @@ class SearchService:
|
||||
ntype = "community" if ('member_count' in ans or 'core_entities' in ans) else ""
|
||||
content_list.append(self.extract_content_from_result(ans, node_type=ntype))
|
||||
|
||||
|
||||
# Filter out empty strings and join with newlines
|
||||
clean_content = '\n'.join([c for c in content_list if c])
|
||||
|
||||
|
||||
# Log first 200 chars
|
||||
logger.info(f"检索接口搜索结果==>>:{clean_content[:200]}...")
|
||||
|
||||
|
||||
# Return raw results if requested
|
||||
if return_raw_results:
|
||||
return clean_content, cleaned_query, answer
|
||||
else:
|
||||
return clean_content, cleaned_query, None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Search failed for query '{question}' in group '{end_user_id}': {e}",
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Annotated, TypedDict
|
||||
@@ -52,6 +51,7 @@ class ReadState(TypedDict):
|
||||
embedding_id: str
|
||||
memory_config: object # 新增字段用于传递内存配置对象
|
||||
retrieve: dict
|
||||
perceptual_data: dict
|
||||
RetrieveSummary: dict
|
||||
InputSummary: dict
|
||||
verify: dict
|
||||
|
||||
@@ -152,6 +152,24 @@ async def write(
|
||||
# Step 3: Save all data to Neo4j database
|
||||
step_start = time.time()
|
||||
|
||||
# Neo4j 写入前:清洗用户/AI助手实体之间的别名交叉污染
|
||||
# 从 Neo4j 查询已有的 AI 助手别名,与本轮实体中的 AI 助手别名合并,
|
||||
# 确保用户实体的 aliases 不包含 AI 助手的名字
|
||||
try:
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
|
||||
clean_cross_role_aliases,
|
||||
fetch_neo4j_assistant_aliases,
|
||||
)
|
||||
neo4j_assistant_aliases = set()
|
||||
if all_entity_nodes:
|
||||
_eu_id = all_entity_nodes[0].end_user_id
|
||||
if _eu_id:
|
||||
neo4j_assistant_aliases = await fetch_neo4j_assistant_aliases(neo4j_connector, _eu_id)
|
||||
clean_cross_role_aliases(all_entity_nodes, external_assistant_aliases=neo4j_assistant_aliases)
|
||||
logger.info(f"Neo4j 写入前别名清洗完成,AI助手别名排除集大小: {len(neo4j_assistant_aliases)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Neo4j 写入前别名清洗失败(不影响主流程): {e}")
|
||||
|
||||
# 添加死锁重试机制
|
||||
max_retries = 3
|
||||
retry_delay = 1 # 秒
|
||||
|
||||
@@ -43,6 +43,7 @@ load_dotenv()
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
def _parse_datetime(value: Any) -> Optional[datetime]:
|
||||
"""Parse ISO `created_at` strings of the form 'YYYY-MM-DDTHH:MM:SS.ssssss'."""
|
||||
if value is None:
|
||||
@@ -75,7 +76,7 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
|
||||
if score_field == "activation_value" and score is None:
|
||||
scores.append(None) # 保持 None,稍后特殊处理
|
||||
continue
|
||||
|
||||
|
||||
if score is not None and isinstance(score, (int, float)):
|
||||
scores.append(float(score))
|
||||
else:
|
||||
@@ -83,10 +84,10 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
|
||||
|
||||
if not scores:
|
||||
return results
|
||||
|
||||
|
||||
# 过滤掉 None 值,只对有效分数进行归一化
|
||||
valid_scores = [s for s in scores if s is not None]
|
||||
|
||||
|
||||
if not valid_scores:
|
||||
# 所有分数都是 None,不进行归一化
|
||||
for item in results:
|
||||
@@ -94,7 +95,7 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
|
||||
item[f"normalized_{score_field}"] = None
|
||||
return results
|
||||
|
||||
if len(valid_scores) == 1: # Single valid score, set to 1.0
|
||||
if len(valid_scores) == 1: # Single valid score, set to 1.0
|
||||
for item, score in zip(results, scores):
|
||||
if score_field in item or score_field == "activation_value":
|
||||
if score is None:
|
||||
@@ -132,7 +133,6 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
|
||||
return results
|
||||
|
||||
|
||||
|
||||
def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Remove duplicate items from search results based on content.
|
||||
@@ -150,52 +150,53 @@ def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
seen_ids = set()
|
||||
seen_content = set()
|
||||
deduplicated = []
|
||||
|
||||
|
||||
for item in items:
|
||||
# Try multiple ID fields to identify unique items
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
|
||||
|
||||
# Extract content from various possible fields
|
||||
content = (
|
||||
item.get("text") or
|
||||
item.get("content") or
|
||||
item.get("statement") or
|
||||
item.get("name") or
|
||||
""
|
||||
item.get("text") or
|
||||
item.get("content") or
|
||||
item.get("statement") or
|
||||
item.get("name") or
|
||||
""
|
||||
)
|
||||
|
||||
|
||||
# Normalize content for comparison (strip whitespace and lowercase)
|
||||
normalized_content = str(content).strip().lower() if content else ""
|
||||
|
||||
|
||||
# Check if we've seen this ID or content before
|
||||
is_duplicate = False
|
||||
|
||||
|
||||
if item_id and item_id in seen_ids:
|
||||
is_duplicate = True
|
||||
elif normalized_content and normalized_content in seen_content:
|
||||
# Only check content duplication if content is not empty
|
||||
is_duplicate = True
|
||||
|
||||
|
||||
if not is_duplicate:
|
||||
# Mark as seen
|
||||
if item_id:
|
||||
seen_ids.add(item_id)
|
||||
if normalized_content: # Only track non-empty content
|
||||
seen_content.add(normalized_content)
|
||||
|
||||
|
||||
deduplicated.append(item)
|
||||
|
||||
|
||||
return deduplicated
|
||||
|
||||
|
||||
def rerank_with_activation(
|
||||
keyword_results: Dict[str, List[Dict[str, Any]]],
|
||||
embedding_results: Dict[str, List[Dict[str, Any]]],
|
||||
alpha: float = 0.6,
|
||||
limit: int = 10,
|
||||
forgetting_config: ForgettingEngineConfig | None = None,
|
||||
activation_boost_factor: float = 0.8,
|
||||
now: datetime | None = None,
|
||||
keyword_results: Dict[str, List[Dict[str, Any]]],
|
||||
embedding_results: Dict[str, List[Dict[str, Any]]],
|
||||
alpha: float = 0.6,
|
||||
limit: int = 10,
|
||||
forgetting_config: ForgettingEngineConfig | None = None,
|
||||
activation_boost_factor: float = 0.8,
|
||||
now: datetime | None = None,
|
||||
content_score_threshold: float = 0.5,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
两阶段排序:先按内容相关性筛选,再按激活值排序。
|
||||
@@ -222,6 +223,8 @@ def rerank_with_activation(
|
||||
forgetting_config: 遗忘引擎配置(当前未使用)
|
||||
activation_boost_factor: 激活度对记忆强度的影响系数 (默认: 0.8)
|
||||
now: 当前时间(用于遗忘计算)
|
||||
content_score_threshold: 内容相关性最低阈值(基于归一化后的 content_score),
|
||||
低于此阈值的结果会被过滤。默认 0.5。
|
||||
|
||||
返回:
|
||||
带评分元数据的重排序结果,按 final_score 排序
|
||||
@@ -229,26 +232,26 @@ def rerank_with_activation(
|
||||
# 验证权重范围
|
||||
if not (0 <= alpha <= 1):
|
||||
raise ValueError(f"alpha 必须在 [0, 1] 范围内,当前值: {alpha}")
|
||||
|
||||
|
||||
# 初始化遗忘引擎(如果需要)
|
||||
engine = None
|
||||
if forgetting_config:
|
||||
engine = ForgettingEngine(forgetting_config)
|
||||
now_dt = now or datetime.now()
|
||||
|
||||
|
||||
reranked: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
|
||||
for category in ["statements", "chunks", "entities", "summaries", "communities"]:
|
||||
keyword_items = keyword_results.get(category, [])
|
||||
embedding_items = embedding_results.get(category, [])
|
||||
|
||||
|
||||
# 步骤 1: 归一化分数
|
||||
keyword_items = normalize_scores(keyword_items, "score")
|
||||
embedding_items = normalize_scores(embedding_items, "score")
|
||||
|
||||
|
||||
# 步骤 2: 按 ID 合并结果(去重)
|
||||
combined_items: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
# 添加关键词结果
|
||||
for item in keyword_items:
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
@@ -257,7 +260,7 @@ def rerank_with_activation(
|
||||
combined_items[item_id] = item.copy()
|
||||
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
combined_items[item_id]["embedding_score"] = 0 # 默认值
|
||||
|
||||
|
||||
# 添加或更新向量嵌入结果
|
||||
for item in embedding_items:
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
@@ -271,18 +274,18 @@ def rerank_with_activation(
|
||||
combined_items[item_id] = item.copy()
|
||||
combined_items[item_id]["bm25_score"] = 0 # 默认值
|
||||
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
|
||||
|
||||
# 步骤 3: 归一化激活度分数
|
||||
# 为所有项准备激活度值列表
|
||||
items_list = list(combined_items.values())
|
||||
items_list = normalize_scores(items_list, "activation_value")
|
||||
|
||||
|
||||
# 更新 combined_items 中的归一化激活度分数
|
||||
for item in items_list:
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
if item_id and item_id in combined_items:
|
||||
combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value")
|
||||
|
||||
|
||||
# 步骤 4: 计算基础分数和最终分数
|
||||
for item_id, item in combined_items.items():
|
||||
bm25_norm = float(item.get("bm25_score", 0) or 0)
|
||||
@@ -290,45 +293,45 @@ def rerank_with_activation(
|
||||
# normalized_activation_value 为 None 表示该节点无激活值,保留 None 语义
|
||||
raw_act_norm = item.get("normalized_activation_value")
|
||||
act_norm = float(raw_act_norm) if raw_act_norm is not None else None
|
||||
|
||||
|
||||
# 第一阶段:只考虑内容相关性(BM25 + Embedding)
|
||||
# alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重
|
||||
content_score = alpha * bm25_norm + (1 - alpha) * emb_norm
|
||||
base_score = content_score # 第一阶段用内容分数
|
||||
|
||||
|
||||
# 存储激活度分数供第二阶段使用(None 表示无激活值,不参与激活值排序)
|
||||
item["activation_score"] = act_norm # 可能为 None
|
||||
item["content_score"] = content_score
|
||||
item["base_score"] = base_score
|
||||
|
||||
|
||||
# 步骤 5: 应用遗忘曲线(可选)
|
||||
if engine:
|
||||
# 计算受激活度影响的记忆强度
|
||||
importance = float(item.get("importance_score", 0.5) or 0.5)
|
||||
|
||||
|
||||
# 获取 activation_value
|
||||
activation_val = item.get("activation_value")
|
||||
|
||||
|
||||
# 只对有激活值的节点应用遗忘曲线
|
||||
if activation_val is not None and isinstance(activation_val, (int, float)):
|
||||
activation_val = float(activation_val)
|
||||
|
||||
|
||||
# 计算记忆强度:importance_score × (1 + activation_value × boost_factor)
|
||||
memory_strength = importance * (1 + activation_val * activation_boost_factor)
|
||||
|
||||
|
||||
# 计算经过的时间(天数)
|
||||
dt = _parse_datetime(item.get("created_at"))
|
||||
if dt is None:
|
||||
time_elapsed_days = 0.0
|
||||
else:
|
||||
time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
|
||||
|
||||
|
||||
# 获取遗忘权重
|
||||
forgetting_weight = engine.calculate_weight(
|
||||
time_elapsed=time_elapsed_days,
|
||||
memory_strength=memory_strength
|
||||
)
|
||||
|
||||
|
||||
# 应用到基础分数
|
||||
item["forgetting_weight"] = forgetting_weight
|
||||
item["final_score"] = base_score * forgetting_weight
|
||||
@@ -338,7 +341,7 @@ def rerank_with_activation(
|
||||
else:
|
||||
# 不使用遗忘曲线
|
||||
item["final_score"] = base_score
|
||||
|
||||
|
||||
# 步骤 6: 两阶段排序和限制
|
||||
# 第一阶段:按内容相关性(base_score)排序,取 Top-K
|
||||
first_stage_limit = limit * 3 # 可配置,取3倍候选
|
||||
@@ -347,11 +350,11 @@ def rerank_with_activation(
|
||||
key=lambda x: float(x.get("base_score", 0) or 0), # 按内容分数排序
|
||||
reverse=True
|
||||
)[:first_stage_limit]
|
||||
|
||||
|
||||
# 第二阶段:分离有激活值和无激活值的节点
|
||||
items_with_activation = []
|
||||
items_without_activation = []
|
||||
|
||||
|
||||
for item in first_stage_sorted:
|
||||
activation_score = item.get("activation_score")
|
||||
# 检查是否有有效的激活值(不是 None)
|
||||
@@ -359,14 +362,14 @@ def rerank_with_activation(
|
||||
items_with_activation.append(item)
|
||||
else:
|
||||
items_without_activation.append(item)
|
||||
|
||||
|
||||
# 优先按激活值排序有激活值的节点
|
||||
sorted_with_activation = sorted(
|
||||
items_with_activation,
|
||||
key=lambda x: float(x.get("activation_score", 0) or 0),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
|
||||
# 如果有激活值的节点不足 limit,用无激活值的节点补充
|
||||
if len(sorted_with_activation) < limit:
|
||||
needed = limit - len(sorted_with_activation)
|
||||
@@ -374,7 +377,7 @@ def rerank_with_activation(
|
||||
sorted_items = sorted_with_activation + items_without_activation[:needed]
|
||||
else:
|
||||
sorted_items = sorted_with_activation[:limit]
|
||||
|
||||
|
||||
# 两阶段排序完成,更新 final_score 以反映实际排序依据
|
||||
# Stage 1: 按 content_score 筛选候选(已完成)
|
||||
# Stage 2: 按 activation_score 排序(已完成)
|
||||
@@ -390,16 +393,29 @@ def rerank_with_activation(
|
||||
else:
|
||||
# 无激活值:使用内容相关性分数
|
||||
item["final_score"] = item.get("base_score", 0)
|
||||
|
||||
# 最终去重确保没有重复项
|
||||
|
||||
if content_score_threshold > 0:
|
||||
before_count = len(sorted_items)
|
||||
sorted_items = [
|
||||
item for item in sorted_items
|
||||
if float(item.get("content_score", 0) or 0) >= content_score_threshold
|
||||
]
|
||||
filtered_count = before_count - len(sorted_items)
|
||||
if filtered_count > 0:
|
||||
logger.info(
|
||||
f"[rerank] {category}: filtered {filtered_count}/{before_count} "
|
||||
f"items below content_score_threshold={content_score_threshold}"
|
||||
)
|
||||
|
||||
sorted_items = _deduplicate_results(sorted_items)
|
||||
|
||||
|
||||
reranked[category] = sorted_items
|
||||
|
||||
|
||||
return reranked
|
||||
|
||||
|
||||
def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str], log_file: str = None):
|
||||
def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str],
|
||||
log_file: str = None):
|
||||
"""Log search query information using the logger.
|
||||
|
||||
Args:
|
||||
@@ -412,7 +428,7 @@ def log_search_query(query_text: str, search_type: str, end_user_id: str | None,
|
||||
"""
|
||||
# Ensure the query text is plain and clean before logging
|
||||
cleaned_query = extract_plain_query(query_text)
|
||||
|
||||
|
||||
# Log using the standard logger
|
||||
logger.info(
|
||||
f"Search query: query='{cleaned_query}', type={search_type}, "
|
||||
@@ -439,8 +455,8 @@ def _remove_keys_recursive(obj: Any, keys_to_remove: List[str]) -> Any:
|
||||
|
||||
|
||||
def apply_reranker_placeholder(
|
||||
results: Dict[str, List[Dict[str, Any]]],
|
||||
query_text: str,
|
||||
results: Dict[str, List[Dict[str, Any]]],
|
||||
query_text: str,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Placeholder for a cross-encoder reranker.
|
||||
@@ -483,7 +499,7 @@ def apply_reranker_placeholder(
|
||||
# ) -> Dict[str, List[Dict[str, Any]]]:
|
||||
# """
|
||||
# Apply LLM-based reranking to search results.
|
||||
|
||||
|
||||
# Args:
|
||||
# results: Search results organized by category
|
||||
# query_text: Original search query
|
||||
@@ -491,7 +507,7 @@ def apply_reranker_placeholder(
|
||||
# llm_weight: Weight for LLM score (0.0-1.0, higher favors LLM)
|
||||
# top_k: Maximum number of items to rerank per category
|
||||
# batch_size: Number of items to process concurrently
|
||||
|
||||
|
||||
# Returns:
|
||||
# Reranked results with final_score and reranker_model fields
|
||||
# """
|
||||
@@ -501,18 +517,18 @@ def apply_reranker_placeholder(
|
||||
# # except Exception as e:
|
||||
# # logger.debug(f"Failed to load reranker config: {e}")
|
||||
# # rc = {}
|
||||
|
||||
|
||||
# # Check if reranking is enabled
|
||||
# enabled = rc.get("enabled", False)
|
||||
# if not enabled:
|
||||
# logger.debug("LLM reranking is disabled in configuration")
|
||||
# return results
|
||||
|
||||
|
||||
# # Load configuration parameters with defaults
|
||||
# llm_weight = llm_weight if llm_weight is not None else rc.get("llm_weight", 0.5)
|
||||
# top_k = top_k if top_k is not None else rc.get("top_k", 20)
|
||||
# batch_size = batch_size if batch_size is not None else rc.get("batch_size", 5)
|
||||
|
||||
|
||||
# # Initialize reranker client if not provided
|
||||
# if reranker_client is None:
|
||||
# try:
|
||||
@@ -520,10 +536,10 @@ def apply_reranker_placeholder(
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Failed to initialize reranker client: {e}, skipping LLM reranking")
|
||||
# return results
|
||||
|
||||
|
||||
# # Get model name for metadata
|
||||
# model_name = getattr(reranker_client, 'model_name', 'unknown')
|
||||
|
||||
|
||||
# # Process each category
|
||||
# reranked_results = {}
|
||||
# for category in ["statements", "chunks", "entities", "summaries"]:
|
||||
@@ -531,38 +547,38 @@ def apply_reranker_placeholder(
|
||||
# if not items:
|
||||
# reranked_results[category] = []
|
||||
# continue
|
||||
|
||||
|
||||
# # Select top K items by combined_score for reranking
|
||||
# sorted_items = sorted(
|
||||
# items,
|
||||
# key=lambda x: float(x.get("combined_score", x.get("score", 0.0)) or 0.0),
|
||||
# reverse=True
|
||||
# )
|
||||
|
||||
|
||||
# top_items = sorted_items[:top_k]
|
||||
# remaining_items = sorted_items[top_k:]
|
||||
|
||||
|
||||
# # Extract text content from each item
|
||||
# def extract_text(item: Dict[str, Any]) -> str:
|
||||
# """Extract text content from a result item."""
|
||||
# # Try different text fields based on category
|
||||
# text = item.get("text") or item.get("content") or item.get("statement") or item.get("name") or ""
|
||||
# return str(text).strip()
|
||||
|
||||
|
||||
# # Batch items for concurrent processing
|
||||
# batches = []
|
||||
# for i in range(0, len(top_items), batch_size):
|
||||
# batch = top_items[i:i + batch_size]
|
||||
# batches.append(batch)
|
||||
|
||||
|
||||
# # Process batches concurrently
|
||||
# async def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
# """Process a batch of items with LLM relevance scoring."""
|
||||
# scored_batch = []
|
||||
|
||||
|
||||
# for item in batch:
|
||||
# item_text = extract_text(item)
|
||||
|
||||
|
||||
# # Skip items with no text
|
||||
# if not item_text:
|
||||
# item_copy = item.copy()
|
||||
@@ -572,7 +588,7 @@ def apply_reranker_placeholder(
|
||||
# item_copy["reranker_model"] = model_name
|
||||
# scored_batch.append(item_copy)
|
||||
# continue
|
||||
|
||||
|
||||
# # Create relevance scoring prompt
|
||||
# prompt = f"""Given the search query and a result item, rate the relevance of the item to the query on a scale from 0.0 to 1.0.
|
||||
|
||||
@@ -585,15 +601,15 @@ def apply_reranker_placeholder(
|
||||
# - 1.0 means perfectly relevant
|
||||
|
||||
# Relevance score:"""
|
||||
|
||||
|
||||
# # Send request to LLM
|
||||
# try:
|
||||
# messages = [{"role": "user", "content": prompt}]
|
||||
# response = await reranker_client.chat(messages)
|
||||
|
||||
|
||||
# # Parse LLM response to extract relevance score
|
||||
# response_text = str(response.content if hasattr(response, 'content') else response).strip()
|
||||
|
||||
|
||||
# # Try to extract a float from the response
|
||||
# try:
|
||||
# # Remove any non-numeric characters except decimal point
|
||||
@@ -608,11 +624,11 @@ def apply_reranker_placeholder(
|
||||
# except (ValueError, AttributeError) as e:
|
||||
# logger.warning(f"Invalid LLM score format: {response_text}, using combined_score. Error: {e}")
|
||||
# llm_score = None
|
||||
|
||||
|
||||
# # Calculate final score
|
||||
# item_copy = item.copy()
|
||||
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
|
||||
|
||||
# if llm_score is not None:
|
||||
# final_score = (1 - llm_weight) * combined_score + llm_weight * llm_score
|
||||
# item_copy["llm_relevance_score"] = llm_score
|
||||
@@ -620,7 +636,7 @@ def apply_reranker_placeholder(
|
||||
# # Use combined_score as fallback
|
||||
# final_score = combined_score
|
||||
# item_copy["llm_relevance_score"] = combined_score
|
||||
|
||||
|
||||
# item_copy["final_score"] = final_score
|
||||
# item_copy["reranker_model"] = model_name
|
||||
# scored_batch.append(item_copy)
|
||||
@@ -632,14 +648,14 @@ def apply_reranker_placeholder(
|
||||
# item_copy["llm_relevance_score"] = combined_score
|
||||
# item_copy["reranker_model"] = model_name
|
||||
# scored_batch.append(item_copy)
|
||||
|
||||
|
||||
# return scored_batch
|
||||
|
||||
|
||||
# # Process all batches concurrently
|
||||
# try:
|
||||
# batch_tasks = [process_batch(batch) for batch in batches]
|
||||
# batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
|
||||
|
||||
|
||||
# # Merge batch results
|
||||
# scored_items = []
|
||||
# for result in batch_results:
|
||||
@@ -647,7 +663,7 @@ def apply_reranker_placeholder(
|
||||
# logger.warning(f"Batch processing failed: {result}")
|
||||
# continue
|
||||
# scored_items.extend(result)
|
||||
|
||||
|
||||
# # Add remaining items (not in top K) with their combined_score as final_score
|
||||
# for item in remaining_items:
|
||||
# item_copy = item.copy()
|
||||
@@ -655,11 +671,11 @@ def apply_reranker_placeholder(
|
||||
# item_copy["final_score"] = combined_score
|
||||
# item_copy["reranker_model"] = model_name
|
||||
# scored_items.append(item_copy)
|
||||
|
||||
|
||||
# # Sort all items by final_score in descending order
|
||||
# scored_items.sort(key=lambda x: float(x.get("final_score", 0.0) or 0.0), reverse=True)
|
||||
# reranked_results[category] = scored_items
|
||||
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error in LLM reranking for category {category}: {e}, returning original results")
|
||||
# # Return original items with combined_score as final_score
|
||||
@@ -668,22 +684,22 @@ def apply_reranker_placeholder(
|
||||
# item["final_score"] = combined_score
|
||||
# item["reranker_model"] = model_name
|
||||
# reranked_results[category] = items
|
||||
|
||||
|
||||
# return reranked_results
|
||||
|
||||
|
||||
async def run_hybrid_search(
|
||||
query_text: str,
|
||||
search_type: str,
|
||||
end_user_id: str | None,
|
||||
limit: int,
|
||||
include: List[str],
|
||||
output_path: str | None,
|
||||
memory_config: "MemoryConfig",
|
||||
rerank_alpha: float = 0.6,
|
||||
activation_boost_factor: float = 0.8,
|
||||
use_forgetting_rerank: bool = False,
|
||||
use_llm_rerank: bool = False,
|
||||
query_text: str,
|
||||
search_type: str,
|
||||
end_user_id: str | None,
|
||||
limit: int,
|
||||
include: List[str],
|
||||
output_path: str | None,
|
||||
memory_config: "MemoryConfig",
|
||||
rerank_alpha: float = 0.6,
|
||||
activation_boost_factor: float = 0.8,
|
||||
use_forgetting_rerank: bool = False,
|
||||
use_llm_rerank: bool = False,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -699,7 +715,7 @@ async def run_hybrid_search(
|
||||
|
||||
# Clean and normalize the incoming query before use/logging
|
||||
query_text = extract_plain_query(query_text)
|
||||
|
||||
|
||||
# Validate query is not empty after cleaning
|
||||
if not query_text or not query_text.strip():
|
||||
logger.warning("Empty query after cleaning, returning empty results")
|
||||
@@ -716,7 +732,7 @@ async def run_hybrid_search(
|
||||
"error": "Empty query"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Log the search query
|
||||
log_search_query(query_text, search_type, end_user_id, limit, include)
|
||||
|
||||
@@ -747,7 +763,7 @@ async def run_hybrid_search(
|
||||
# Embedding-based search
|
||||
logger.info("[PERF] Starting embedding search...")
|
||||
embedding_start = time.time()
|
||||
|
||||
|
||||
# 从数据库读取嵌入器配置(按 ID)并构建 RedBearModelConfig
|
||||
config_load_start = time.time()
|
||||
try:
|
||||
@@ -768,7 +784,7 @@ async def run_hybrid_search(
|
||||
embedder = OpenAIEmbedderClient(model_config=rb_config)
|
||||
embedder_init_time = time.time() - embedder_init_start
|
||||
logger.info(f"[PERF] Embedder init took {embedder_init_time:.4f}s")
|
||||
|
||||
|
||||
embedding_task = asyncio.create_task(
|
||||
search_graph_by_embedding(
|
||||
connector=connector,
|
||||
@@ -788,7 +804,7 @@ async def run_hybrid_search(
|
||||
|
||||
if keyword_task:
|
||||
keyword_results = await keyword_task
|
||||
keyword_latency = time.time() - keyword_start
|
||||
keyword_latency = time.time() - search_start_time
|
||||
latency_metrics["keyword_search_latency"] = round(keyword_latency, 4)
|
||||
logger.info(f"[PERF] Keyword search completed in {keyword_latency:.4f}s")
|
||||
if search_type == "keyword":
|
||||
@@ -798,7 +814,7 @@ async def run_hybrid_search(
|
||||
|
||||
if embedding_task:
|
||||
embedding_results = await embedding_task
|
||||
embedding_latency = time.time() - embedding_start
|
||||
embedding_latency = time.time() - search_start_time
|
||||
latency_metrics["embedding_search_latency"] = round(embedding_latency, 4)
|
||||
logger.info(f"[PERF] Embedding search completed in {embedding_latency:.4f}s")
|
||||
if search_type == "embedding":
|
||||
@@ -810,7 +826,8 @@ async def run_hybrid_search(
|
||||
if search_type == "hybrid":
|
||||
results["combined_summary"] = {
|
||||
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
|
||||
"total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
|
||||
"total_embedding_results": sum(
|
||||
len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
|
||||
"search_query": query_text,
|
||||
"search_timestamp": datetime.now().isoformat()
|
||||
}
|
||||
@@ -818,7 +835,7 @@ async def run_hybrid_search(
|
||||
# Apply two-stage reranking with ACTR activation calculation
|
||||
rerank_start = time.time()
|
||||
logger.info("[PERF] Using two-stage reranking with ACTR activation")
|
||||
|
||||
|
||||
# 加载遗忘引擎配置
|
||||
config_start = time.time()
|
||||
try:
|
||||
@@ -829,7 +846,7 @@ async def run_hybrid_search(
|
||||
forgetting_cfg = ForgettingEngineConfig()
|
||||
config_time = time.time() - config_start
|
||||
logger.info(f"[PERF] Forgetting config loading took {config_time:.4f}s")
|
||||
|
||||
|
||||
# 统一使用激活度重排序(两阶段:检索 + ACTR计算)
|
||||
rerank_compute_start = time.time()
|
||||
reranked_results = rerank_with_activation(
|
||||
@@ -842,14 +859,14 @@ async def run_hybrid_search(
|
||||
)
|
||||
rerank_compute_time = time.time() - rerank_compute_start
|
||||
logger.info(f"[PERF] Rerank computation took {rerank_compute_time:.4f}s")
|
||||
|
||||
|
||||
rerank_latency = time.time() - rerank_start
|
||||
latency_metrics["reranking_latency"] = round(rerank_latency, 4)
|
||||
logger.info(f"[PERF] Total reranking completed in {rerank_latency:.4f}s")
|
||||
|
||||
|
||||
# Optional: apply reranker placeholder if enabled via config
|
||||
reranked_results = apply_reranker_placeholder(reranked_results, query_text)
|
||||
|
||||
|
||||
# Apply LLM reranking if enabled
|
||||
llm_rerank_applied = False
|
||||
# if use_llm_rerank:
|
||||
@@ -862,11 +879,12 @@ async def run_hybrid_search(
|
||||
# logger.info("LLM reranking applied successfully")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"LLM reranking failed: {e}, using previous scores")
|
||||
|
||||
|
||||
results["reranked_results"] = reranked_results
|
||||
results["combined_summary"] = {
|
||||
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
|
||||
"total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
|
||||
"total_embedding_results": sum(
|
||||
len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
|
||||
"total_reranked_results": sum(len(v) if isinstance(v, list) else 0 for v in reranked_results.values()),
|
||||
"search_query": query_text,
|
||||
"search_timestamp": datetime.now().isoformat(),
|
||||
@@ -879,13 +897,13 @@ async def run_hybrid_search(
|
||||
# Calculate total latency
|
||||
total_latency = time.time() - search_start_time
|
||||
latency_metrics["total_latency"] = round(total_latency, 4)
|
||||
|
||||
|
||||
# Add latency metrics to results
|
||||
if "combined_summary" in results:
|
||||
results["combined_summary"]["latency_metrics"] = latency_metrics
|
||||
else:
|
||||
results["latency_metrics"] = latency_metrics
|
||||
|
||||
|
||||
logger.info(f"[PERF] ===== SEARCH PERFORMANCE SUMMARY =====")
|
||||
logger.info(f"[PERF] Total search completed in {total_latency:.4f}s")
|
||||
logger.info(f"[PERF] Latency breakdown: {json.dumps(latency_metrics, indent=2)}")
|
||||
@@ -908,8 +926,10 @@ async def run_hybrid_search(
|
||||
# Log search completion with result count
|
||||
if search_type == "hybrid":
|
||||
result_counts = {
|
||||
"keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in keyword_results.items()},
|
||||
"embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in embedding_results.items()}
|
||||
"keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in
|
||||
keyword_results.items()},
|
||||
"embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in
|
||||
embedding_results.items()}
|
||||
}
|
||||
else:
|
||||
result_counts = {key: len(value) if isinstance(value, list) else 0 for key, value in results.items()}
|
||||
@@ -927,12 +947,12 @@ async def run_hybrid_search(
|
||||
|
||||
|
||||
async def search_by_temporal(
|
||||
end_user_id: Optional[str] = "test",
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
end_user_id: Optional[str] = "test",
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
):
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
@@ -968,13 +988,13 @@ async def search_by_temporal(
|
||||
|
||||
|
||||
async def search_by_keyword_temporal(
|
||||
query_text: str,
|
||||
end_user_id: Optional[str] = "test",
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
query_text: str,
|
||||
end_user_id: Optional[str] = "test",
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
):
|
||||
"""
|
||||
Temporal keyword search across Statements.
|
||||
@@ -1011,9 +1031,9 @@ async def search_by_keyword_temporal(
|
||||
|
||||
|
||||
async def search_chunk_by_chunk_id(
|
||||
chunk_id: str,
|
||||
end_user_id: Optional[str] = "test",
|
||||
limit: int = 1,
|
||||
chunk_id: str,
|
||||
end_user_id: Optional[str] = "test",
|
||||
limit: int = 1,
|
||||
):
|
||||
"""
|
||||
Search for Chunks by chunk_id.
|
||||
@@ -1026,4 +1046,3 @@ async def search_chunk_by_chunk_id(
|
||||
limit=limit
|
||||
)
|
||||
return {"chunks": chunks}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import asyncio
|
||||
import difflib # 提供字符串相似度计算工具
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
@@ -16,6 +17,8 @@ from app.core.memory.models.graph_models import (
|
||||
)
|
||||
from app.core.memory.models.variate_config import DedupConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 模块级类型统一工具函数
|
||||
def _unify_entity_type(canonical: ExtractedEntityNode, losing: ExtractedEntityNode, suggested_type: str = None) -> None:
|
||||
@@ -198,6 +201,161 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 用户和AI助手的占位名称集合(用于名称标准化)
|
||||
_USER_PLACEHOLDER_NAMES = {"用户", "我", "user", "i"}
|
||||
_ASSISTANT_PLACEHOLDER_NAMES = {"ai助手", "助手", "人工智能助手", "智能助手", "智能体", "ai assistant", "assistant"}
|
||||
|
||||
# 标准化后的规范名称和类型
|
||||
_CANONICAL_USER_NAME = "用户"
|
||||
_CANONICAL_USER_TYPE = "用户"
|
||||
_CANONICAL_ASSISTANT_NAME = "AI助手"
|
||||
_CANONICAL_ASSISTANT_TYPE = "Agent"
|
||||
|
||||
# 用户和AI助手的所有可能名称(用于判断实体是否为特殊角色实体)
|
||||
_ALL_USER_NAMES = _USER_PLACEHOLDER_NAMES
|
||||
_ALL_ASSISTANT_NAMES = _ASSISTANT_PLACEHOLDER_NAMES
|
||||
|
||||
|
||||
def _is_user_entity(ent: ExtractedEntityNode) -> bool:
|
||||
"""判断实体是否为用户实体(name 或 entity_type 匹配)"""
|
||||
name = (getattr(ent, "name", "") or "").strip().lower()
|
||||
etype = (getattr(ent, "entity_type", "") or "").strip()
|
||||
return name in _ALL_USER_NAMES or etype == _CANONICAL_USER_TYPE
|
||||
|
||||
|
||||
def _is_assistant_entity(ent: ExtractedEntityNode) -> bool:
|
||||
"""判断实体是否为AI助手实体(name 或 entity_type 匹配)"""
|
||||
name = (getattr(ent, "name", "") or "").strip().lower()
|
||||
etype = (getattr(ent, "entity_type", "") or "").strip()
|
||||
return name in _ALL_ASSISTANT_NAMES or etype == _CANONICAL_ASSISTANT_TYPE
|
||||
|
||||
|
||||
def _would_merge_cross_role(a: ExtractedEntityNode, b: ExtractedEntityNode) -> bool:
|
||||
"""判断两个实体的合并是否会跨越用户/AI助手角色边界。
|
||||
|
||||
用户实体和AI助手实体永远不应该被合并在一起。
|
||||
如果一方是用户实体、另一方是AI助手实体,返回 True(阻止合并)。
|
||||
"""
|
||||
return (
|
||||
(_is_user_entity(a) and _is_assistant_entity(b))
|
||||
or (_is_assistant_entity(a) and _is_user_entity(b))
|
||||
)
|
||||
|
||||
|
||||
def _normalize_special_entity_names(
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
) -> None:
|
||||
"""标准化用户和AI助手实体的名称和类型。
|
||||
|
||||
多轮对话中,LLM 对同一角色可能使用不同的名称变体(如"用户"/"我"/"User",
|
||||
"AI助手"/"助手"/"Assistant"),导致精确匹配无法合并。
|
||||
此函数在去重前将这些变体统一为规范名称,并强制绑定 entity_type,确保:
|
||||
- name="用户" 的实体 entity_type 一定为 "用户"
|
||||
- name="AI助手" 的实体 entity_type 一定为 "Agent"
|
||||
|
||||
Args:
|
||||
entity_nodes: 实体节点列表(原地修改)
|
||||
"""
|
||||
for ent in entity_nodes:
|
||||
name = (getattr(ent, "name", "") or "").strip()
|
||||
name_lower = name.lower()
|
||||
|
||||
if name_lower in _USER_PLACEHOLDER_NAMES:
|
||||
ent.name = _CANONICAL_USER_NAME
|
||||
ent.entity_type = _CANONICAL_USER_TYPE
|
||||
elif name_lower in _ASSISTANT_PLACEHOLDER_NAMES:
|
||||
ent.name = _CANONICAL_ASSISTANT_NAME
|
||||
ent.entity_type = _CANONICAL_ASSISTANT_TYPE
|
||||
|
||||
# 第二步:清洗用户/AI助手之间的别名交叉污染(复用 clean_cross_role_aliases)
|
||||
clean_cross_role_aliases(entity_nodes)
|
||||
|
||||
|
||||
async def fetch_neo4j_assistant_aliases(neo4j_connector, end_user_id: str) -> set:
|
||||
"""从 Neo4j 查询 AI 助手实体的所有别名(小写归一化)。
|
||||
|
||||
这是助手别名查询的唯一入口,供 write_tools 和 extraction_orchestrator 共用,
|
||||
避免多处维护相同的 Cypher 和名称列表。
|
||||
|
||||
Args:
|
||||
neo4j_connector: Neo4j 连接器实例(需提供 execute_query 方法)
|
||||
end_user_id: 终端用户 ID
|
||||
|
||||
Returns:
|
||||
小写归一化后的助手别名集合
|
||||
"""
|
||||
# 查询名称列表:规范名称 + 常见变体(与 _normalize_special_entity_names 标准化后一致)
|
||||
query_names = [_CANONICAL_ASSISTANT_NAME, *_ASSISTANT_PLACEHOLDER_NAMES]
|
||||
# 去重保序
|
||||
query_names = list(dict.fromkeys(query_names))
|
||||
|
||||
cypher = """
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE e.end_user_id = $end_user_id AND e.name IN $names
|
||||
RETURN e.aliases AS aliases
|
||||
"""
|
||||
try:
|
||||
result = await neo4j_connector.execute_query(
|
||||
cypher, end_user_id=end_user_id, names=query_names
|
||||
)
|
||||
assistant_aliases: set = set()
|
||||
for record in (result or []):
|
||||
for alias in (record.get("aliases") or []):
|
||||
assistant_aliases.add(alias.strip().lower())
|
||||
if assistant_aliases:
|
||||
logger.debug(f"Neo4j 中 AI 助手别名: {assistant_aliases}")
|
||||
return assistant_aliases
|
||||
except Exception as e:
|
||||
logger.warning(f"查询 Neo4j AI 助手别名失败: {e}")
|
||||
return set()
|
||||
|
||||
|
||||
def clean_cross_role_aliases(
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
external_assistant_aliases: set = None,
|
||||
) -> None:
|
||||
"""清洗用户实体和AI助手实体之间的别名交叉污染。
|
||||
|
||||
在 Neo4j 写入前调用,确保:
|
||||
- 用户实体的 aliases 不包含 AI 助手的别名
|
||||
- AI 助手实体的 aliases 不包含用户的别名
|
||||
|
||||
Args:
|
||||
entity_nodes: 实体节点列表(原地修改)
|
||||
external_assistant_aliases: 外部传入的 AI 助手别名集合(如从 Neo4j 查询),
|
||||
与本轮实体中的 AI 助手别名合并使用
|
||||
"""
|
||||
# 收集本轮 AI 助手实体的所有别名
|
||||
assistant_aliases = set(external_assistant_aliases or set())
|
||||
user_aliases = set()
|
||||
|
||||
for ent in entity_nodes:
|
||||
if _is_assistant_entity(ent):
|
||||
for alias in (getattr(ent, "aliases", []) or []):
|
||||
assistant_aliases.add(alias.strip().lower())
|
||||
elif _is_user_entity(ent):
|
||||
for alias in (getattr(ent, "aliases", []) or []):
|
||||
user_aliases.add(alias.strip().lower())
|
||||
|
||||
# 从用户实体的 aliases 中移除 AI 助手别名
|
||||
if assistant_aliases:
|
||||
for ent in entity_nodes:
|
||||
if _is_user_entity(ent):
|
||||
original = getattr(ent, "aliases", []) or []
|
||||
cleaned = [a for a in original if a.strip().lower() not in assistant_aliases]
|
||||
if len(cleaned) < len(original):
|
||||
ent.aliases = cleaned
|
||||
|
||||
# 从 AI 助手实体的 aliases 中移除用户别名
|
||||
if user_aliases:
|
||||
for ent in entity_nodes:
|
||||
if _is_assistant_entity(ent):
|
||||
original = getattr(ent, "aliases", []) or []
|
||||
cleaned = [a for a in original if a.strip().lower() not in user_aliases]
|
||||
if len(cleaned) < len(original):
|
||||
ent.aliases = cleaned
|
||||
|
||||
|
||||
def accurate_match(
|
||||
entity_nodes: List[ExtractedEntityNode]
|
||||
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]:
|
||||
@@ -261,6 +419,10 @@ def accurate_match(
|
||||
canonical = alias_index.get((ent_uid, ent_name))
|
||||
# 确保不是自身
|
||||
if canonical is not None and canonical.id != ent.id:
|
||||
# 保护:禁止跨角色合并(用户实体和AI助手实体不能互相合并)
|
||||
if _would_merge_cross_role(canonical, ent):
|
||||
i += 1
|
||||
continue
|
||||
_merge_attribute(canonical, ent)
|
||||
id_redirect[ent.id] = canonical.id
|
||||
for k, v in list(id_redirect.items()):
|
||||
@@ -704,6 +866,11 @@ def fuzzy_match(
|
||||
# 条件A(快速通道):alias_match_merge = True
|
||||
# 条件B(标准通道):s_name ≥ tn AND s_type ≥ type_threshold AND overall ≥ tover
|
||||
if alias_match_merge or (s_name >= tn and s_type >= type_threshold and overall >= tover):
|
||||
# 保护:禁止跨角色合并(用户实体和AI助手实体不能互相合并)
|
||||
if _would_merge_cross_role(a, b):
|
||||
j += 1
|
||||
continue
|
||||
|
||||
# ========== 第六步:执行实体合并 ==========
|
||||
|
||||
# 6.1 合并别名
|
||||
@@ -813,6 +980,12 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
|
||||
b = entity_by_id.get(losing_id)
|
||||
if not a or not b: # 若不存在 a 或 b,可能已在精确或模糊阶段合并,在之前阶段合并之后,不会再处理但是处于审计的目的会记录
|
||||
continue
|
||||
# 保护:禁止跨角色合并(用户实体和AI助手实体不能互相合并)
|
||||
if _would_merge_cross_role(a, b):
|
||||
llm_records.append(
|
||||
f"[LLM阻断] 跨角色合并被阻止: {a.id} ({a.name}) 与 {b.id} ({b.name})"
|
||||
)
|
||||
continue
|
||||
_merge_attribute(a, b)
|
||||
# ID 重定向
|
||||
try:
|
||||
@@ -934,6 +1107,9 @@ async def deduplicate_entities_and_edges(
|
||||
返回:去重后的实体、语句→实体边、实体↔实体边。
|
||||
"""
|
||||
local_llm_records: List[str] = [] # 作为“审计日志”的本地收集器 初始化,保留为了之后对于LLM决策追溯
|
||||
# 0) 标准化用户和AI助手实体名称(确保多轮对话中的变体名称统一)
|
||||
_normalize_special_entity_names(entity_nodes)
|
||||
|
||||
# 1) 精确匹配
|
||||
deduped_entities, id_redirect, exact_merge_map = accurate_match(entity_nodes)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from app.core.memory.models.message_models import DialogData
|
||||
from app.core.memory.models.variate_config import ExtractionPipelineConfig
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
|
||||
deduplicate_entities_and_edges,
|
||||
clean_cross_role_aliases,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import (
|
||||
second_layer_dedup_and_merge_with_neo4j,
|
||||
@@ -100,6 +101,10 @@ async def dedup_layers_and_merge_and_return(
|
||||
except Exception as e:
|
||||
print(f"Second-layer dedup failed: {e}")
|
||||
|
||||
# 第二层去重后,清洗用户/AI助手之间的别名交叉污染
|
||||
# 第二层从 Neo4j 合并了旧实体,可能带入历史脏数据
|
||||
clean_cross_role_aliases(fused_entity_nodes)
|
||||
|
||||
return (
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
|
||||
@@ -44,6 +44,10 @@ from app.core.memory.models.variate_config import (
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import (
|
||||
dedup_layers_and_merge_and_return,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
|
||||
_USER_PLACEHOLDER_NAMES,
|
||||
fetch_neo4j_assistant_aliases,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import (
|
||||
embedding_generation,
|
||||
generate_entity_embeddings_from_triplets,
|
||||
@@ -1341,14 +1345,20 @@ class ExtractionOrchestrator:
|
||||
dialog_data_list: List[DialogData]
|
||||
) -> None:
|
||||
"""
|
||||
从 Neo4j 读取用户实体的最终 aliases,同步到 end_user 和 end_user_info 表
|
||||
将本轮提取的用户别名同步到 end_user 和 end_user_info 表。
|
||||
|
||||
注意:
|
||||
1. other_name 使用本次对话提取的第一个别名(保持时间顺序)
|
||||
2. aliases 从 Neo4j 读取(保持完整性)
|
||||
注意:此方法在 Neo4j 写入之前调用,因此不能依赖 Neo4j 作为别名的权威数据源。
|
||||
改为直接使用内存中去重后的 entity_nodes 的 aliases,与 PgSQL 已有的 aliases 合并。
|
||||
|
||||
策略:
|
||||
1. 从内存中的 entity_nodes 提取本轮用户别名(current_aliases)
|
||||
2. 从去重后的 entity_nodes 中提取完整别名(含 Neo4j 二层去重合并的历史别名)
|
||||
3. 从 PgSQL end_user_info 读取已有的 aliases(db_aliases)
|
||||
4. 合并 db_aliases + deduped_aliases + current_aliases,去重保序
|
||||
5. 写回 PgSQL
|
||||
|
||||
Args:
|
||||
entity_nodes: 实体节点列表
|
||||
entity_nodes: 去重后的实体节点列表(内存中,含二层去重合并结果)
|
||||
dialog_data_list: 对话数据列表
|
||||
"""
|
||||
try:
|
||||
@@ -1361,23 +1371,40 @@ class ExtractionOrchestrator:
|
||||
logger.warning("end_user_id 为空,跳过用户别名同步")
|
||||
return
|
||||
|
||||
# 1. 提取本次对话的用户别名(保持 LLM 提取的原始顺序,不排序)
|
||||
current_aliases = self._extract_current_aliases(entity_nodes)
|
||||
# 1. 提取本轮对话的用户别名(保持 LLM 提取的原始顺序,不排序)
|
||||
current_aliases = self._extract_current_aliases(entity_nodes, dialog_data_list)
|
||||
|
||||
# 2. 从 Neo4j 获取完整 aliases(权威数据源)
|
||||
neo4j_aliases = await self._fetch_neo4j_user_aliases(end_user_id)
|
||||
# 1.5 从去重后的 entity_nodes 中提取完整别名
|
||||
# 二层去重会将 Neo4j 中已有的历史别名合并到 entity_nodes 中,
|
||||
# 这里提取出来确保 PgSQL 与 Neo4j 的别名保持同步
|
||||
deduped_aliases = self._extract_deduped_entity_aliases(entity_nodes)
|
||||
|
||||
if not neo4j_aliases:
|
||||
# Neo4j 中没有别名,使用本次对话提取的别名
|
||||
neo4j_aliases = current_aliases
|
||||
if not neo4j_aliases:
|
||||
logger.debug(f"aliases 为空,跳过同步: end_user_id={end_user_id}")
|
||||
return
|
||||
# 1.6 从 Neo4j 查询已有的 AI 助手别名,作为额外的排除源
|
||||
# (防止 LLM 未提取出 AI 助手实体时,AI 别名泄漏到用户别名中)
|
||||
neo4j_assistant_aliases = await self._fetch_neo4j_assistant_aliases(end_user_id)
|
||||
if neo4j_assistant_aliases:
|
||||
before_count = len(current_aliases)
|
||||
current_aliases = [
|
||||
a for a in current_aliases
|
||||
if a.strip().lower() not in neo4j_assistant_aliases
|
||||
]
|
||||
if len(current_aliases) < before_count:
|
||||
logger.info(f"通过 Neo4j AI 助手别名排除了 {before_count - len(current_aliases)} 个误归属别名")
|
||||
# 同样过滤 deduped_aliases
|
||||
deduped_aliases = [
|
||||
a for a in deduped_aliases
|
||||
if a.strip().lower() not in neo4j_assistant_aliases
|
||||
]
|
||||
|
||||
logger.info(f"本次对话提取的 aliases: {current_aliases}")
|
||||
logger.info(f"Neo4j 中的完整 aliases: {neo4j_aliases}")
|
||||
if not current_aliases and not deduped_aliases:
|
||||
logger.debug(f"本轮未提取到用户别名,跳过同步: end_user_id={end_user_id}")
|
||||
return
|
||||
|
||||
# 3. 同步到数据库
|
||||
logger.info(f"本轮对话提取的 aliases: {current_aliases}")
|
||||
if deduped_aliases:
|
||||
logger.info(f"去重后实体的完整 aliases(含历史): {deduped_aliases}")
|
||||
|
||||
# 2. 同步到数据库
|
||||
end_user_uuid = uuid.UUID(end_user_id)
|
||||
with get_db_context() as db:
|
||||
# 更新 end_user 表
|
||||
@@ -1386,7 +1413,38 @@ class ExtractionOrchestrator:
|
||||
logger.warning(f"未找到 end_user_id={end_user_id} 的用户记录")
|
||||
return
|
||||
|
||||
new_name = self._resolve_other_name(end_user.other_name, current_aliases, neo4j_aliases)
|
||||
# 3. 从 PgSQL 读取已有 aliases 并与本轮合并
|
||||
info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid)
|
||||
db_aliases = (info.aliases if info and info.aliases else [])
|
||||
# 过滤掉占位名称
|
||||
db_aliases = [a for a in db_aliases if a.strip().lower() not in self.USER_PLACEHOLDER_NAMES]
|
||||
|
||||
# 合并:已有 + 去重后完整别名 + 本轮新增,去重保序
|
||||
merged_aliases = list(db_aliases)
|
||||
seen_lower = {a.strip().lower() for a in merged_aliases}
|
||||
# 先合并去重后实体的完整别名(含 Neo4j 历史别名)
|
||||
for alias in deduped_aliases:
|
||||
if alias.strip().lower() not in seen_lower:
|
||||
merged_aliases.append(alias)
|
||||
seen_lower.add(alias.strip().lower())
|
||||
# 再合并本轮新提取的别名
|
||||
for alias in current_aliases:
|
||||
if alias.strip().lower() not in seen_lower:
|
||||
merged_aliases.append(alias)
|
||||
seen_lower.add(alias.strip().lower())
|
||||
|
||||
# 最终过滤:从合并结果中排除 AI 助手别名(清理历史脏数据)
|
||||
if neo4j_assistant_aliases:
|
||||
merged_aliases = [
|
||||
a for a in merged_aliases
|
||||
if a.strip().lower() not in neo4j_assistant_aliases
|
||||
]
|
||||
|
||||
logger.info(f"PgSQL 已有 aliases: {db_aliases}")
|
||||
logger.info(f"合并后 aliases: {merged_aliases}")
|
||||
|
||||
# 更新 end_user 表 other_name
|
||||
new_name = self._resolve_other_name(end_user.other_name, current_aliases, merged_aliases)
|
||||
if new_name is not None:
|
||||
end_user.other_name = new_name
|
||||
logger.info(f"更新 end_user 表 other_name → {new_name}")
|
||||
@@ -1394,26 +1452,27 @@ class ExtractionOrchestrator:
|
||||
logger.debug(f"end_user 表 other_name 保持不变: {end_user.other_name}")
|
||||
|
||||
# 更新或创建 end_user_info 记录
|
||||
info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid)
|
||||
if info:
|
||||
new_name_info = self._resolve_other_name(info.other_name, current_aliases, neo4j_aliases)
|
||||
new_name_info = self._resolve_other_name(info.other_name, current_aliases, merged_aliases)
|
||||
if new_name_info is not None:
|
||||
info.other_name = new_name_info
|
||||
logger.info(f"更新 end_user_info 表 other_name → {new_name_info}")
|
||||
if info.aliases != neo4j_aliases:
|
||||
info.aliases = neo4j_aliases
|
||||
logger.info(f"同步 Neo4j aliases 到 end_user_info: {neo4j_aliases}")
|
||||
if info.aliases != merged_aliases:
|
||||
info.aliases = merged_aliases
|
||||
logger.info(f"同步合并后 aliases 到 end_user_info: {merged_aliases}")
|
||||
else:
|
||||
first_alias = current_aliases[0].strip() if current_aliases else ""
|
||||
first_alias = current_aliases[0].strip() if current_aliases else (
|
||||
deduped_aliases[0].strip() if deduped_aliases else ""
|
||||
)
|
||||
# 确保 first_alias 不是占位名称
|
||||
if first_alias and first_alias not in self.USER_PLACEHOLDER_NAMES:
|
||||
if first_alias and first_alias.lower() not in self.USER_PLACEHOLDER_NAMES:
|
||||
db.add(EndUserInfo(
|
||||
end_user_id=end_user_uuid,
|
||||
other_name=first_alias,
|
||||
aliases=neo4j_aliases,
|
||||
aliases=merged_aliases,
|
||||
meta_data={}
|
||||
))
|
||||
logger.info(f"创建 end_user_info 记录,other_name={first_alias}, aliases={neo4j_aliases}")
|
||||
logger.info(f"创建 end_user_info 记录,other_name={first_alias}, aliases={merged_aliases}")
|
||||
|
||||
db.commit()
|
||||
|
||||
@@ -1423,49 +1482,81 @@ class ExtractionOrchestrator:
|
||||
|
||||
|
||||
# 用户实体占位名称,不允许作为 other_name 或出现在 aliases 中
|
||||
USER_PLACEHOLDER_NAMES = {'用户', '我', 'User', 'I'}
|
||||
# 复用 deduped_and_disamb 模块级常量,避免重复维护
|
||||
USER_PLACEHOLDER_NAMES = _USER_PLACEHOLDER_NAMES
|
||||
|
||||
def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode]) -> List[str]:
|
||||
"""从实体节点提取用户别名(保持 LLM 提取的原始顺序,不进行任何排序)
|
||||
def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode], dialog_data_list=None) -> List[str]:
|
||||
"""从用户发言的原始实体中提取本轮新增别名(绕过去重污染)
|
||||
|
||||
这个方法直接返回 LLM 提取的别名列表,并过滤掉占位名称("用户"、"我"、"User"、"I")。
|
||||
第一个别名将被用作 other_name。
|
||||
策略:
|
||||
仅从 dialog_data_list 中找到 speaker="user" 的 statement,
|
||||
从这些 statement 的 triplet_extraction_info 中提取用户实体的 aliases。
|
||||
这样拿到的是 LLM 对用户原话的提取结果,不受去重合并的影响。
|
||||
|
||||
注意:不再使用去重后 entity_nodes 作为兜底,因为二层去重会将 Neo4j 历史别名
|
||||
合并进来,导致历史别名被误认为"本轮提取"。历史别名的同步由
|
||||
_extract_deduped_entity_aliases 负责。
|
||||
|
||||
Args:
|
||||
entity_nodes: 实体节点列表
|
||||
entity_nodes: 去重后的实体节点列表(未使用,保留参数兼容性)
|
||||
dialog_data_list: 对话数据列表
|
||||
|
||||
Returns:
|
||||
别名列表(保持 LLM 提取的原始顺序,已过滤占位名称)
|
||||
别名列表(保持原始顺序,已过滤)
|
||||
"""
|
||||
if not dialog_data_list:
|
||||
return []
|
||||
|
||||
all_user_aliases = []
|
||||
seen_lower = set()
|
||||
for dialog in dialog_data_list:
|
||||
for chunk in dialog.chunks:
|
||||
speaker = getattr(chunk, 'speaker', None)
|
||||
for statement in chunk.statements:
|
||||
stmt_speaker = getattr(statement, 'speaker', None) or speaker
|
||||
if stmt_speaker != "user":
|
||||
continue
|
||||
triplet_info = getattr(statement, 'triplet_extraction_info', None)
|
||||
if not triplet_info:
|
||||
continue
|
||||
for entity in (triplet_info.entities or []):
|
||||
ent_name = getattr(entity, 'name', '').strip()
|
||||
if ent_name.lower() in self.USER_PLACEHOLDER_NAMES:
|
||||
for alias in (getattr(entity, 'aliases', []) or []):
|
||||
a = alias.strip()
|
||||
if a and a.lower() not in self.USER_PLACEHOLDER_NAMES and a.lower() not in seen_lower:
|
||||
all_user_aliases.append(a)
|
||||
seen_lower.add(a.lower())
|
||||
if all_user_aliases:
|
||||
logger.debug(f"从用户原始发言提取到别名: {all_user_aliases}")
|
||||
return all_user_aliases
|
||||
|
||||
def _extract_deduped_entity_aliases(self, entity_nodes: List[ExtractedEntityNode]) -> List[str]:
|
||||
"""从去重后的用户实体中提取完整别名列表。
|
||||
|
||||
二层去重会将 Neo4j 中已有的历史别名合并到 entity_nodes 的用户实体中,
|
||||
因此这里提取到的别名包含了历史积累的所有别名,可用于同步到 PgSQL。
|
||||
|
||||
Args:
|
||||
entity_nodes: 去重后的实体节点列表(含二层去重合并结果)
|
||||
|
||||
Returns:
|
||||
别名列表(已过滤占位名称,去重保序)
|
||||
"""
|
||||
for entity in entity_nodes:
|
||||
if getattr(entity, 'name', '').strip() in self.USER_PLACEHOLDER_NAMES:
|
||||
if getattr(entity, 'name', '').strip().lower() in self.USER_PLACEHOLDER_NAMES:
|
||||
aliases = getattr(entity, 'aliases', []) or []
|
||||
# 过滤掉占位名称,防止 "用户"/"我"/"User"/"I" 被存入 aliases 和 other_name
|
||||
filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES]
|
||||
logger.debug(f"提取到用户别名(原始顺序,已过滤占位名称): {filtered}")
|
||||
return filtered
|
||||
filtered = [
|
||||
a for a in aliases
|
||||
if a.strip().lower() not in self.USER_PLACEHOLDER_NAMES
|
||||
]
|
||||
if filtered:
|
||||
return filtered
|
||||
return []
|
||||
|
||||
|
||||
async def _fetch_neo4j_user_aliases(self, end_user_id: str) -> List[str]:
|
||||
"""从 Neo4j 查询用户实体的完整 aliases 列表(已过滤占位名称)"""
|
||||
cypher = """
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE e.end_user_id = $end_user_id AND e.name IN ['用户', '我', 'User', 'I']
|
||||
RETURN e.aliases AS aliases
|
||||
LIMIT 1
|
||||
"""
|
||||
result = await Neo4jConnector().execute_query(cypher, end_user_id=end_user_id)
|
||||
if not result:
|
||||
logger.debug(f"Neo4j 中未找到用户实体: end_user_id={end_user_id}")
|
||||
return []
|
||||
aliases = result[0].get('aliases') or []
|
||||
if not aliases:
|
||||
logger.debug(f"Neo4j 用户实体 aliases 为空: end_user_id={end_user_id}")
|
||||
return []
|
||||
# 过滤掉占位名称,防止历史脏数据传播
|
||||
filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES]
|
||||
return filtered
|
||||
async def _fetch_neo4j_assistant_aliases(self, end_user_id: str) -> set:
|
||||
"""从 Neo4j 查询 AI 助手实体的所有别名(用于从用户别名中排除)"""
|
||||
return await fetch_neo4j_assistant_aliases(self.connector, end_user_id)
|
||||
|
||||
def _resolve_other_name(
|
||||
self,
|
||||
@@ -1484,16 +1575,16 @@ class ExtractionOrchestrator:
|
||||
注意:返回值不允许是占位名称("用户"、"我"、"User"、"I")
|
||||
"""
|
||||
# 当前值为空或为占位名称时,需要更新
|
||||
if not current or not current.strip() or current.strip() in self.USER_PLACEHOLDER_NAMES:
|
||||
if not current or not current.strip() or current.strip().lower() in self.USER_PLACEHOLDER_NAMES:
|
||||
candidate = current_aliases[0].strip() if current_aliases else None
|
||||
# 确保候选值不是占位名称
|
||||
if candidate and candidate in self.USER_PLACEHOLDER_NAMES:
|
||||
if candidate and candidate.lower() in self.USER_PLACEHOLDER_NAMES:
|
||||
return None
|
||||
return candidate
|
||||
if current not in neo4j_aliases:
|
||||
candidate = neo4j_aliases[0].strip() if neo4j_aliases else None
|
||||
# 确保候选值不是占位名称
|
||||
if candidate and candidate in self.USER_PLACEHOLDER_NAMES:
|
||||
if candidate and candidate.lower() in self.USER_PLACEHOLDER_NAMES:
|
||||
return None
|
||||
return candidate
|
||||
|
||||
|
||||
@@ -61,6 +61,7 @@ class TripletExtractor:
|
||||
predicate_instructions=PREDICATE_DEFINITIONS,
|
||||
language=self._get_language(),
|
||||
ontology_types=self.ontology_types,
|
||||
speaker=getattr(statement, 'speaker', None),
|
||||
)
|
||||
|
||||
# Create messages for LLM
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||
from app.core.memory.utils.log.logging_utils import log_prompt_rendering, log_template_rendering
|
||||
|
||||
# Setup Jinja2 environment
|
||||
@@ -205,6 +205,7 @@ async def render_triplet_extraction_prompt(
|
||||
predicate_instructions: dict = None,
|
||||
language: str = "zh",
|
||||
ontology_types: "OntologyTypeList | None" = None,
|
||||
speaker: str = None,
|
||||
) -> str:
|
||||
"""
|
||||
Renders the triplet extraction prompt using the extract_triplet.jinja2 template.
|
||||
@@ -216,6 +217,7 @@ async def render_triplet_extraction_prompt(
|
||||
predicate_instructions: Optional predicate instructions
|
||||
language: The language to use for entity descriptions ("zh" for Chinese, "en" for English)
|
||||
ontology_types: Optional OntologyTypeList containing predefined ontology types for entity classification
|
||||
speaker: Speaker role ("user" or "assistant") for the current statement
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
@@ -223,7 +225,7 @@ async def render_triplet_extraction_prompt(
|
||||
template = prompt_env.get_template("extract_triplet.jinja2")
|
||||
|
||||
# 准备本体类型数据
|
||||
ontology_type_section = ""
|
||||
ontology_type_section = None
|
||||
ontology_type_names = []
|
||||
type_hierarchy_hints = []
|
||||
if ontology_types and ontology_types.types:
|
||||
@@ -240,6 +242,7 @@ async def render_triplet_extraction_prompt(
|
||||
ontology_types=ontology_type_section,
|
||||
ontology_type_names=ontology_type_names,
|
||||
type_hierarchy_hints=type_hierarchy_hints,
|
||||
speaker=speaker,
|
||||
)
|
||||
# 记录渲染结果到提示日志(与示例日志结构一致)
|
||||
log_prompt_rendering('triplet extraction', rendered_prompt)
|
||||
|
||||
@@ -23,6 +23,16 @@ Extract entities and knowledge triplets from the given statement.
|
||||
===Inputs===
|
||||
**Chunk Content:** "{{ chunk_content }}"
|
||||
**Statement:** "{{ statement }}"
|
||||
{% if speaker %}
|
||||
**Speaker:** {{ speaker }}
|
||||
{% if speaker == "assistant" %}
|
||||
{% if language == "zh" %}
|
||||
⚠️ 当前陈述句来自 **AI助手的回复**。AI助手在回复中用来称呼用户的名字是**用户的别名**,不是 AI 助手的别名。但只能提取原文中逐字出现的名字,严禁推测或创造原文中不存在的别名变体。
|
||||
{% else %}
|
||||
⚠️ This statement is from the **AI assistant's reply**. Names the AI uses to address the user are **user's aliases**, NOT the AI assistant's aliases. But only extract names that appear VERBATIM in the text — never infer or fabricate alias variants.
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
{% if ontology_types %}
|
||||
===Ontology Type Guidance===
|
||||
@@ -87,7 +97,17 @@ Extract entities and knowledge triplets from the given statement.
|
||||
* "我叫张三,大家叫我小张" → aliases=["张三", "小张"](张三是第一个,将成为 other_name)
|
||||
* "大家叫我小李,我全名叫李明" → aliases=["小李", "李明"](小李先出现,将成为 other_name)
|
||||
- 空值:如果没有别名,使用 `[]`
|
||||
- 重要:只提取本次对话中明确提到的别名,不要推测或添加未提及的名字
|
||||
- **🚨🚨🚨 严禁幻觉:只提取对话原文中逐字出现的别名,绝对不能推测、衍生或创造任何未在原文中出现的名字。例如,看到"陈思远"不能自行添加"思远大人""远哥""小远"等变体。如果原文没有这些字,就不能出现在 aliases 中。**
|
||||
- **🚨 归属区分:必须严格区分名称的归属对象。默认情况下,用户提到的名字归属用户实体。只有出现明确的第二人称命名表达(如"叫你""给你取名")时,才将名字归属 AI/助手实体。**
|
||||
- **🚨 说话人视角:当 speaker 为 assistant 时,AI 助手用来称呼用户的名字是用户的别名,必须归入用户实体的 aliases,绝对不能归入 AI 助手实体。但同样只能提取原文中逐字出现的称呼,不能推测。**
|
||||
* "我叫陈思远,我给AI取名为远仔" → 用户 aliases=["陈思远"],AI助手 aliases=["远仔"]
|
||||
* "我叫vv" → 用户 aliases=["vv"](没有给AI取名的表达,名字归用户)
|
||||
* [speaker=assistant] "好的,VV" → 用户 aliases=["VV"](AI 在称呼用户,原文中出现了"VV")
|
||||
* [speaker=assistant] "我叫陈仔" → AI助手 aliases=["陈仔"](AI 在自我介绍,这是 AI 的别名)
|
||||
* ❌ 错误:将"远仔"放入用户的 aliases("远仔"是给AI取的名字,不是用户的名字)
|
||||
* ❌ 错误:用户说"我叫vv",却把"vv"放入 AI 助手的 aliases
|
||||
* ❌ 错误:AI 称呼用户为"VV",却把"VV"放入 AI 助手的 aliases
|
||||
* ❌ 错误:原文只有"陈思远",却在 aliases 中添加"思远大人""远哥""小远"等从未出现的变体(这是幻觉)
|
||||
{% else %}
|
||||
- Include: nicknames, full names, abbreviations, alternative names
|
||||
- Order: **The FIRST alias will be used as the user's primary display name (other_name). Put the most important/frequently used name FIRST**
|
||||
@@ -96,7 +116,17 @@ Extract entities and knowledge triplets from the given statement.
|
||||
* "I'm John, people call me Johnny" → aliases=["John", "Johnny"] (John is first, will become other_name)
|
||||
* "People call me Mike, my full name is Michael" → aliases=["Mike", "Michael"] (Mike appears first, will become other_name)
|
||||
- Empty: If no aliases, use `[]`
|
||||
- Important: Only extract aliases explicitly mentioned in current conversation, do not infer or add unmentioned names
|
||||
- **🚨🚨🚨 NO HALLUCINATION: Only extract aliases that appear VERBATIM in the original text. NEVER infer, derive, or fabricate names not present in the text. For example, seeing "John Smith" does NOT allow adding "Johnny", "Smithy", "Mr. Smith" unless those exact strings appear in the conversation.**
|
||||
- **🚨 Ownership distinction: By default, all names mentioned by the user belong to the user entity. Only assign a name to the AI/assistant entity when an explicit second-person naming expression (e.g., "I'll call you", "your name is") is present.**
|
||||
- **🚨 Speaker perspective: When speaker is "assistant", names the AI uses to address the user are the USER's aliases and MUST go into the user entity's aliases, NEVER into the AI assistant entity's aliases. But only extract names that appear verbatim in the text, never infer.**
|
||||
* "I'm Alex, I'll call you Buddy" → User aliases=["Alex"], AI assistant aliases=["Buddy"]
|
||||
* "I'm vv" → User aliases=["vv"] (no AI-naming expression, name belongs to user)
|
||||
* [speaker=assistant] "Sure thing, VV" → User aliases=["VV"] (AI addressing the user, "VV" appears in text)
|
||||
* [speaker=assistant] "I'm Jarvis" → AI assistant aliases=["Jarvis"] (AI self-introduction, this is AI's alias)
|
||||
* ❌ Wrong: putting "Buddy" in user's aliases ("Buddy" is a name for the AI, not the user)
|
||||
* ❌ Wrong: User says "I'm vv" but "vv" is put in AI assistant's aliases
|
||||
* ❌ Wrong: AI calls user "VV" but "VV" is put in AI assistant's aliases
|
||||
* ❌ Wrong: Text only has "John Smith" but aliases include "Johnny", "Smithy" (hallucinated variants)
|
||||
{% endif %}
|
||||
|
||||
|
||||
@@ -122,7 +152,60 @@ Extract entities and knowledge triplets from the given statement.
|
||||
|
||||
|
||||
|
||||
4. **ALIASES ORDER:**
|
||||
4. **AI/ASSISTANT ENTITY SPECIAL HANDLING:**
|
||||
{% if language == "zh" %}
|
||||
- **🚨 默认规则:如果对话中没有出现明确指向 AI/助手的命名表达,则所有名字都归属于用户实体。不要猜测或推断某个名字是给 AI 取的。**
|
||||
- 只有当用户**明确**对 AI/助手进行命名时,才创建 AI/助手实体并将对应名字放入其 aliases
|
||||
- AI/助手实体的 name 字段:使用 "AI助手"
|
||||
- 用户给 AI 取的名字:放入 AI/助手实体的 aliases
|
||||
- **🚨 禁止将用户给 AI 取的名字放入用户实体的 aliases 中**
|
||||
- **必须出现以下明确的命名表达才能判定为给 AI 取名:**「给你取名」「叫你」「称呼你为」「给AI取名」「你的名字是」「以后叫你」「你就叫」「你不叫X了」「你现在叫」等**第二人称(你)或明确指向 AI 的命名句式**
|
||||
- **🚨 "你不叫X了"/"你不叫X,你叫Y" 句式:X 和 Y 都是 AI 的名字(旧名和新名),绝对不是用户的名字。因为句子主语是"你"(AI)。**
|
||||
- **以下情况名字归属用户,不是给 AI 取名:**「我叫」「我的名字是」「叫我」「我是」「大家叫我」「我的英文名是」「我的昵称是」等**第一人称(我)的自我介绍句式**
|
||||
- **🚨 speaker=assistant 时的特殊规则:**
|
||||
* AI 用来称呼用户的名字 → 归入**用户**实体的 aliases(但必须是原文中逐字出现的称呼,不能推测)
|
||||
* AI 自称的名字(如"我叫陈仔""我是你的助手")→ 归入**AI助手**实体的 aliases
|
||||
* 判断依据:AI 说"你叫X"或用 X 称呼用户 → X 是用户别名;AI 说"我叫X"或"我是X" → X 是 AI 别名
|
||||
- 示例:
|
||||
* "我叫vv" → 用户实体: name="用户", aliases=["vv"](第一人称自我介绍,名字归用户)
|
||||
* "我的英文名叫vv" → 用户实体: name="用户", aliases=["vv"](第一人称自我介绍,名字归用户)
|
||||
* "我叫陈思远,我给AI取名为远仔" → 用户实体: name="用户", aliases=["陈思远"];AI实体: name="AI助手", aliases=["远仔"]
|
||||
* "叫你小助,我自己叫老王" → 用户实体: name="用户", aliases=["老王"];AI实体: name="AI助手", aliases=["小助"]
|
||||
* "你不叫远仔了,你现在叫陈仔" → AI实体: name="AI助手", aliases=["陈仔"]("远仔"是AI旧名,"陈仔"是AI新名,都归AI。不要把"远仔"或"陈仔"放入用户的aliases)
|
||||
* [speaker=assistant] "好的VV,今天想干点啥?" → 用户实体: name="用户", aliases=["VV"](AI 在称呼用户,原文中出现了"VV")
|
||||
* [speaker=assistant] "你叫陈思远,我叫陈仔" → 用户实体: name="用户", aliases=["陈思远"];AI实体: name="AI助手", aliases=["陈仔"]
|
||||
* ❌ 错误:用户说"我叫vv",却把"vv"放入 AI 助手的 aliases(没有任何给 AI 取名的表达)
|
||||
* ❌ 错误:AI 称呼用户为"VV",却把"VV"放入 AI 助手的 aliases
|
||||
* ❌ 错误:aliases=["陈思远", "远仔"]("远仔"是给AI取的名字,不是用户的名字)
|
||||
* ❌ 错误:原文只有"陈思远",却在 aliases 中添加"思远大人""远哥""小远"等从未出现的变体(这是幻觉)
|
||||
{% else %}
|
||||
- **🚨 Default rule: If there is NO explicit AI/assistant naming expression in the conversation, ALL names belong to the user entity. Do NOT guess or infer that a name is for the AI.**
|
||||
- Only create an AI/assistant entity when the user **explicitly** names the AI/assistant
|
||||
- AI/assistant entity name field: use "AI Assistant"
|
||||
- Names the user gives to the AI: put in the AI/assistant entity's aliases
|
||||
- **🚨 NEVER put names given to the AI into the user entity's aliases**
|
||||
- **An AI-naming expression MUST be present to assign a name to the AI:** "I'll call you", "your name is", "I name you", "let me call you", "you'll be called", "you're not called X anymore", "your new name is", etc. — **second-person ("you") or explicit AI-directed naming patterns**
|
||||
- **🚨 "You're not called X anymore" / "You're not X, you're Y" pattern: BOTH X and Y are AI's names (old and new). They are NOT user's names. The subject is "you" (the AI).**
|
||||
- **These patterns mean the name belongs to the USER, NOT the AI:** "I'm", "my name is", "call me", "I am", "people call me", "my English name is", "my nickname is", etc. — **first-person ("I"/"me") self-introduction patterns**
|
||||
- **🚨 Special rules when speaker=assistant:**
|
||||
* Names the AI uses to address the user → belong to the **user** entity's aliases (but only extract names that appear verbatim in the text, never infer)
|
||||
* Names the AI uses for itself (e.g., "I'm Jarvis", "I am your assistant") → belong to the **AI assistant** entity's aliases
|
||||
* Rule: AI says "you are X" or calls user X → X is user's alias; AI says "I'm X" or "I am X" → X is AI's alias
|
||||
- Examples:
|
||||
* "I'm vv" → User entity: name="User", aliases=["vv"] (first-person intro, name belongs to user)
|
||||
* "My English name is vv" → User entity: name="User", aliases=["vv"] (first-person intro, name belongs to user)
|
||||
* "I'm Alex, I'll call you Buddy" → User entity: name="User", aliases=["Alex"]; AI entity: name="AI Assistant", aliases=["Buddy"]
|
||||
* "Call yourself Jarvis, my name is Tony" → User entity: name="User", aliases=["Tony"]; AI entity: name="AI Assistant", aliases=["Jarvis"]
|
||||
* "You're not called Jarvis anymore, your new name is Friday" → AI entity: name="AI Assistant", aliases=["Friday"] (both "Jarvis" and "Friday" are AI names, NOT user names)
|
||||
* [speaker=assistant] "Sure thing, VV" → User entity: name="User", aliases=["VV"] (AI addressing the user, "VV" appears in text)
|
||||
* [speaker=assistant] "You're Alex, and I'm Jarvis" → User entity: name="User", aliases=["Alex"]; AI entity: name="AI Assistant", aliases=["Jarvis"]
|
||||
* ❌ Wrong: User says "I'm vv" but "vv" is put in AI assistant's aliases (no AI-naming expression exists)
|
||||
* ❌ Wrong: AI calls user "VV" but "VV" is put in AI assistant's aliases
|
||||
* ❌ Wrong: aliases=["Alex", "Buddy"] ("Buddy" is a name for the AI, not the user)
|
||||
* ❌ Wrong: Text only has "John Smith" but aliases include "Johnny", "Smithy" (hallucinated variants)
|
||||
{% endif %}
|
||||
|
||||
5. **ALIASES ORDER:**
|
||||
{% if language == "zh" %}
|
||||
- 顺序优先级:按出现顺序,先出现的在前
|
||||
{% else %}
|
||||
@@ -202,8 +285,19 @@ Output:
|
||||
{"entity_idx": 0, "name": "Tripod", "type": "Equipment", "description": "Photography equipment accessory", "example": "", "aliases": ["Camera Tripod"], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
|
||||
**Example 4 (User vs AI alias distinction - English output):** "I'm Alex, and I'll call you Buddy"
|
||||
Output:
|
||||
{
|
||||
"triplets": [
|
||||
{"subject_name": "User", "subject_id": 0, "predicate": "NAMED", "object_name": "AI Assistant", "object_id": 1, "value": "Buddy"}
|
||||
],
|
||||
"entities": [
|
||||
{"entity_idx": 0, "name": "User", "type": "Person", "description": "The user", "example": "", "aliases": ["Alex"], "is_explicit_memory": false},
|
||||
{"entity_idx": 1, "name": "AI Assistant", "type": "Person", "description": "The user's AI assistant", "example": "", "aliases": ["Buddy"], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
{% else %}
|
||||
**Example 1 (English input → Chinese output):** "I plan to travel to Paris next week and visit the Louvre."
|
||||
Output:
|
||||
{
|
||||
"triplets": [
|
||||
@@ -258,6 +352,39 @@ Output:
|
||||
]
|
||||
}
|
||||
|
||||
**Example 6 (用户与AI别名区分 - Chinese):** "我称呼自己为陈思远,我给AI取名为远仔"
|
||||
Output:
|
||||
{
|
||||
"triplets": [
|
||||
{"subject_name": "用户", "subject_id": 0, "predicate": "NAMED", "object_name": "AI助手", "object_id": 1, "value": "远仔"}
|
||||
],
|
||||
"entities": [
|
||||
{"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["陈思远"], "is_explicit_memory": false},
|
||||
{"entity_idx": 1, "name": "AI助手", "type": "Person", "description": "用户的AI助手", "example": "", "aliases": ["远仔"], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
|
||||
**Example 7 (纯用户自我介绍,无AI命名 - Chinese):** "我叫vv"
|
||||
Output:
|
||||
{
|
||||
"triplets": [],
|
||||
"entities": [
|
||||
{"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["vv"], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
|
||||
**Example 8 (给AI改名 - Chinese):** "你不叫远仔了,你现在叫陈仔"
|
||||
Output:
|
||||
{
|
||||
"triplets": [
|
||||
{"subject_name": "用户", "subject_id": 0, "predicate": "NAMED", "object_name": "AI助手", "object_id": 1, "value": "陈仔"}
|
||||
],
|
||||
"entities": [
|
||||
{"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": [], "is_explicit_memory": false},
|
||||
{"entity_idx": 1, "name": "AI助手", "type": "Person", "description": "用户的AI助手", "example": "", "aliases": ["陈仔"], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
{% endif %}
|
||||
===End of Examples===
|
||||
|
||||
@@ -25,8 +25,34 @@ class RedBearEmbeddings(Embeddings):
|
||||
def _create_model(self, config: RedBearModelConfig) -> Embeddings:
|
||||
"""根据配置创建 LangChain 模型"""
|
||||
embedding_class = get_provider_embedding_class(config.provider)
|
||||
model_params = RedBearModelFactory.get_model_params(config)
|
||||
return embedding_class(**model_params)
|
||||
provider = config.provider.lower()
|
||||
# Embedding models only need connection params, never LLM-specific ones
|
||||
# (e.g. enable_thinking, model_kwargs) — build params directly.
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
import httpx
|
||||
params = {
|
||||
"model": config.model_name,
|
||||
"base_url": config.base_url,
|
||||
"api_key": config.api_key,
|
||||
"timeout": httpx.Timeout(timeout=config.timeout, connect=60.0),
|
||||
"max_retries": config.max_retries,
|
||||
}
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
params = {
|
||||
"model": config.model_name,
|
||||
"dashscope_api_key": config.api_key,
|
||||
"max_retries": config.max_retries,
|
||||
}
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
params = {
|
||||
"model": config.model_name,
|
||||
"base_url": config.base_url,
|
||||
}
|
||||
elif provider == ModelProvider.BEDROCK:
|
||||
params = RedBearModelFactory.get_model_params(config)
|
||||
else:
|
||||
params = RedBearModelFactory.get_model_params(config)
|
||||
return embedding_class(**params)
|
||||
|
||||
def _create_volcano_client(self, config: RedBearModelConfig):
|
||||
"""创建火山引擎客户端"""
|
||||
|
||||
@@ -6,14 +6,28 @@ ChatOpenAI 在解析流式 SSE 时只取 delta.content,会丢弃 delta.reasoni
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain_core.outputs import ChatGenerationChunk
|
||||
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
|
||||
class VolcanoChatOpenAI(ChatOpenAI):
|
||||
"""火山引擎 Chat 模型,支持深度思考内容(reasoning_content)的流式透传。"""
|
||||
"""火山引擎 Chat 模型,支持深度思考内容(reasoning_content)的流式和非流式透传。"""
|
||||
|
||||
def _create_chat_result(self, response: Union[dict, Any], generation_info: Optional[dict] = None) -> ChatResult:
|
||||
result = super()._create_chat_result(response, generation_info)
|
||||
# 将非流式响应中的 reasoning_content 补入 additional_kwargs
|
||||
choices = response.choices if hasattr(response, "choices") else response.get("choices", [])
|
||||
if choices:
|
||||
message = choices[0].message if hasattr(choices[0], "message") else choices[0].get("message", {})
|
||||
reasoning = (
|
||||
getattr(message, "reasoning_content", None)
|
||||
or (message.get("reasoning_content") if isinstance(message, dict) else None)
|
||||
)
|
||||
if reasoning and result.generations:
|
||||
result.generations[0].message.additional_kwargs["reasoning_content"] = reasoning
|
||||
return result
|
||||
|
||||
def _convert_chunk_to_generation_chunk(
|
||||
self,
|
||||
|
||||
@@ -27,7 +27,7 @@ class DateTimeTool(BuiltinTool):
|
||||
type=ParameterType.STRING,
|
||||
description="操作类型",
|
||||
required=True,
|
||||
enum=["format", "convert_timezone", "timestamp_to_datetime", "now"]
|
||||
enum=["format", "convert_timezone", "timestamp_to_datetime", "now", "datetime_to_timestamp"]
|
||||
),
|
||||
ToolParameter(
|
||||
name="input_value",
|
||||
|
||||
@@ -32,13 +32,16 @@ from app.core.workflow.nodes.configs import (
|
||||
NoteNodeConfig,
|
||||
ParameterExtractorNodeConfig,
|
||||
QuestionClassifierNodeConfig,
|
||||
VariableAggregatorNodeConfig
|
||||
VariableAggregatorNodeConfig,
|
||||
ListOperatorNodeConfig,
|
||||
DocExtractorNodeConfig,
|
||||
)
|
||||
from app.core.workflow.nodes.cycle_graph.config import (
|
||||
ConditionDetail as LoopConditionDetail,
|
||||
ConditionsConfig,
|
||||
CycleVariable
|
||||
)
|
||||
from app.core.workflow.nodes.list_operator.config import FilterCondition
|
||||
from app.core.workflow.nodes.enums import (
|
||||
ValueInputType,
|
||||
ComparisonOperator,
|
||||
@@ -90,6 +93,8 @@ class DifyConverter(BaseConverter):
|
||||
NodeType.VAR_AGGREGATOR: self.convert_variable_aggregator_node_config,
|
||||
NodeType.TOOL: self.convert_tool_node_config,
|
||||
NodeType.NOTES: self.convert_notes_config,
|
||||
NodeType.LIST_OPERATOR: self.convert_list_operator_node_config,
|
||||
NodeType.DOCUMENT_EXTRACTOR: self.convert_document_extractor_node_config,
|
||||
NodeType.CYCLE_START: lambda x: {},
|
||||
NodeType.BREAK: lambda x: {},
|
||||
}
|
||||
@@ -213,7 +218,9 @@ class DifyConverter(BaseConverter):
|
||||
"end with": ComparisonOperator.END_WITH,
|
||||
"not contains": ComparisonOperator.NOT_CONTAINS,
|
||||
"exists": ComparisonOperator.NOT_EMPTY,
|
||||
"not exists": ComparisonOperator.EMPTY
|
||||
"not exists": ComparisonOperator.EMPTY,
|
||||
"in": ComparisonOperator.IN,
|
||||
"not in": ComparisonOperator.NOT_IN,
|
||||
}
|
||||
return operator_map.get(operator, operator)
|
||||
|
||||
@@ -771,3 +778,46 @@ class DifyConverter(BaseConverter):
|
||||
show_author=node_data.get("showAuthor", True)
|
||||
).model_dump()
|
||||
return result
|
||||
|
||||
def convert_list_operator_node_config(self, node: dict) -> dict:
|
||||
"""Dify list-operator — convert variable path array to {{ }} selector format."""
|
||||
node_data = node["data"]
|
||||
variable_path = node_data.get("variable", [])
|
||||
input_list = self._process_list_variable_literal(variable_path) or ""
|
||||
filter_by = node_data.get("filter_by", {"enabled": False, "conditions": []})
|
||||
# Convert each condition's comparison_operator from Dify format to native
|
||||
if filter_by.get("conditions"):
|
||||
converted_conditions = []
|
||||
for cond in filter_by["conditions"]:
|
||||
converted_conditions.append({
|
||||
**cond,
|
||||
"comparison_operator": self.convert_compare_operator(
|
||||
cond.get("comparison_operator", "")
|
||||
)
|
||||
})
|
||||
filter_by = {**filter_by, "conditions": converted_conditions}
|
||||
result = {
|
||||
"input_list": input_list,
|
||||
"filter_by": filter_by,
|
||||
"order_by": node_data.get("order_by", {"enabled": False, "key": "", "value": "asc"}),
|
||||
"limit": node_data.get("limit", {"enabled": False, "size": -1}),
|
||||
"extract_by": node_data.get("extract_by", {"enabled": False, "serial": "1"}),
|
||||
}
|
||||
self.config_validate(node["id"], node["data"]["title"], ListOperatorNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_document_extractor_node_config(self, node: dict) -> dict:
|
||||
"""Convert Dify document-extractor node to MemoryBear DocExtractorNodeConfig.
|
||||
|
||||
Dify document-extractor data fields:
|
||||
variable_selector: list[str] - file variable path
|
||||
"""
|
||||
node_data = node["data"]
|
||||
file_selector = self._process_list_variable_literal(
|
||||
node_data.get("variable_selector", [])
|
||||
) or ""
|
||||
result = DocExtractorNodeConfig.model_construct(
|
||||
file_selector=file_selector,
|
||||
).model_dump()
|
||||
self.config_validate(node["id"], node["data"]["title"], DocExtractorNodeConfig, result)
|
||||
return result
|
||||
|
||||
@@ -45,6 +45,8 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
"question-classifier": NodeType.QUESTION_CLASSIFIER,
|
||||
"variable-aggregator": NodeType.VAR_AGGREGATOR,
|
||||
"tool": NodeType.TOOL,
|
||||
"list-operator": NodeType.LIST_OPERATOR,
|
||||
"document-extractor": NodeType.DOCUMENT_EXTRACTOR,
|
||||
"": NodeType.NOTES
|
||||
}
|
||||
|
||||
|
||||
@@ -22,6 +22,8 @@ from app.core.workflow.nodes.configs import (
|
||||
MemoryReadNodeConfig,
|
||||
MemoryWriteNodeConfig,
|
||||
NoteNodeConfig,
|
||||
ListOperatorNodeConfig,
|
||||
DocExtractorNodeConfig,
|
||||
)
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
@@ -51,6 +53,8 @@ class MemoryBearConverter(BaseConverter):
|
||||
NodeType.MEMORY_READ: MemoryReadNodeConfig,
|
||||
NodeType.MEMORY_WRITE: MemoryWriteNodeConfig,
|
||||
NodeType.NOTES: NoteNodeConfig,
|
||||
NodeType.LIST_OPERATOR: ListOperatorNodeConfig,
|
||||
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNodeConfig,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -59,6 +59,9 @@ class WorkflowResultBuilder:
|
||||
conversation_vars = variable_pool.get_all_conversation_vars()
|
||||
sys_vars = variable_pool.get_all_system_vars()
|
||||
|
||||
# 汇总所有 knowledge 节点的 citations
|
||||
citations = self.aggregate_citations(node_outputs)
|
||||
|
||||
return {
|
||||
"status": "completed" if success else "failed",
|
||||
"output": final_output,
|
||||
@@ -71,9 +74,25 @@ class WorkflowResultBuilder:
|
||||
"conversation_id": execution_context.conversation_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": token_usage,
|
||||
"citations": citations,
|
||||
"error": result.get("error"),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def aggregate_citations(node_outputs: dict) -> list:
|
||||
"""从所有 knowledge 节点的输出中汇总 citations,去重"""
|
||||
seen = set()
|
||||
citations = []
|
||||
for node_output in node_outputs.values():
|
||||
if not isinstance(node_output, dict):
|
||||
continue
|
||||
for c in node_output.get("citations", []):
|
||||
key = c.get("document_id")
|
||||
if key and key not in seen:
|
||||
seen.add(key)
|
||||
citations.append(c)
|
||||
return citations
|
||||
|
||||
@staticmethod
|
||||
def aggregate_token_usage(node_outputs: dict) -> dict[str, int] | None:
|
||||
"""
|
||||
|
||||
@@ -318,7 +318,7 @@ class VariablePool:
|
||||
namespace: str,
|
||||
key: str,
|
||||
value: Any,
|
||||
var_type: VariableType,
|
||||
var_type: VariableType | None,
|
||||
mut: bool
|
||||
):
|
||||
if self.has(f"{namespace}.{key}"):
|
||||
@@ -493,6 +493,23 @@ class VariablePoolInitializer:
|
||||
var_value = var_default
|
||||
else:
|
||||
var_value = DEFAULT_VALUE(var_type)
|
||||
# Convert FileInput-format dicts to full FileObject dicts
|
||||
if var_type == VariableType.FILE:
|
||||
if not var_value:
|
||||
continue
|
||||
var_value = await self._resolve_file_default(var_value)
|
||||
if not var_value:
|
||||
continue
|
||||
elif var_type == VariableType.ARRAY_FILE:
|
||||
if not var_value:
|
||||
var_value = []
|
||||
else:
|
||||
resolved = []
|
||||
for item in var_value:
|
||||
f = await self._resolve_file_default(item)
|
||||
if f:
|
||||
resolved.append(f)
|
||||
var_value = resolved
|
||||
await variable_pool.new(
|
||||
namespace="conv",
|
||||
key=var_name,
|
||||
@@ -501,6 +518,17 @@ class VariablePoolInitializer:
|
||||
mut=True
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _resolve_file_default(file_def: dict) -> dict | None:
|
||||
"""Accept only already-resolved FileObject dicts (is_file=True).
|
||||
FileInput-format dicts are converted at save time by WorkflowService._resolve_variables_file_defaults.
|
||||
"""
|
||||
if not isinstance(file_def, dict):
|
||||
return None
|
||||
if file_def.get("is_file"):
|
||||
return file_def
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def _init_system_vars(
|
||||
variable_pool: VariablePool,
|
||||
|
||||
@@ -395,7 +395,8 @@ class BaseNode(ABC):
|
||||
"output": output,
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": token_usage,
|
||||
"error": None
|
||||
"error": None,
|
||||
**self._extract_extra_fields(business_result),
|
||||
}
|
||||
final_output = {
|
||||
"node_outputs": {self.node_id: node_output},
|
||||
@@ -498,6 +499,13 @@ class BaseNode(ABC):
|
||||
# Default implementation returns the business result directly
|
||||
return business_result
|
||||
|
||||
def _extract_extra_fields(self, business_result: Any) -> dict:
|
||||
"""Extracts extra fields to merge into node_output (e.g. citations).
|
||||
|
||||
Subclasses may override to inject additional metadata.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||
"""Extracts token usage information from the business result.
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes import BaseNode
|
||||
from app.core.workflow.nodes.code.config import CodeNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -70,7 +70,8 @@ class CodeNode(BaseNode):
|
||||
for output in self.typed_config.output_variables:
|
||||
value = exec_result.get(output.name)
|
||||
if value is None:
|
||||
raise RuntimeError(f"Return value {output.name} does not exist")
|
||||
result[output.name] = DEFAULT_VALUE(output.type)
|
||||
continue
|
||||
match output.type:
|
||||
case VariableType.STRING:
|
||||
if not isinstance(value, str):
|
||||
|
||||
@@ -24,6 +24,8 @@ from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
||||
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
||||
from app.core.workflow.nodes.notes.config import NoteNodeConfig
|
||||
from app.core.workflow.nodes.list_operator.config import ListOperatorNodeConfig
|
||||
from app.core.workflow.nodes.document_extractor.config import DocExtractorNodeConfig
|
||||
|
||||
__all__ = [
|
||||
# 基础类
|
||||
@@ -49,5 +51,7 @@ __all__ = [
|
||||
"MemoryReadNodeConfig",
|
||||
"MemoryWriteNodeConfig",
|
||||
"CodeNodeConfig",
|
||||
"NoteNodeConfig"
|
||||
"NoteNodeConfig",
|
||||
"ListOperatorNodeConfig",
|
||||
"DocExtractorNodeConfig",
|
||||
]
|
||||
|
||||
@@ -14,12 +14,22 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def _file_object_to_file_input(f: FileObject) -> FileInput:
|
||||
"""Convert workflow FileObject to multimodal FileInput."""
|
||||
file_type = f.origin_file_type or ""
|
||||
# Prefer mime_type for more accurate type detection
|
||||
if not file_type and f.mime_type:
|
||||
file_type = f.mime_type
|
||||
resolved_type = FileType.trans(f.type) if isinstance(f.type, str) else f.type
|
||||
if resolved_type != FileType.DOCUMENT:
|
||||
raise ValueError(
|
||||
f"Document extractor only supports document files, got type '{f.type}' "
|
||||
f"(name={f.name or f.file_id or f.url})"
|
||||
)
|
||||
return FileInput(
|
||||
type=FileType.DOCUMENT,
|
||||
type=resolved_type,
|
||||
transfer_method=TransferMethod(f.transfer_method),
|
||||
url=f.url or None,
|
||||
upload_file_id=f.file_id or None,
|
||||
file_type=f.origin_file_type or "",
|
||||
file_type=file_type,
|
||||
)
|
||||
|
||||
|
||||
@@ -81,6 +91,7 @@ class DocExtractorNode(BaseNode):
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
svc = MultimodalService(db)
|
||||
for f in files:
|
||||
label = f.name or f.url or f.file_id
|
||||
try:
|
||||
file_input = _file_object_to_file_input(f)
|
||||
# Ensure URL is populated for local files
|
||||
@@ -93,7 +104,7 @@ class DocExtractorNode(BaseNode):
|
||||
chunks.append(text)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Node {self.node_id}: failed to extract file url={f.url} file_id={f.file_id}: {e}",
|
||||
f"Node {self.node_id}: failed to extract file '{label}': {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
chunks.append("")
|
||||
|
||||
@@ -24,6 +24,7 @@ class NodeType(StrEnum):
|
||||
MEMORY_READ = "memory-read"
|
||||
MEMORY_WRITE = "memory-write"
|
||||
DOCUMENT_EXTRACTOR = "document-extractor"
|
||||
LIST_OPERATOR = "list-operator"
|
||||
|
||||
UNKNOWN = "unknown"
|
||||
NOTES = "notes"
|
||||
@@ -45,6 +46,8 @@ class ComparisonOperator(StrEnum):
|
||||
LE = "le"
|
||||
GT = "gt"
|
||||
GE = "ge"
|
||||
IN = "in"
|
||||
NOT_IN = "not_in"
|
||||
|
||||
|
||||
class LogicOperator(StrEnum):
|
||||
|
||||
@@ -34,6 +34,20 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
"output": VariableType.ARRAY_STRING
|
||||
}
|
||||
|
||||
def _extract_output(self, business_result: Any) -> Any:
|
||||
"""下游节点只拿 chunks 列表"""
|
||||
if isinstance(business_result, dict) and "chunks" in business_result:
|
||||
return business_result["chunks"]
|
||||
return business_result
|
||||
|
||||
def _extract_citations(self, business_result: Any) -> list:
|
||||
if isinstance(business_result, dict):
|
||||
return business_result.get("citations", [])
|
||||
return []
|
||||
|
||||
def _extract_extra_fields(self, business_result: Any) -> dict:
|
||||
return {"citations": self._extract_citations(business_result)}
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
return {
|
||||
"query": self._render_template(self.typed_config.query, variable_pool),
|
||||
@@ -314,4 +328,20 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
logger.info(
|
||||
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
|
||||
)
|
||||
return [chunk.page_content for chunk in final_rs]
|
||||
citations = []
|
||||
seen_doc_ids = set()
|
||||
for chunk in final_rs:
|
||||
meta = chunk.metadata or {}
|
||||
doc_id = meta.get("document_id") or meta.get("doc_id")
|
||||
if doc_id and doc_id not in seen_doc_ids:
|
||||
seen_doc_ids.add(doc_id)
|
||||
citations.append({
|
||||
"document_id": str(doc_id),
|
||||
"file_name": meta.get("file_name", ""),
|
||||
"knowledge_id": str(meta.get("knowledge_id", kb_config.kb_id)),
|
||||
"score": meta.get("score", 0.0),
|
||||
})
|
||||
return {
|
||||
"chunks": [chunk.page_content for chunk in final_rs],
|
||||
"citations": citations,
|
||||
}
|
||||
|
||||
3
api/app/core/workflow/nodes/list_operator/__init__.py
Normal file
3
api/app/core/workflow/nodes/list_operator/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .node import ListOperatorNode
|
||||
|
||||
__all__ = ["ListOperatorNode"]
|
||||
44
api/app/core/workflow/nodes/list_operator/config.py
Normal file
44
api/app/core/workflow/nodes/list_operator/config.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator
|
||||
|
||||
|
||||
class FilterCondition(BaseModel):
|
||||
key: str = ""
|
||||
comparison_operator: ComparisonOperator = ComparisonOperator.CONTAINS
|
||||
value: str | list[str] | bool = ""
|
||||
|
||||
|
||||
class FilterBy(BaseModel):
|
||||
enabled: bool = False
|
||||
conditions: list[FilterCondition] = Field(default_factory=list)
|
||||
|
||||
|
||||
class OrderByConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
key: str = ""
|
||||
value: str = "asc" # "asc" | "desc"
|
||||
|
||||
|
||||
class Limit(BaseModel):
|
||||
enabled: bool = False
|
||||
size: int = -1
|
||||
|
||||
|
||||
class ExtractConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
serial: str = "1" # 1-based index string, e.g. "1" = first
|
||||
|
||||
|
||||
class ListOperatorNodeConfig(BaseNodeConfig):
|
||||
"""
|
||||
List Operator node config.
|
||||
Operation order: filter -> extract -> order -> limit
|
||||
"""
|
||||
input_list: str = Field(..., description="Variable selector, e.g. {{ sys.files }} or {{ conv.uploaded_files }}")
|
||||
filter_by: FilterBy = Field(default_factory=FilterBy)
|
||||
order_by: OrderByConfig = Field(default_factory=OrderByConfig)
|
||||
limit: Limit = Field(default_factory=Limit)
|
||||
extract_by: ExtractConfig = Field(default_factory=ExtractConfig)
|
||||
143
api/app/core/workflow/nodes/list_operator/node.py
Normal file
143
api/app/core/workflow/nodes/list_operator/node.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator
|
||||
from app.core.workflow.nodes.list_operator.config import ListOperatorNodeConfig, FilterCondition
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# File object fields that hold string values
|
||||
_FILE_STRING_KEYS = {"name", "extension", "mime_type", "url", "transfer_method", "origin_file_type", "file_id"}
|
||||
_FILE_NUMBER_KEYS = {"size"}
|
||||
|
||||
|
||||
class ListOperatorNode(BaseNode):
|
||||
def __init__(self, node_config: dict, workflow_config: dict, down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: ListOperatorNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {
|
||||
"result": VariableType.ANY,
|
||||
"first_record": VariableType.ANY,
|
||||
"last_record": VariableType.ANY,
|
||||
}
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
self.typed_config = ListOperatorNodeConfig(**self.config)
|
||||
cfg = self.typed_config
|
||||
|
||||
# Resolve input variable from path selector
|
||||
items: list = self.get_variable(cfg.input_list, variable_pool)
|
||||
if not isinstance(items, list):
|
||||
raise TypeError(f"Variable '{cfg.input_list}' must be an array, got {type(items)}")
|
||||
|
||||
result = list(items)
|
||||
|
||||
# 1. Filter
|
||||
if cfg.filter_by.enabled and cfg.filter_by.conditions:
|
||||
for condition in cfg.filter_by.conditions:
|
||||
result = [item for item in result if self._match_condition(item, condition, variable_pool)]
|
||||
|
||||
# 2. Extract (take single item by 1-based serial index)
|
||||
if cfg.extract_by.enabled:
|
||||
serial_str = self._resolve_value(cfg.extract_by.serial, variable_pool)
|
||||
idx = int(serial_str) - 1
|
||||
if idx < 0 or idx >= len(result):
|
||||
raise ValueError(f"extract_by.serial={cfg.extract_by.serial} out of range (list length={len(result)})")
|
||||
result = [result[idx]]
|
||||
|
||||
# 3. Order
|
||||
if cfg.order_by.enabled and cfg.order_by.key:
|
||||
reverse = cfg.order_by.value == "desc"
|
||||
key_fn = self._make_sort_key(cfg.order_by.key)
|
||||
result = sorted(result, key=key_fn, reverse=reverse)
|
||||
|
||||
# 4. Limit (take first N)
|
||||
if cfg.limit.enabled and cfg.limit.size > 0:
|
||||
result = result[:cfg.limit.size]
|
||||
|
||||
return {
|
||||
"result": result,
|
||||
"first_record": result[0] if result else None,
|
||||
"last_record": result[-1] if result else None,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _resolve_value(value: str, variable_pool: VariablePool) -> Any:
|
||||
"""If value is a {{ namespace.key }} variable selector, resolve it from the pool.
|
||||
Otherwise return the raw string."""
|
||||
import re
|
||||
m = re.fullmatch(r"\{\{\s*(\w+\.\w+)\s*}}", value.strip())
|
||||
if m:
|
||||
resolved = variable_pool.get_value(value, default=value, strict=False)
|
||||
return resolved
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _make_sort_key(key: str):
|
||||
def key_fn(item):
|
||||
if isinstance(item, dict):
|
||||
return item.get(key) or ""
|
||||
return item
|
||||
return key_fn
|
||||
|
||||
def _match_condition(self, item: Any, cond: FilterCondition, variable_pool: VariablePool) -> bool:
|
||||
op = cond.comparison_operator
|
||||
value = cond.value
|
||||
|
||||
# Resolve value if it's a variable reference {{ namespace.key }}
|
||||
if isinstance(value, str):
|
||||
value = self._resolve_value(value, variable_pool)
|
||||
|
||||
# Resolve left value
|
||||
if isinstance(item, dict):
|
||||
left = item.get(cond.key)
|
||||
else:
|
||||
left = item # primitive array: compare element directly
|
||||
|
||||
# Numeric operators
|
||||
if op == ComparisonOperator.EQ:
|
||||
return self._safe_num(left) == self._safe_num(value)
|
||||
if op == ComparisonOperator.NE:
|
||||
return self._safe_num(left) != self._safe_num(value)
|
||||
if op == ComparisonOperator.LT:
|
||||
return self._safe_num(left) < self._safe_num(value)
|
||||
if op == ComparisonOperator.LE:
|
||||
return self._safe_num(left) <= self._safe_num(value)
|
||||
if op == ComparisonOperator.GT:
|
||||
return self._safe_num(left) > self._safe_num(value)
|
||||
if op == ComparisonOperator.GE:
|
||||
return self._safe_num(left) >= self._safe_num(value)
|
||||
|
||||
# String / sequence operators
|
||||
left_str = str(left) if left is not None else ""
|
||||
if op == ComparisonOperator.CONTAINS:
|
||||
return str(value) in left_str
|
||||
if op == ComparisonOperator.NOT_CONTAINS:
|
||||
return str(value) not in left_str
|
||||
if op == ComparisonOperator.START_WITH:
|
||||
return left_str.startswith(str(value))
|
||||
if op == ComparisonOperator.END_WITH:
|
||||
return left_str.endswith(str(value))
|
||||
if op == ComparisonOperator.IN:
|
||||
return left_str in (value if isinstance(value, list) else [str(value)])
|
||||
if op == ComparisonOperator.NOT_IN:
|
||||
return left_str not in (value if isinstance(value, list) else [str(value)])
|
||||
if op == ComparisonOperator.EMPTY:
|
||||
return not left
|
||||
if op == ComparisonOperator.NOT_EMPTY:
|
||||
return bool(left)
|
||||
|
||||
raise ValueError(f"Unsupported operator: {op}")
|
||||
|
||||
@staticmethod
|
||||
def _safe_num(v) -> float:
|
||||
try:
|
||||
return float(v)
|
||||
except (TypeError, ValueError):
|
||||
return 0.0
|
||||
@@ -135,8 +135,7 @@ class LLMNode(BaseNode):
|
||||
api_key=model_info.api_key,
|
||||
base_url=model_info.api_base,
|
||||
extra_params=extra_params,
|
||||
is_omni=model_info.is_omni,
|
||||
support_thinking="thinking" in (model_info.capability or []),
|
||||
is_omni=model_info.is_omni
|
||||
),
|
||||
type=model_info.model_type
|
||||
)
|
||||
@@ -214,9 +213,10 @@ class LLMNode(BaseNode):
|
||||
messages = messages[:-1] + history_message + messages[-1:]
|
||||
self.messages = messages
|
||||
else:
|
||||
# 使用简单的 prompt 格式(向后兼容)
|
||||
# 使用简单的 prompt 格式(向后兼容)——包装为标准消息列表以兼容所有 provider
|
||||
prompt_template = self.config.get("prompt", "")
|
||||
self.messages = self._render_template(prompt_template, variable_pool)
|
||||
rendered = self._render_template(prompt_template, variable_pool)
|
||||
self.messages = [{"role": "user", "content": rendered}]
|
||||
|
||||
return llm
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from app.core.workflow.nodes.breaker import BreakNode
|
||||
from app.core.workflow.nodes.tool import ToolNode
|
||||
from app.core.workflow.nodes.document_extractor import DocExtractorNode
|
||||
from app.core.workflow.nodes.list_operator import ListOperatorNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -51,7 +52,8 @@ WorkflowNode = Union[
|
||||
MemoryReadNode,
|
||||
MemoryWriteNode,
|
||||
CodeNode,
|
||||
DocExtractorNode
|
||||
DocExtractorNode,
|
||||
ListOperatorNode
|
||||
]
|
||||
|
||||
|
||||
@@ -83,7 +85,8 @@ class NodeFactory:
|
||||
NodeType.MEMORY_READ: MemoryReadNode,
|
||||
NodeType.MEMORY_WRITE: MemoryWriteNode,
|
||||
NodeType.CODE: CodeNode,
|
||||
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNode
|
||||
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNode,
|
||||
NodeType.LIST_OPERATOR: ListOperatorNode
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -118,8 +118,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
is_omni=is_omni,
|
||||
support_thinking="thinking" in (capability or []),
|
||||
is_omni=is_omni
|
||||
),
|
||||
type=ModelType(model_type)
|
||||
)
|
||||
|
||||
@@ -71,8 +71,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
is_omni=is_omni,
|
||||
support_thinking="thinking" in (capability or []),
|
||||
is_omni=is_omni
|
||||
),
|
||||
type=ModelType(model_type)
|
||||
)
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/3/10 13:36
|
||||
import mimetypes
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse, unquote
|
||||
|
||||
TRANSFORM_FILE_TYPE = {
|
||||
'text/plain': 'document/text',
|
||||
'text/markdown': 'document/markdown',
|
||||
@@ -52,5 +55,143 @@ ALLOWED_FILE_TYPES = [
|
||||
def mime_to_file_type(mime_type):
|
||||
if mime_type not in ALLOWED_FILE_TYPES:
|
||||
return None
|
||||
|
||||
return TRANSFORM_FILE_TYPE.get(mime_type, mime_type)
|
||||
|
||||
|
||||
def build_file_object_dict_from_url(url: str, file_type: str, origin_file_type: str) -> dict[str, Any]:
|
||||
"""Build a FileObject dict for a remote_url file using only URL parsing (no HTTP request).
|
||||
Used as fallback when HTTP request fails.
|
||||
"""
|
||||
raw_path = url.split("?")[0]
|
||||
name = unquote(os.path.basename(urlparse(url).path)) or None
|
||||
_, ext = os.path.splitext(name or "")
|
||||
extension = ext.lstrip(".").lower() if ext else None
|
||||
guessed_mime = mimetypes.guess_type(url)[0]
|
||||
return {
|
||||
"type": file_type,
|
||||
"url": url,
|
||||
"transfer_method": "remote_url",
|
||||
"origin_file_type": origin_file_type,
|
||||
"file_id": None,
|
||||
"name": name,
|
||||
"size": None,
|
||||
"extension": extension,
|
||||
"mime_type": guessed_mime or origin_file_type,
|
||||
"is_file": True,
|
||||
}
|
||||
|
||||
|
||||
async def fetch_remote_file_meta(
|
||||
url: str,
|
||||
file_type: str,
|
||||
origin_file_type: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Fetch remote file metadata via HEAD (fallback GET) and build a FileObject dict.
|
||||
Falls back to URL-only parsing if the HTTP request fails.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
name = size = mime_type = extension = None
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.head(url, follow_redirects=True)
|
||||
if resp.status_code != 200:
|
||||
resp = await client.get(url, follow_redirects=True)
|
||||
|
||||
cl = resp.headers.get("Content-Length")
|
||||
size = int(cl) if cl else None
|
||||
|
||||
ct = resp.headers.get("Content-Type", "").split(";")[0].strip()
|
||||
mime_type = ct or origin_file_type
|
||||
|
||||
cd = resp.headers.get("Content-Disposition", "")
|
||||
if "filename=" in cd:
|
||||
name = cd.split("filename=")[-1].strip('"').strip("'")
|
||||
if not name:
|
||||
name = unquote(os.path.basename(urlparse(url).path)) or None
|
||||
|
||||
if name:
|
||||
_, ext = os.path.splitext(name)
|
||||
extension = ext.lstrip(".").lower() if ext else None
|
||||
if not extension and mime_type:
|
||||
ext = mimetypes.guess_extension(mime_type)
|
||||
extension = ext.lstrip(".").lower() if ext else None
|
||||
except Exception:
|
||||
return build_file_object_dict_from_url(url, file_type, origin_file_type)
|
||||
|
||||
return build_file_object_dict_from_meta(
|
||||
file_type=file_type,
|
||||
transfer_method="remote_url",
|
||||
origin_file_type=origin_file_type,
|
||||
file_id=None,
|
||||
url=url,
|
||||
file_name=name,
|
||||
file_size=size,
|
||||
file_ext=extension,
|
||||
content_type=mime_type,
|
||||
)
|
||||
|
||||
|
||||
def build_file_object_dict_from_meta(
|
||||
file_type: str,
|
||||
transfer_method: str,
|
||||
origin_file_type: str,
|
||||
file_id: str,
|
||||
url: str,
|
||||
file_name: str | None,
|
||||
file_size: int | None,
|
||||
file_ext: str | None,
|
||||
content_type: str | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build a FileObject dict from already-fetched FileMetadata fields."""
|
||||
ext = (file_ext or "").lstrip(".")
|
||||
return {
|
||||
"type": file_type,
|
||||
"url": url,
|
||||
"transfer_method": transfer_method,
|
||||
"origin_file_type": content_type or origin_file_type,
|
||||
"file_id": file_id,
|
||||
"name": file_name,
|
||||
"size": file_size,
|
||||
"extension": ext.lower() if ext else None,
|
||||
"mime_type": content_type,
|
||||
"is_file": True,
|
||||
}
|
||||
|
||||
|
||||
def resolve_local_file_object_dict(
|
||||
db,
|
||||
upload_file_id: str | uuid.UUID,
|
||||
file_type: str,
|
||||
origin_file_type: str,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Query FileMetadata and build a FileObject dict for a local_file.
|
||||
Returns None if the file is not found or not completed.
|
||||
"""
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.core.config import settings
|
||||
|
||||
try:
|
||||
fid = uuid.UUID(str(upload_file_id))
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
meta = db.query(FileMetadata).filter(
|
||||
FileMetadata.id == fid,
|
||||
FileMetadata.status == "completed"
|
||||
).first()
|
||||
if not meta:
|
||||
return None
|
||||
|
||||
url = f"{settings.FILE_LOCAL_SERVER_URL}/storage/permanent/{fid}"
|
||||
return build_file_object_dict_from_meta(
|
||||
file_type=file_type,
|
||||
transfer_method="local_file",
|
||||
origin_file_type=origin_file_type,
|
||||
file_id=str(fid),
|
||||
url=url,
|
||||
file_name=meta.file_name,
|
||||
file_size=meta.file_size,
|
||||
file_ext=meta.file_ext,
|
||||
content_type=meta.content_type,
|
||||
)
|
||||
|
||||
@@ -301,7 +301,7 @@ class WorkflowValidator:
|
||||
for node in nodes:
|
||||
if node.get("type") not in [NodeType.START, NodeType.CYCLE_START, NodeType.END] and not node.get("name"):
|
||||
errors.append(
|
||||
f"节点 {node.get('id')} 缺少名称(发布时必须提供)"
|
||||
f"节点 {node.get('name')} 缺少名称(发布时必须提供)"
|
||||
)
|
||||
|
||||
# 2. 验证所有非 start/end 节点都有配置
|
||||
@@ -311,7 +311,7 @@ class WorkflowValidator:
|
||||
config = node.get("config")
|
||||
if not config or not isinstance(config, dict):
|
||||
errors.append(
|
||||
f"节点 {node.get('id')} 缺少配置(发布时必须提供)"
|
||||
f"节点 {node.get('name')} 缺少配置(发布时必须提供)"
|
||||
)
|
||||
|
||||
# 3. 验证必填变量
|
||||
|
||||
@@ -91,7 +91,7 @@ def DEFAULT_VALUE(var_type: VariableType) -> Any:
|
||||
case VariableType.OBJECT:
|
||||
return {}
|
||||
case VariableType.FILE:
|
||||
return None
|
||||
return {}
|
||||
case VariableType.ARRAY_STRING:
|
||||
return []
|
||||
case VariableType.ARRAY_NUMBER:
|
||||
@@ -113,6 +113,12 @@ class FileObject(BaseModel):
|
||||
origin_file_type: str
|
||||
file_id: str | None
|
||||
|
||||
# Extended file metadata
|
||||
name: str | None = None
|
||||
size: int | None = None
|
||||
extension: str | None = None
|
||||
mime_type: str | None = None
|
||||
|
||||
content_cache: dict = Field(default_factory=dict)
|
||||
is_file: bool
|
||||
|
||||
|
||||
@@ -66,20 +66,10 @@ class FileVariable(BaseVariable):
|
||||
type = 'file'
|
||||
|
||||
def valid_value(self, value) -> FileObject:
|
||||
|
||||
if isinstance(value, dict):
|
||||
if not value.get("is_file"):
|
||||
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
|
||||
return FileObject(
|
||||
**{
|
||||
"type": str(value.get('type')),
|
||||
"transfer_method": value.get("transfer_method"),
|
||||
"url": value.get('url'),
|
||||
"file_id": value.get("file_id"),
|
||||
"origin_file_type": value.get("origin_file_type"),
|
||||
"is_file": True
|
||||
}
|
||||
)
|
||||
return FileObject(**value)
|
||||
if isinstance(value, FileObject):
|
||||
return value
|
||||
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
|
||||
@@ -88,7 +78,7 @@ class FileVariable(BaseVariable):
|
||||
return f'{"!"if self.value.type == FileType.IMAGE else ""}[file]({self.value.url})'
|
||||
|
||||
def get_value(self) -> Any:
|
||||
return self.value.model_dump()
|
||||
return self.value.model_dump(exclude={"content_cache"})
|
||||
|
||||
async def get_content(self):
|
||||
total_bytes = 0
|
||||
@@ -186,6 +176,8 @@ def create_variable_instance(var_type: VariableType, value: Any) -> T:
|
||||
return BooleanVariable(value)
|
||||
case VariableType.OBJECT:
|
||||
return DictVariable(value)
|
||||
case VariableType.FILE:
|
||||
return FileVariable(value)
|
||||
case VariableType.ARRAY_STRING:
|
||||
return make_array(StringVariable, value)
|
||||
case VariableType.ARRAY_NUMBER:
|
||||
|
||||
@@ -62,6 +62,7 @@ async def lifespan(app: FastAPI):
|
||||
else:
|
||||
logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)")
|
||||
await create_all_indexes()
|
||||
logger.info("All neo4j indexes and constraints created successfully!")
|
||||
logger.info("应用程序启动完成")
|
||||
|
||||
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
import asyncio
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
async def create_fulltext_indexes():
|
||||
"""Create full-text indexes for keyword search with BM25 scoring."""
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
|
||||
|
||||
# 创建 Statements 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
|
||||
""")
|
||||
|
||||
# # 创建 Dialogues 索引
|
||||
# await connector.execute_query("""
|
||||
# CREATE FULLTEXT INDEX dialoguesFulltext IF NOT EXISTS FOR (d:Dialogue) ON EACH [d.content]
|
||||
@@ -21,27 +21,35 @@ async def create_fulltext_indexes():
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX entitiesFulltext IF NOT EXISTS FOR (e:ExtractedEntity) ON EACH [e.name]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
|
||||
""")
|
||||
|
||||
# 创建 Chunks 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX chunksFulltext IF NOT EXISTS FOR (c:Chunk) ON EACH [c.content]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
|
||||
""")
|
||||
|
||||
# 创建 MemorySummary 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX summariesFulltext IF NOT EXISTS FOR (m:MemorySummary) ON EACH [m.content]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
""")
|
||||
# 创建 Community 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX communitiesFulltext IF NOT EXISTS FOR (c:Community) ON EACH [c.name, c.summary]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
|
||||
|
||||
# 创建 Perceptual 感知记忆索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX perceptualFulltext IF NOT EXISTS FOR (p:Perceptual) ON EACH [p.summary, p.topic, p.domain]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def create_vector_indexes():
|
||||
"""Create vector indexes for fast embedding similarity search.
|
||||
|
||||
@@ -50,8 +58,7 @@ async def create_vector_indexes():
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
|
||||
|
||||
|
||||
# Statement embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX statement_embedding_index IF NOT EXISTS
|
||||
@@ -62,8 +69,7 @@ async def create_vector_indexes():
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
|
||||
|
||||
|
||||
# Chunk embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX chunk_embedding_index IF NOT EXISTS
|
||||
@@ -75,7 +81,6 @@ async def create_vector_indexes():
|
||||
}}
|
||||
""")
|
||||
|
||||
|
||||
# Entity name embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX entity_embedding_index IF NOT EXISTS
|
||||
@@ -86,8 +91,7 @@ async def create_vector_indexes():
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
|
||||
|
||||
|
||||
# Memory summary embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX summary_embedding_index IF NOT EXISTS
|
||||
@@ -98,7 +102,7 @@ async def create_vector_indexes():
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
|
||||
|
||||
# Community summary embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS
|
||||
@@ -108,8 +112,8 @@ async def create_vector_indexes():
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
|
||||
""")
|
||||
|
||||
# Dialogue embedding index (optional)
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX dialogue_embedding_index IF NOT EXISTS
|
||||
@@ -120,15 +124,27 @@ async def create_vector_indexes():
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
|
||||
|
||||
# Perceptual summary embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX perceptual_summary_embedding_index IF NOT EXISTS
|
||||
FOR (p:Perceptual)
|
||||
ON p.summary_embedding
|
||||
OPTIONS {indexConfig: {
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def create_unique_constraints():
|
||||
"""Create uniqueness constraints for core node identifiers.
|
||||
Ensures concurrent MERGE operations remain safe and prevents duplicates.
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
try:
|
||||
# Dialogue.id unique
|
||||
await connector.execute_query(
|
||||
"""
|
||||
@@ -136,7 +152,7 @@ async def create_unique_constraints():
|
||||
FOR (d:Dialogue) REQUIRE d.id IS UNIQUE
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
# Statement.id unique
|
||||
await connector.execute_query(
|
||||
"""
|
||||
@@ -144,7 +160,7 @@ async def create_unique_constraints():
|
||||
FOR (s:Statement) REQUIRE s.id IS UNIQUE
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
# Chunk.id unique
|
||||
await connector.execute_query(
|
||||
"""
|
||||
@@ -152,13 +168,13 @@ async def create_unique_constraints():
|
||||
FOR (c:Chunk) REQUIRE c.id IS UNIQUE
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def create_all_indexes():
|
||||
"""Create all indexes and constraints in one go."""
|
||||
await create_fulltext_indexes()
|
||||
await create_vector_indexes()
|
||||
await create_unique_constraints()
|
||||
print("✓ All indexes and constraints created successfully!")
|
||||
|
||||
|
||||
@@ -1449,3 +1449,44 @@ ON CREATE SET r.end_user_id = edge.end_user_id,
|
||||
r.created_at = edge.created_at
|
||||
RETURN elementId(r) AS uuid
|
||||
"""
|
||||
|
||||
SEARCH_PERCEPTUAL_BY_KEYWORD = """
|
||||
CALL db.index.fulltext.queryNodes("perceptualFulltext", $q) YIELD node AS p, score
|
||||
WHERE p.end_user_id = $end_user_id
|
||||
RETURN p.id AS id,
|
||||
p.end_user_id AS end_user_id,
|
||||
p.perceptual_type AS perceptual_type,
|
||||
p.file_path AS file_path,
|
||||
p.file_name AS file_name,
|
||||
p.file_ext AS file_ext,
|
||||
p.summary AS summary,
|
||||
p.keywords AS keywords,
|
||||
p.topic AS topic,
|
||||
p.domain AS domain,
|
||||
p.created_at AS created_at,
|
||||
p.file_type AS file_type,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
PERCEPTUAL_EMBEDDING_SEARCH = """
|
||||
CALL db.index.vector.queryNodes('perceptual_summary_embedding_index', $limit * 100, $embedding)
|
||||
YIELD node AS p, score
|
||||
WHERE p.summary_embedding IS NOT NULL AND p.end_user_id = $end_user_id
|
||||
RETURN p.id AS id,
|
||||
p.end_user_id AS end_user_id,
|
||||
p.perceptual_type AS perceptual_type,
|
||||
p.file_path AS file_path,
|
||||
p.file_name AS file_name,
|
||||
p.file_ext AS file_ext,
|
||||
p.summary AS summary,
|
||||
p.keywords AS keywords,
|
||||
p.topic AS topic,
|
||||
p.domain AS domain,
|
||||
p.created_at AS created_at,
|
||||
p.file_type AS file_type,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
@@ -8,6 +8,7 @@ from app.repositories.neo4j.cypher_queries import (
|
||||
ENTITY_EMBEDDING_SEARCH,
|
||||
EXPAND_COMMUNITY_STATEMENTS,
|
||||
MEMORY_SUMMARY_EMBEDDING_SEARCH,
|
||||
PERCEPTUAL_EMBEDDING_SEARCH,
|
||||
SEARCH_CHUNK_BY_CHUNK_ID,
|
||||
SEARCH_CHUNKS_BY_CONTENT,
|
||||
SEARCH_COMMUNITIES_BY_KEYWORD,
|
||||
@@ -15,6 +16,7 @@ from app.repositories.neo4j.cypher_queries import (
|
||||
SEARCH_ENTITIES_BY_NAME,
|
||||
SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
|
||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
|
||||
SEARCH_PERCEPTUAL_BY_KEYWORD,
|
||||
SEARCH_STATEMENTS_BY_CREATED_AT,
|
||||
SEARCH_STATEMENTS_BY_KEYWORD,
|
||||
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
|
||||
@@ -34,11 +36,11 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _update_activation_values_batch(
|
||||
connector: Neo4jConnector,
|
||||
nodes: List[Dict[str, Any]],
|
||||
node_label: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
max_retries: int = 3
|
||||
connector: Neo4jConnector,
|
||||
nodes: List[Dict[str, Any]],
|
||||
node_label: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
max_retries: int = 3
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
批量更新节点的激活值
|
||||
@@ -58,7 +60,7 @@ async def _update_activation_values_batch(
|
||||
"""
|
||||
if not nodes:
|
||||
return []
|
||||
|
||||
|
||||
# 延迟导入以避免循环依赖
|
||||
from app.core.memory.storage_services.forgetting_engine.access_history_manager import (
|
||||
AccessHistoryManager,
|
||||
@@ -66,7 +68,7 @@ async def _update_activation_values_batch(
|
||||
from app.core.memory.storage_services.forgetting_engine.actr_calculator import (
|
||||
ACTRCalculator,
|
||||
)
|
||||
|
||||
|
||||
# 创建计算器和管理器实例
|
||||
actr_calculator = ACTRCalculator()
|
||||
access_manager = AccessHistoryManager(
|
||||
@@ -74,7 +76,7 @@ async def _update_activation_values_batch(
|
||||
actr_calculator=actr_calculator,
|
||||
max_retries=max_retries
|
||||
)
|
||||
|
||||
|
||||
# 提取节点ID列表并去重(保持原始顺序)
|
||||
seen_ids = set()
|
||||
unique_node_ids = []
|
||||
@@ -83,7 +85,7 @@ async def _update_activation_values_batch(
|
||||
if node_id and node_id not in seen_ids:
|
||||
seen_ids.add(node_id)
|
||||
unique_node_ids.append(node_id)
|
||||
|
||||
|
||||
if not unique_node_ids:
|
||||
logger.warning(f"批量更新激活值:没有有效的节点ID")
|
||||
return nodes
|
||||
@@ -95,7 +97,7 @@ async def _update_activation_values_batch(
|
||||
f"批量更新激活值:检测到重复节点,具有有效ID的节点数量={id_nodes_count}, "
|
||||
f"去重后唯一ID数量={len(unique_node_ids)}"
|
||||
)
|
||||
|
||||
|
||||
# 批量记录访问
|
||||
try:
|
||||
updated_nodes = await access_manager.record_batch_access(
|
||||
@@ -103,14 +105,14 @@ async def _update_activation_values_batch(
|
||||
node_label=node_label,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
|
||||
logger.info(
|
||||
f"批量更新激活值成功: {node_label}, "
|
||||
f"更新数量={len(updated_nodes)}/{len(unique_node_ids)}"
|
||||
)
|
||||
|
||||
|
||||
return updated_nodes
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"批量更新激活值失败: {node_label}, 错误: {str(e)}"
|
||||
@@ -120,9 +122,9 @@ async def _update_activation_values_batch(
|
||||
|
||||
|
||||
async def _update_search_results_activation(
|
||||
connector: Neo4jConnector,
|
||||
results: Dict[str, List[Dict[str, Any]]],
|
||||
end_user_id: Optional[str] = None
|
||||
connector: Neo4jConnector,
|
||||
results: Dict[str, List[Dict[str, Any]]],
|
||||
end_user_id: Optional[str] = None
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
更新搜索结果中所有知识节点的激活值
|
||||
@@ -144,11 +146,11 @@ async def _update_search_results_activation(
|
||||
'entities': 'ExtractedEntity',
|
||||
'summaries': 'MemorySummary'
|
||||
}
|
||||
|
||||
|
||||
# 并行更新所有类型的节点
|
||||
update_tasks = []
|
||||
update_keys = []
|
||||
|
||||
|
||||
for key, label in knowledge_node_types.items():
|
||||
if key in results and results[key]:
|
||||
update_tasks.append(
|
||||
@@ -160,13 +162,13 @@ async def _update_search_results_activation(
|
||||
)
|
||||
)
|
||||
update_keys.append(key)
|
||||
|
||||
|
||||
if not update_tasks:
|
||||
return results
|
||||
|
||||
|
||||
# 并行执行所有更新
|
||||
update_results = await asyncio.gather(*update_tasks, return_exceptions=True)
|
||||
|
||||
|
||||
# 更新结果字典,保留原始搜索分数
|
||||
updated_results = results.copy()
|
||||
for key, update_result in zip(update_keys, update_results):
|
||||
@@ -175,10 +177,10 @@ async def _update_search_results_activation(
|
||||
# 保留原始的 score 字段(BM25/Embedding 分数)
|
||||
original_nodes = results[key]
|
||||
updated_nodes = update_result
|
||||
|
||||
|
||||
# 创建 ID 到更新节点的映射(用于快速查找激活值数据)
|
||||
updated_map = {node.get('id'): node for node in updated_nodes if node.get('id')}
|
||||
|
||||
|
||||
# 合并数据:保留所有原始节点(包括重复的),用更新后的激活值数据填充
|
||||
merged_nodes = []
|
||||
for original_node in original_nodes:
|
||||
@@ -186,7 +188,7 @@ async def _update_search_results_activation(
|
||||
if node_id and node_id in updated_map:
|
||||
# 从原始节点开始,用更新后的激活值数据覆盖
|
||||
merged_node = original_node.copy()
|
||||
|
||||
|
||||
# 更新激活值相关字段
|
||||
activation_fields = {
|
||||
'activation_value',
|
||||
@@ -196,35 +198,35 @@ async def _update_search_results_activation(
|
||||
'importance_score',
|
||||
'version',
|
||||
'statement', # Statement 节点的内容字段
|
||||
'content' # MemorySummary 节点的内容字段
|
||||
'content' # MemorySummary 节点的内容字段
|
||||
}
|
||||
|
||||
|
||||
# 只更新激活值相关字段,保留原始节点的其他字段
|
||||
for field in activation_fields:
|
||||
if field in updated_map[node_id]:
|
||||
merged_node[field] = updated_map[node_id][field]
|
||||
|
||||
|
||||
merged_nodes.append(merged_node)
|
||||
else:
|
||||
# 如果没有更新数据,保留原始节点
|
||||
merged_nodes.append(original_node)
|
||||
|
||||
|
||||
updated_results[key] = merged_nodes
|
||||
else:
|
||||
# 更新失败,记录错误但保留原始结果
|
||||
logger.warning(
|
||||
f"更新 {key} 激活值失败: {str(update_result)}"
|
||||
)
|
||||
|
||||
|
||||
return updated_results
|
||||
|
||||
|
||||
async def search_graph(
|
||||
connector: Neo4jConnector,
|
||||
q: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: List[str] = None,
|
||||
connector: Neo4jConnector,
|
||||
q: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: List[str] = None,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Search across Statements, Entities, Chunks, and Summaries using a free-text query.
|
||||
@@ -249,41 +251,45 @@ async def search_graph(
|
||||
"""
|
||||
if include is None:
|
||||
include = ["statements", "chunks", "entities", "summaries"]
|
||||
|
||||
|
||||
# Prepare tasks for parallel execution
|
||||
tasks = []
|
||||
task_keys = []
|
||||
|
||||
|
||||
if "statements" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_KEYWORD,
|
||||
json_format=True,
|
||||
q=q,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("statements")
|
||||
|
||||
|
||||
if "entities" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
|
||||
json_format=True,
|
||||
q=q,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("entities")
|
||||
|
||||
|
||||
if "chunks" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_CHUNKS_BY_CONTENT,
|
||||
json_format=True,
|
||||
q=q,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("chunks")
|
||||
|
||||
|
||||
if "summaries" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
|
||||
json_format=True,
|
||||
q=q,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
@@ -293,15 +299,16 @@ async def search_graph(
|
||||
if "communities" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_COMMUNITIES_BY_KEYWORD,
|
||||
json_format=True,
|
||||
q=q,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("communities")
|
||||
|
||||
|
||||
# Execute all queries in parallel
|
||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
|
||||
# Build results dictionary
|
||||
results = {}
|
||||
for key, result in zip(task_keys, task_results):
|
||||
@@ -310,14 +317,14 @@ async def search_graph(
|
||||
results[key] = []
|
||||
else:
|
||||
results[key] = result
|
||||
|
||||
|
||||
# Deduplicate results before updating activation values
|
||||
# This prevents duplicates from propagating through the pipeline
|
||||
from app.core.memory.src.search import _deduplicate_results
|
||||
for key in results:
|
||||
if isinstance(results[key], list):
|
||||
results[key] = _deduplicate_results(results[key])
|
||||
|
||||
|
||||
# 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary)
|
||||
# Skip activation updates if only searching summaries (optimization)
|
||||
needs_activation_update = any(
|
||||
@@ -331,17 +338,17 @@ async def search_graph(
|
||||
results=results,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_graph_by_embedding(
|
||||
connector: Neo4jConnector,
|
||||
embedder_client,
|
||||
query_text: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: List[str] = ["statements", "chunks", "entities","summaries"],
|
||||
connector: Neo4jConnector,
|
||||
embedder_client,
|
||||
query_text: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: List[str] = ["statements", "chunks", "entities", "summaries"],
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Embedding-based semantic search across Statements, Chunks, and Entities.
|
||||
@@ -355,13 +362,13 @@ async def search_graph_by_embedding(
|
||||
- Returns up to 'limit' per included type
|
||||
"""
|
||||
import time
|
||||
|
||||
|
||||
# Get embedding for the query
|
||||
embed_start = time.time()
|
||||
embeddings = await embedder_client.response([query_text])
|
||||
embed_time = time.time() - embed_start
|
||||
print(f"[PERF] Embedding generation took: {embed_time:.4f}s")
|
||||
|
||||
logger.debug(f"[PERF] Embedding generation took: {embed_time:.4f}s")
|
||||
|
||||
if not embeddings or not embeddings[0]:
|
||||
logger.warning(
|
||||
f"search_graph_by_embedding: embedding 生成失败或为空,"
|
||||
@@ -378,6 +385,7 @@ async def search_graph_by_embedding(
|
||||
if "statements" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
STATEMENT_EMBEDDING_SEARCH,
|
||||
json_format=True,
|
||||
embedding=embedding,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
@@ -388,6 +396,7 @@ async def search_graph_by_embedding(
|
||||
if "chunks" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
CHUNK_EMBEDDING_SEARCH,
|
||||
json_format=True,
|
||||
embedding=embedding,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
@@ -398,6 +407,7 @@ async def search_graph_by_embedding(
|
||||
if "entities" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
ENTITY_EMBEDDING_SEARCH,
|
||||
json_format=True,
|
||||
embedding=embedding,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
@@ -408,6 +418,7 @@ async def search_graph_by_embedding(
|
||||
if "summaries" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
MEMORY_SUMMARY_EMBEDDING_SEARCH,
|
||||
json_format=True,
|
||||
embedding=embedding,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
@@ -418,6 +429,7 @@ async def search_graph_by_embedding(
|
||||
if "communities" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
COMMUNITY_EMBEDDING_SEARCH,
|
||||
json_format=True,
|
||||
embedding=embedding,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
@@ -428,8 +440,8 @@ async def search_graph_by_embedding(
|
||||
query_start = time.time()
|
||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
query_time = time.time() - query_start
|
||||
print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
|
||||
|
||||
logger.debug(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
|
||||
|
||||
# Build results dictionary
|
||||
results: Dict[str, List[Dict[str, Any]]] = {
|
||||
"statements": [],
|
||||
@@ -438,7 +450,7 @@ async def search_graph_by_embedding(
|
||||
"summaries": [],
|
||||
"communities": [],
|
||||
}
|
||||
|
||||
|
||||
for key, result in zip(task_keys, task_results):
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f"search_graph_by_embedding: {key} 向量查询异常: {result}")
|
||||
@@ -473,13 +485,15 @@ async def search_graph_by_embedding(
|
||||
logger.info(f"[PERF] Skipping activation updates (only summaries)")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: str,
|
||||
entities: List[Dict[str, Any]],
|
||||
use_contains_fallback: bool = True,
|
||||
batch_size: int = 500,
|
||||
max_concurrency: int = 5,
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: str,
|
||||
entities: List[Dict[str, Any]],
|
||||
use_contains_fallback: bool = True,
|
||||
batch_size: int = 500,
|
||||
max_concurrency: int = 5,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
为第二层去重消歧批量检索候选实体(适配新版 cypher_queries):
|
||||
@@ -560,14 +574,14 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
|
||||
|
||||
|
||||
async def search_graph_by_keyword_temporal(
|
||||
connector: Neo4jConnector,
|
||||
query_text: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
connector: Neo4jConnector,
|
||||
query_text: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
) -> Dict[str, List[Any]]:
|
||||
"""
|
||||
Temporal keyword search across Statements.
|
||||
@@ -579,7 +593,7 @@ async def search_graph_by_keyword_temporal(
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
if not query_text:
|
||||
print(f"query_text不能为空")
|
||||
logger.warning(f"query_text不能为空")
|
||||
return {"statements": []}
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
|
||||
@@ -591,7 +605,7 @@ async def search_graph_by_keyword_temporal(
|
||||
invalid_date=invalid_date,
|
||||
limit=limit,
|
||||
)
|
||||
print(f"查询结果为:\n{statements}")
|
||||
logger.debug(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
@@ -605,13 +619,13 @@ async def search_graph_by_keyword_temporal(
|
||||
|
||||
|
||||
async def search_graph_by_temporal(
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
@@ -632,10 +646,6 @@ async def search_graph_by_temporal(
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
|
||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
@@ -643,15 +653,15 @@ async def search_graph_by_temporal(
|
||||
results=results,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_graph_by_dialog_id(
|
||||
connector: Neo4jConnector,
|
||||
dialog_id: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
connector: Neo4jConnector,
|
||||
dialog_id: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Dialogues.
|
||||
@@ -661,7 +671,7 @@ async def search_graph_by_dialog_id(
|
||||
- Returns up to 'limit' dialogues
|
||||
"""
|
||||
if not dialog_id:
|
||||
print(f"dialog_id不能为空")
|
||||
logger.warning(f"dialog_id不能为空")
|
||||
return {"dialogues": []}
|
||||
|
||||
dialogues = await connector.execute_query(
|
||||
@@ -674,13 +684,13 @@ async def search_graph_by_dialog_id(
|
||||
|
||||
|
||||
async def search_graph_by_chunk_id(
|
||||
connector: Neo4jConnector,
|
||||
chunk_id : str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
connector: Neo4jConnector,
|
||||
chunk_id: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
if not chunk_id:
|
||||
print(f"chunk_id不能为空")
|
||||
logger.warning(f"chunk_id不能为空")
|
||||
return {"chunks": []}
|
||||
chunks = await connector.execute_query(
|
||||
SEARCH_CHUNK_BY_CHUNK_ID,
|
||||
@@ -692,10 +702,10 @@ async def search_graph_by_chunk_id(
|
||||
|
||||
|
||||
async def search_graph_community_expand(
|
||||
connector: Neo4jConnector,
|
||||
community_ids: List[str],
|
||||
end_user_id: str,
|
||||
limit: int = 10,
|
||||
connector: Neo4jConnector,
|
||||
community_ids: List[str],
|
||||
end_user_id: str,
|
||||
limit: int = 10,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
三期:社区展开检索 —— 主题 → 细节两级检索。
|
||||
@@ -748,12 +758,11 @@ async def search_graph_community_expand(
|
||||
|
||||
|
||||
async def search_graph_by_created_at(
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
|
||||
created_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
created_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
@@ -767,16 +776,11 @@ async def search_graph_by_created_at(
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_CREATED_AT,
|
||||
end_user_id=end_user_id,
|
||||
|
||||
|
||||
|
||||
created_at=created_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{end_user_id: {end_user_id} created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
@@ -784,16 +788,16 @@ async def search_graph_by_created_at(
|
||||
results=results,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_graph_by_valid_at(
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
|
||||
valid_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
valid_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
@@ -807,16 +811,11 @@ async def search_graph_by_valid_at(
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_VALID_AT,
|
||||
end_user_id=end_user_id,
|
||||
|
||||
|
||||
|
||||
valid_at=valid_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
|
||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
@@ -824,16 +823,16 @@ async def search_graph_by_valid_at(
|
||||
results=results,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_graph_g_created_at(
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
|
||||
created_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
created_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
@@ -847,16 +846,11 @@ async def search_graph_g_created_at(
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_G_CREATED_AT,
|
||||
end_user_id=end_user_id,
|
||||
|
||||
|
||||
|
||||
created_at=created_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
@@ -864,16 +858,16 @@ async def search_graph_g_created_at(
|
||||
results=results,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_graph_g_valid_at(
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
|
||||
valid_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
valid_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
@@ -887,16 +881,10 @@ async def search_graph_g_valid_at(
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_G_VALID_AT,
|
||||
end_user_id=end_user_id,
|
||||
|
||||
|
||||
valid_at=valid_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
|
||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
@@ -904,16 +892,16 @@ async def search_graph_g_valid_at(
|
||||
results=results,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_graph_l_created_at(
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
|
||||
created_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
created_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
@@ -927,16 +915,11 @@ async def search_graph_l_created_at(
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_L_CREATED_AT,
|
||||
end_user_id=end_user_id,
|
||||
|
||||
|
||||
|
||||
created_at=created_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
@@ -944,16 +927,16 @@ async def search_graph_l_created_at(
|
||||
results=results,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_graph_l_valid_at(
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
|
||||
valid_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
valid_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
@@ -967,16 +950,11 @@ async def search_graph_l_valid_at(
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_L_VALID_AT,
|
||||
end_user_id=end_user_id,
|
||||
|
||||
|
||||
|
||||
valid_at=valid_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
|
||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
@@ -984,5 +962,89 @@ async def search_graph_l_valid_at(
|
||||
results=results,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_perceptual(
|
||||
connector: Neo4jConnector,
|
||||
q: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Search Perceptual memory nodes using fulltext keyword search.
|
||||
|
||||
Matches against summary, topic, and domain fields via the perceptualFulltext index.
|
||||
|
||||
Args:
|
||||
connector: Neo4j connector
|
||||
q: Query text
|
||||
end_user_id: Optional user filter
|
||||
limit: Max results
|
||||
|
||||
Returns:
|
||||
Dictionary with 'perceptuals' key containing matched perceptual memory nodes
|
||||
"""
|
||||
try:
|
||||
perceptuals = await connector.execute_query(
|
||||
SEARCH_PERCEPTUAL_BY_KEYWORD,
|
||||
q=q,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"search_perceptual: keyword search failed: {e}")
|
||||
perceptuals = []
|
||||
|
||||
# Deduplicate
|
||||
from app.core.memory.src.search import _deduplicate_results
|
||||
perceptuals = _deduplicate_results(perceptuals)
|
||||
|
||||
return {"perceptuals": perceptuals}
|
||||
|
||||
|
||||
async def search_perceptual_by_embedding(
|
||||
connector: Neo4jConnector,
|
||||
embedder_client,
|
||||
query_text: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Search Perceptual memory nodes using embedding-based semantic search.
|
||||
|
||||
Uses cosine similarity on summary_embedding via the perceptual_summary_embedding_index.
|
||||
|
||||
Args:
|
||||
connector: Neo4j connector
|
||||
embedder_client: Embedding client with async response() method
|
||||
query_text: Query text to embed
|
||||
end_user_id: Optional user filter
|
||||
limit: Max results
|
||||
|
||||
Returns:
|
||||
Dictionary with 'perceptuals' key containing matched perceptual memory nodes
|
||||
"""
|
||||
embeddings = await embedder_client.response([query_text])
|
||||
if not embeddings or not embeddings[0]:
|
||||
logger.warning(f"search_perceptual_by_embedding: embedding generation failed for '{query_text[:50]}'")
|
||||
return {"perceptuals": []}
|
||||
|
||||
embedding = embeddings[0]
|
||||
|
||||
try:
|
||||
perceptuals = await connector.execute_query(
|
||||
PERCEPTUAL_EMBEDDING_SEARCH,
|
||||
embedding=embedding,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"search_perceptual_by_embedding: vector search failed: {e}")
|
||||
perceptuals = []
|
||||
|
||||
from app.core.memory.src.search import _deduplicate_results
|
||||
perceptuals = _deduplicate_results(perceptuals)
|
||||
|
||||
return {"perceptuals": perceptuals}
|
||||
|
||||
@@ -11,10 +11,28 @@ Classes:
|
||||
from typing import Any, List, Dict
|
||||
|
||||
from neo4j import AsyncGraphDatabase, basic_auth
|
||||
from neo4j.time import DateTime as Neo4jDateTime, Date as Neo4jDate, Time as Neo4jTime, Duration as Neo4jDuration
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
def _convert_neo4j_types(value: Any) -> Any:
|
||||
"""递归将 neo4j 原生时间类型转为 Python 原生类型 / ISO 字符串,确保可被 json.dumps 序列化。"""
|
||||
if isinstance(value, Neo4jDateTime):
|
||||
return value.to_native().isoformat() if value.tzinfo else value.iso_format()
|
||||
if isinstance(value, Neo4jDate):
|
||||
return value.iso_format()
|
||||
if isinstance(value, Neo4jTime):
|
||||
return value.iso_format()
|
||||
if isinstance(value, Neo4jDuration):
|
||||
return str(value)
|
||||
if isinstance(value, dict):
|
||||
return {k: _convert_neo4j_types(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_convert_neo4j_types(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
class Neo4jConnector:
|
||||
"""Neo4j数据库连接器
|
||||
|
||||
@@ -59,11 +77,12 @@ class Neo4jConnector:
|
||||
"""
|
||||
await self.driver.close()
|
||||
|
||||
async def execute_query(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]:
|
||||
async def execute_query(self, query: str, json_format=False, **kwargs: Any) -> List[Dict[str, Any]]:
|
||||
"""执行Cypher查询
|
||||
|
||||
Args:
|
||||
query: Cypher查询语句
|
||||
json_format: json格式化
|
||||
**kwargs: 查询参数,将作为参数传递给Cypher查询
|
||||
|
||||
Returns:
|
||||
@@ -78,7 +97,10 @@ class Neo4jConnector:
|
||||
**kwargs
|
||||
)
|
||||
records, summary, keys = result
|
||||
return [record.data() for record in records]
|
||||
if json_format:
|
||||
return [_convert_neo4j_types(record.data()) for record in records]
|
||||
else:
|
||||
return [record.data() for record in records]
|
||||
|
||||
async def execute_write_transaction(self, transaction_func, **kwargs: Any) -> Any:
|
||||
"""在写事务中执行操作
|
||||
|
||||
@@ -4,10 +4,6 @@ from typing import Optional, Any, List, Dict, Union
|
||||
from enum import Enum, StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
|
||||
|
||||
from app.schemas.workflow_schema import WorkflowConfigCreate
|
||||
|
||||
|
||||
# ---------- Multimodal File Support ----------
|
||||
|
||||
class FileType(StrEnum):
|
||||
@@ -317,7 +313,7 @@ class AppCreate(BaseModel):
|
||||
# only for type=multi_agent
|
||||
multi_agent_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
workflow_config: Optional[WorkflowConfigCreate] = None
|
||||
workflow_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class AppUpdate(BaseModel):
|
||||
@@ -644,6 +640,7 @@ class CitationSource(BaseModel):
|
||||
class DraftRunResponse(BaseModel):
|
||||
"""试运行响应(非流式)"""
|
||||
message: str = Field(..., description="AI 回复消息")
|
||||
reasoning_content: Optional[str] = Field(default=None, description="深度思考内容")
|
||||
conversation_id: Optional[str] = Field(default=None, description="会话ID(用于多轮对话)")
|
||||
usage: Optional[Dict[str, Any]] = Field(default=None, description="Token 使用情况")
|
||||
elapsed_time: Optional[float] = Field(default=None, description="耗时(秒)")
|
||||
@@ -651,6 +648,12 @@ class DraftRunResponse(BaseModel):
|
||||
citations: List[CitationSource] = Field(default_factory=list, description="引用来源")
|
||||
audio_url: Optional[str] = Field(default=None, description="TTS 语音URL")
|
||||
|
||||
def model_dump(self, **kwargs):
|
||||
data = super().model_dump(**kwargs)
|
||||
if not data.get("reasoning_content"):
|
||||
data.pop("reasoning_content", None)
|
||||
return data
|
||||
|
||||
|
||||
class OpeningResponse(BaseModel):
|
||||
"""应用开场白响应"""
|
||||
|
||||
@@ -401,7 +401,7 @@ class AppService:
|
||||
def _create_workflow_config(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
data: app_schema.WorkflowConfigCreate,
|
||||
data,
|
||||
now: datetime.datetime
|
||||
):
|
||||
workflow_cfg = WorkflowConfig(
|
||||
@@ -678,7 +678,9 @@ class AppService:
|
||||
self._create_multi_agent_config(app.id, data.multi_agent_config, now)
|
||||
|
||||
if app.type == "workflow" and data.workflow_config:
|
||||
self._create_workflow_config(app.id, data.workflow_config, now)
|
||||
from app.schemas.workflow_schema import WorkflowConfigCreate
|
||||
wf_data = WorkflowConfigCreate(**data.workflow_config) if isinstance(data.workflow_config, dict) else data.workflow_config
|
||||
self._create_workflow_config(app.id, wf_data, now)
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(app)
|
||||
|
||||
@@ -462,11 +462,6 @@ class MemoryAgentService:
|
||||
|
||||
logger.info(f"Read operation for group {end_user_id} with config_id {config_id}")
|
||||
|
||||
# 导入审计日志记录器
|
||||
|
||||
|
||||
|
||||
|
||||
config_load_start = time.time()
|
||||
try:
|
||||
# Use a separate database session to avoid transaction failures
|
||||
@@ -507,10 +502,13 @@ class MemoryAgentService:
|
||||
async with make_read_graph() as graph:
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
|
||||
"end_user_id": end_user_id
|
||||
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
|
||||
"memory_config": memory_config}
|
||||
initial_state = {
|
||||
"messages": [HumanMessage(content=message)],
|
||||
"search_switch": search_switch,
|
||||
"end_user_id": end_user_id
|
||||
, "storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"memory_config": memory_config}
|
||||
# 获取节点更新信息
|
||||
_intermediate_outputs = []
|
||||
summary = ''
|
||||
@@ -522,7 +520,7 @@ class MemoryAgentService:
|
||||
for node_name, node_data in update_event.items():
|
||||
# if 'save_neo4j' == node_name:
|
||||
# massages = node_data
|
||||
print(f"处理节点: {node_name}")
|
||||
logger.info(f"处理节点: {node_name}")
|
||||
|
||||
# 处理不同Summary节点的返回结构
|
||||
if 'Summary' in node_name:
|
||||
@@ -549,6 +547,11 @@ class MemoryAgentService:
|
||||
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
|
||||
_intermediate_outputs.extend(retrieve_node)
|
||||
|
||||
# Perceptual_Retrieve 节点
|
||||
perceptual_node = node_data.get('perceptual_data', {}).get('_intermediate', None)
|
||||
if perceptual_node and perceptual_node != [] and perceptual_node != {}:
|
||||
_intermediate_outputs.append(perceptual_node)
|
||||
|
||||
# Verify 节点
|
||||
verify_n = node_data.get('verify', {}).get('_intermediate', None)
|
||||
if verify_n and verify_n != [] and verify_n != {}:
|
||||
|
||||
@@ -16,7 +16,6 @@ from app.core.workflow.adapters.registry import PlatformAdapterRegistry
|
||||
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.core.workflow.validator import validate_workflow_config
|
||||
from app.core.workflow.variable.base_variable import FileObject
|
||||
from app.db import get_db
|
||||
from app.models import App
|
||||
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
||||
@@ -453,22 +452,70 @@ class WorkflowService:
|
||||
"success_rate": completed / total if total > 0 else 0
|
||||
}
|
||||
|
||||
async def _resolve_variables_file_defaults(
|
||||
self,
|
||||
variables: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert FileInput-format defaults in workflow variables to full FileObject dicts."""
|
||||
from app.core.workflow.utils.file_processor import (
|
||||
resolve_local_file_object_dict,
|
||||
fetch_remote_file_meta,
|
||||
)
|
||||
|
||||
async def _resolve_one(item: dict) -> dict | None:
|
||||
if not isinstance(item, dict) or item.get("is_file"):
|
||||
return item
|
||||
transfer_method = item.get("transfer_method", "remote_url")
|
||||
file_type = item.get("type", "document")
|
||||
origin_file_type = item.get("file_type") or file_type
|
||||
if transfer_method == "remote_url":
|
||||
url = item.get("url", "")
|
||||
return await fetch_remote_file_meta(url, file_type, origin_file_type) if url else None
|
||||
else:
|
||||
return resolve_local_file_object_dict(self.db, item.get("upload_file_id"), file_type, origin_file_type)
|
||||
|
||||
result = []
|
||||
for var_def in variables:
|
||||
var_type = var_def.get("type", "")
|
||||
default = var_def.get("default")
|
||||
if var_type == "file" and isinstance(default, dict) and not default.get("is_file"):
|
||||
var_def = {**var_def, "default": await _resolve_one(default)}
|
||||
elif var_type == "array[file]" and isinstance(default, list):
|
||||
resolved = []
|
||||
for item in default:
|
||||
r = await _resolve_one(item)
|
||||
if r is not None:
|
||||
resolved.append(r)
|
||||
var_def = {**var_def, "default": resolved}
|
||||
result.append(var_def)
|
||||
return result
|
||||
|
||||
async def _handle_file_input(self, files: list[FileInput]):
|
||||
if not files:
|
||||
return []
|
||||
|
||||
from app.core.workflow.utils.file_processor import (
|
||||
resolve_local_file_object_dict,
|
||||
build_file_object_dict_from_meta,
|
||||
fetch_remote_file_meta,
|
||||
)
|
||||
|
||||
files_struct = []
|
||||
for file in files:
|
||||
files_struct.append(
|
||||
FileObject(
|
||||
type=file.type,
|
||||
url=await self.multimodal_service.get_file_url(file),
|
||||
transfer_method=file.transfer_method,
|
||||
file_id=str(file.upload_file_id) if file.upload_file_id else None,
|
||||
origin_file_type=file.file_type,
|
||||
is_file=True
|
||||
).model_dump()
|
||||
)
|
||||
url = await self.multimodal_service.get_file_url(file)
|
||||
file_type = str(file.type)
|
||||
origin_file_type = file.file_type or file_type
|
||||
|
||||
if file.transfer_method.value == "local_file" and file.upload_file_id:
|
||||
fo = resolve_local_file_object_dict(self.db, file.upload_file_id, file_type, origin_file_type)
|
||||
files_struct.append(fo or build_file_object_dict_from_meta(
|
||||
file_type=file_type, transfer_method="local_file",
|
||||
origin_file_type=origin_file_type,
|
||||
file_id=str(file.upload_file_id), url=url,
|
||||
file_name=None, file_size=None, file_ext=None, content_type=None,
|
||||
))
|
||||
else:
|
||||
files_struct.append(await fetch_remote_file_meta(url, file_type, origin_file_type))
|
||||
return files_struct
|
||||
|
||||
@staticmethod
|
||||
@@ -545,6 +592,12 @@ class WorkflowService:
|
||||
def _get_memory_store_info(self, workspace_id: uuid.UUID) -> tuple[str, str]:
|
||||
storage_type = get_workspace_storage_type_without_auth(self.db, workspace_id)
|
||||
user_rag_memory_id = ""
|
||||
# 如果 storage_type 为 None,使用默认值 'neo4j'
|
||||
if not storage_type:
|
||||
storage_type = 'neo4j'
|
||||
logger.warning(
|
||||
f"Storage type not set for workspace {workspace_id}, using default: neo4j"
|
||||
)
|
||||
if storage_type == "rag":
|
||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
db=self.db,
|
||||
@@ -659,6 +712,26 @@ class WorkflowService:
|
||||
input_data["conv_messages"] = conv_messages
|
||||
init_message_length = len(input_data.get("conv_messages", []))
|
||||
|
||||
# 新会话时写入开场白
|
||||
is_new_conversation = init_message_length == 0
|
||||
if is_new_conversation:
|
||||
opening_cfg = feature_configs.get("opening_statement", {})
|
||||
if isinstance(opening_cfg, dict) and opening_cfg.get("enabled") and opening_cfg.get("statement"):
|
||||
statement = opening_cfg["statement"]
|
||||
suggested_questions = opening_cfg.get("suggested_questions", [])
|
||||
if payload.variables:
|
||||
for var_name, var_value in payload.variables.items():
|
||||
statement = statement.replace(f"{{{{{var_name}}}}}", str(var_value))
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id_uuid,
|
||||
role="assistant",
|
||||
content=statement,
|
||||
meta_data={"suggested_questions": suggested_questions}
|
||||
)
|
||||
# 注入到 conv_messages,让 LLM 感知开场白
|
||||
input_data["conv_messages"] = [{"role": "assistant", "content": statement}]
|
||||
init_message_length = 1
|
||||
|
||||
result = await execute_workflow(
|
||||
workflow_config=workflow_config_dict,
|
||||
input_data=input_data,
|
||||
@@ -696,12 +769,21 @@ class WorkflowService:
|
||||
content=human_message,
|
||||
meta_data=human_meta
|
||||
)
|
||||
# 过滤 citations
|
||||
citations = result.get("citations", [])
|
||||
citation_cfg = feature_configs.get("citation", {})
|
||||
filtered_citations = (
|
||||
citations if isinstance(citation_cfg, dict) and citation_cfg.get("enabled") else []
|
||||
)
|
||||
assistant_meta = {"usage": token_usage, "audio_url": None}
|
||||
if filtered_citations:
|
||||
assistant_meta["citations"] = filtered_citations
|
||||
self.conversation_service.add_message(
|
||||
message_id=message_id,
|
||||
conversation_id=conversation_id_uuid,
|
||||
role="assistant",
|
||||
content=assistant_message,
|
||||
meta_data={"usage": token_usage, "audio_url": None}
|
||||
meta_data=assistant_meta
|
||||
)
|
||||
self.update_execution_status(
|
||||
execution.execution_id,
|
||||
@@ -720,6 +802,7 @@ class WorkflowService:
|
||||
)
|
||||
logger.error(f"Workflow Run Failed, execution_id: {execution.execution_id},"
|
||||
f" error: {result.get('error')}")
|
||||
filtered_citations = []
|
||||
|
||||
# 返回增强的响应结构
|
||||
return {
|
||||
@@ -734,7 +817,8 @@ class WorkflowService:
|
||||
"conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID
|
||||
"error_message": result.get("error"),
|
||||
"elapsed_time": result.get("elapsed_time"),
|
||||
"token_usage": result.get("token_usage")
|
||||
"token_usage": result.get("token_usage"),
|
||||
"citations": filtered_citations,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -825,6 +909,27 @@ class WorkflowService:
|
||||
input_data["conv_messages"] = conv_messages
|
||||
init_message_length = len(input_data.get("conv_messages", []))
|
||||
message_id = uuid.uuid4()
|
||||
|
||||
# 新会话时写入开场白
|
||||
is_new_conversation = init_message_length == 0
|
||||
if is_new_conversation:
|
||||
opening_cfg = feature_configs.get("opening_statement", {})
|
||||
if isinstance(opening_cfg, dict) and opening_cfg.get("enabled") and opening_cfg.get("statement"):
|
||||
statement = opening_cfg["statement"]
|
||||
suggested_questions = opening_cfg.get("suggested_questions", [])
|
||||
if payload.variables:
|
||||
for var_name, var_value in payload.variables.items():
|
||||
statement = statement.replace(f"{{{{{var_name}}}}}", str(var_value))
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id_uuid,
|
||||
role="assistant",
|
||||
content=statement,
|
||||
meta_data={"suggested_questions": suggested_questions}
|
||||
)
|
||||
# 注入到 conv_messages,让 LLM 感知开场白
|
||||
input_data["conv_messages"] = [{"role": "assistant", "content": statement}]
|
||||
init_message_length = 1
|
||||
|
||||
async for event in execute_workflow_stream(
|
||||
workflow_config=workflow_config_dict,
|
||||
input_data=input_data,
|
||||
@@ -862,12 +967,21 @@ class WorkflowService:
|
||||
content=human_message,
|
||||
meta_data=human_meta
|
||||
)
|
||||
# 过滤 citations
|
||||
citations = event.get("data", {}).get("citations", [])
|
||||
citation_cfg = feature_configs.get("citation", {})
|
||||
filtered_citations = (
|
||||
citations if isinstance(citation_cfg, dict) and citation_cfg.get("enabled") else []
|
||||
)
|
||||
assistant_meta = {"usage": token_usage, "audio_url": None}
|
||||
if filtered_citations:
|
||||
assistant_meta["citations"] = filtered_citations
|
||||
self.conversation_service.add_message(
|
||||
message_id=message_id,
|
||||
conversation_id=conversation_id_uuid,
|
||||
role="assistant",
|
||||
content=assistant_message,
|
||||
meta_data={"usage": token_usage, "audio_url": None}
|
||||
meta_data=assistant_meta
|
||||
)
|
||||
self.update_execution_status(
|
||||
execution.execution_id,
|
||||
@@ -875,6 +989,7 @@ class WorkflowService:
|
||||
output_data=event.get("data"),
|
||||
token_usage=token_usage.get("total_tokens", None)
|
||||
)
|
||||
event.setdefault("data", {})["citations"] = filtered_citations
|
||||
logger.info(f"Workflow Run Success, "
|
||||
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
|
||||
elif status == "failed":
|
||||
|
||||
@@ -480,21 +480,21 @@ def create_workspace_invite(
|
||||
try:
|
||||
# 检查权限
|
||||
_check_workspace_admin_permission(db, workspace_id, user)
|
||||
if settings.ENABLE_SINGLE_WORKSPACE:
|
||||
# 检查被邀请用户是否已经在工作空间中
|
||||
from app.repositories import user_repository
|
||||
invited_user = user_repository.get_user_by_email(db, invite_data.email)
|
||||
# if settings.ENABLE_SINGLE_WORKSPACE:
|
||||
# 检查被邀请用户是否已经在工作空间中
|
||||
from app.repositories import user_repository
|
||||
invited_user = user_repository.get_user_by_email(db, invite_data.email)
|
||||
|
||||
if invited_user:
|
||||
# 用户存在,检查是否已经是工作空间成员
|
||||
existing_member = workspace_repository.get_member_in_workspace(
|
||||
db=db,
|
||||
user_id=invited_user.id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if existing_member:
|
||||
business_logger.warning(f"用户 {invite_data.email} 已经是工作空间成员")
|
||||
raise BusinessException("该用户已经是工作空间成员", BizCode.RESOURCE_ALREADY_EXISTS)
|
||||
if invited_user:
|
||||
# 用户存在,检查是否已经是工作空间成员
|
||||
existing_member = workspace_repository.get_member_in_workspace(
|
||||
db=db,
|
||||
user_id=invited_user.id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if existing_member:
|
||||
business_logger.warning(f"用户 {invite_data.email} 已经是工作空间成员")
|
||||
raise BusinessException("该用户已经是工作空间成员", BizCode.RESOURCE_ALREADY_EXISTS)
|
||||
|
||||
# 检查是否已有待处理的邀请
|
||||
invite_repo = WorkspaceInviteRepository(db)
|
||||
|
||||
@@ -153,7 +153,8 @@ def workflow_config_4_app_release(release: AppRelease) -> WorkflowConfig:
|
||||
edges=config_dict.get("edges", []),
|
||||
variables=config_dict.get("variables", []),
|
||||
execution_config=config_dict.get("execution_config", {}),
|
||||
triggers=config_dict.get("triggers", [])
|
||||
triggers=config_dict.get("triggers", []),
|
||||
features=config_dict.get("features", {})
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
@@ -11,13 +11,13 @@ import { ContentEditable } from '@lexical/react/LexicalContentEditable';
|
||||
import { HistoryPlugin } from '@lexical/react/LexicalHistoryPlugin';
|
||||
import { LexicalErrorBoundary } from '@lexical/react/LexicalErrorBoundary';
|
||||
|
||||
import AutocompletePlugin, { type Suggestion } from './plugin/AutocompletePlugin';
|
||||
import { type Suggestion } from './plugin/AutocompletePlugin';
|
||||
import CharacterCountPlugin from './plugin/CharacterCountPlugin';
|
||||
import InitialValuePlugin from './plugin/InitialValuePlugin';
|
||||
import CommandPlugin from './plugin/CommandPlugin';
|
||||
import Jinja2InitialValuePlugin from './plugin/Jinja2InitialValuePlugin';
|
||||
import Jinja2AutocompletePlugin from './plugin/Jinja2AutocompletePlugin';
|
||||
import Jinja2HighlightPlugin from './plugin/Jinja2HighlightPlugin';
|
||||
import Jinja2BlurPlugin from './plugin/Jinja2BlurPlugin';
|
||||
import LineNumberPlugin from './plugin/LineNumberPlugin';
|
||||
import BlurPlugin from './plugin/BlurPlugin';
|
||||
|
||||
const jinja2Theme = {
|
||||
paragraph: 'editor-paragraph',
|
||||
@@ -171,13 +171,12 @@ const Jinja2Editor: FC<Jinja2EditorProps> = ({
|
||||
ErrorBoundary={LexicalErrorBoundary}
|
||||
/>
|
||||
<HistoryPlugin />
|
||||
<CommandPlugin />
|
||||
<Jinja2HighlightPlugin />
|
||||
<LineNumberPlugin />
|
||||
<AutocompletePlugin options={options} enableJinja2 />
|
||||
<CharacterCountPlugin setCount={() => {}} onChange={onChange} />
|
||||
<InitialValuePlugin value={value} options={options} enableLineNumbers />
|
||||
<BlurPlugin enableJinja2 />
|
||||
<Jinja2AutocompletePlugin options={options} />
|
||||
<CharacterCountPlugin setCount={() => {}} onChange={onChange} waitForInit />
|
||||
<Jinja2InitialValuePlugin value={value} />
|
||||
<Jinja2BlurPlugin />
|
||||
</div>
|
||||
</LexicalComposer>
|
||||
);
|
||||
|
||||
@@ -33,6 +33,7 @@ export interface LexicalEditorProps {
|
||||
type?: 'input' | 'textarea';
|
||||
language?: 'string' | 'jinja2';
|
||||
className?: string;
|
||||
waitForInit?: boolean;
|
||||
}
|
||||
|
||||
// Default theme for editor
|
||||
@@ -55,8 +56,10 @@ const Editor: FC<LexicalEditorProps> =({
|
||||
type = 'textarea',
|
||||
language = 'string',
|
||||
height,
|
||||
className
|
||||
className,
|
||||
waitForInit = false,
|
||||
}) => {
|
||||
console.log('Editor value', value)
|
||||
const [_count, setCount] = useState(0);
|
||||
|
||||
if (language === 'jinja2') {
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2025-12-23 16:22:51
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-25 16:13:37
|
||||
* @Last Modified time: 2026-04-02 17:12:41
|
||||
*/
|
||||
import { useEffect, useLayoutEffect, useState, useRef, type FC } from 'react';
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
|
||||
import { $getSelection, $isRangeSelection, $isTextNode, COMMAND_PRIORITY_HIGH, KEY_ENTER_COMMAND, KEY_ARROW_DOWN_COMMAND, KEY_ARROW_UP_COMMAND, KEY_ESCAPE_COMMAND } from 'lexical';
|
||||
import { $getSelection, $isRangeSelection, COMMAND_PRIORITY_HIGH, KEY_ENTER_COMMAND, KEY_ARROW_DOWN_COMMAND, KEY_ARROW_UP_COMMAND, KEY_ESCAPE_COMMAND } from 'lexical';
|
||||
import { Space, Flex } from 'antd';
|
||||
|
||||
import { INSERT_VARIABLE_COMMAND, CLOSE_AUTOCOMPLETE_COMMAND } from '../commands';
|
||||
@@ -28,7 +28,7 @@ export interface Suggestion {
|
||||
}
|
||||
|
||||
// Autocomplete plugin for variable suggestions triggered by '/' character
|
||||
const AutocompletePlugin: FC<{ options: Suggestion[], enableJinja2?: boolean }> = ({ options, enableJinja2 = false }) => {
|
||||
const AutocompletePlugin: FC<{ options: Suggestion[] }> = ({ options }) => {
|
||||
const [editor] = useLexicalComposerContext();
|
||||
const [showSuggestions, setShowSuggestions] = useState(false);
|
||||
const [selectedIndex, setSelectedIndex] = useState(0);
|
||||
@@ -159,34 +159,7 @@ const AutocompletePlugin: FC<{ options: Suggestion[], enableJinja2?: boolean }>
|
||||
|
||||
// Insert selected suggestion into editor
|
||||
const insertMention = (suggestion: Suggestion) => {
|
||||
if (enableJinja2) {
|
||||
// In Jinja2 mode, insert {{variable}} format text
|
||||
editor.update(() => {
|
||||
const selection = $getSelection();
|
||||
if ($isRangeSelection(selection)) {
|
||||
const anchorNode = selection.anchor.getNode();
|
||||
const anchorOffset = selection.anchor.offset;
|
||||
const nodeText = anchorNode.getTextContent();
|
||||
|
||||
// Remove trigger character '/'
|
||||
const textBefore = nodeText.substring(0, anchorOffset - 1);
|
||||
const textAfter = nodeText.substring(anchorOffset);
|
||||
const newText = textBefore + `{{${suggestion.value}}}` + textAfter;
|
||||
|
||||
if ($isTextNode(anchorNode)) {
|
||||
anchorNode.setTextContent(newText);
|
||||
}
|
||||
|
||||
// Set cursor position after inserted text
|
||||
const newOffset = textBefore.length + `{{${suggestion.value}}}`.length;
|
||||
selection.anchor.offset = newOffset;
|
||||
selection.focus.offset = newOffset;
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// In normal mode, use VariableNode
|
||||
editor.dispatchCommand(INSERT_VARIABLE_COMMAND, { data: suggestion });
|
||||
}
|
||||
editor.dispatchCommand(INSERT_VARIABLE_COMMAND, { data: suggestion });
|
||||
setShowSuggestions(false);
|
||||
setExpandedParent(null);
|
||||
setChildPanelTop(0);
|
||||
|
||||
@@ -1,64 +1,33 @@
|
||||
/*
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-01-20 10:42:13
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-03 10:12:10
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-04-02 17:13:08
|
||||
*/
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
|
||||
import { useEffect } from 'react';
|
||||
import { $setSelection } from 'lexical';
|
||||
import { CLOSE_AUTOCOMPLETE_COMMAND } from '../commands';
|
||||
|
||||
// Plugin to handle blur events and close autocomplete when clicking outside
|
||||
export default function BlurPlugin({ enableJinja2 }: { enableJinja2: boolean }) {
|
||||
export default function BlurPlugin() {
|
||||
const [editor] = useLexicalComposerContext();
|
||||
|
||||
useEffect(() => {
|
||||
// Close autocomplete when clicking outside the popup
|
||||
const handleClickOutside = (e: MouseEvent) => {
|
||||
const target = e.target as HTMLElement;
|
||||
if (target?.closest('[data-autocomplete-popup="true"]')) {
|
||||
return;
|
||||
}
|
||||
if ((e.target as HTMLElement)?.closest('[data-autocomplete-popup="true"]')) return;
|
||||
editor.dispatchCommand(CLOSE_AUTOCOMPLETE_COMMAND, undefined);
|
||||
};
|
||||
|
||||
document.addEventListener('mousedown', handleClickOutside);
|
||||
|
||||
return editor.registerRootListener((rootElement) => {
|
||||
if (rootElement) {
|
||||
const handleBlur = (e: FocusEvent) => {
|
||||
if (enableJinja2) {
|
||||
// Check if autocomplete popup was clicked
|
||||
const target = e.target as HTMLElement;
|
||||
if (target?.closest('[data-autocomplete-popup="true"]')) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if blur was caused by paste operation
|
||||
const relatedTarget = e.relatedTarget as HTMLElement;
|
||||
if (!relatedTarget || relatedTarget === document.body) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Clear selection on blur
|
||||
editor.update(() => {
|
||||
$setSelection(null);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
rootElement.addEventListener('blur', handleBlur);
|
||||
return () => {
|
||||
document.removeEventListener('mousedown', handleClickOutside);
|
||||
rootElement.removeEventListener('blur', handleBlur);
|
||||
};
|
||||
}
|
||||
return () => {
|
||||
document.removeEventListener('mousedown', handleClickOutside);
|
||||
};
|
||||
return () => { document.removeEventListener('mousedown', handleClickOutside); };
|
||||
});
|
||||
}, [editor, enableJinja2]);
|
||||
}, [editor]);
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -1,3 +1,9 @@
|
||||
/*
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2025-12-23 16:22:51
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-04-02 17:14:15
|
||||
*/
|
||||
import { useEffect, useRef } from 'react';
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
|
||||
import { $getRoot, $createParagraphNode, $createTextNode } from 'lexical';
|
||||
@@ -8,19 +14,17 @@ import { type Suggestion } from '../plugin/AutocompletePlugin'
|
||||
interface InitialValuePluginProps {
|
||||
value: string;
|
||||
options?: Suggestion[];
|
||||
enableLineNumbers?: boolean;
|
||||
}
|
||||
|
||||
const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options = [], enableLineNumbers = false }) => {
|
||||
const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options = [] }) => {
|
||||
const [editor] = useLexicalComposerContext();
|
||||
const prevValueRef = useRef<string>('');
|
||||
const prevEnableLineNumbersRef = useRef<boolean>(enableLineNumbers);
|
||||
const isUserInputRef = useRef(false);
|
||||
const optionsRef = useRef(options);
|
||||
optionsRef.current = options;
|
||||
|
||||
useEffect(() => {
|
||||
const removeListener = editor.registerUpdateListener(({ editorState, tags }) => {
|
||||
return editor.registerUpdateListener(({ editorState, tags }) => {
|
||||
if (tags.has('programmatic')) return;
|
||||
editorState.read(() => {
|
||||
const root = $getRoot();
|
||||
@@ -31,21 +35,16 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return removeListener;
|
||||
}, [editor]);
|
||||
|
||||
useEffect(() => {
|
||||
if (value !== prevValueRef.current || enableLineNumbers !== prevEnableLineNumbersRef.current) {
|
||||
// Skip reset if the change was triggered by user input (avoid cursor jump)
|
||||
if (isUserInputRef.current && enableLineNumbers === prevEnableLineNumbersRef.current) {
|
||||
if (value !== prevValueRef.current) {
|
||||
if (isUserInputRef.current) {
|
||||
prevValueRef.current = value;
|
||||
isUserInputRef.current = false;
|
||||
return;
|
||||
}
|
||||
// Update refs BEFORE editor.update to prevent re-entry
|
||||
prevValueRef.current = value;
|
||||
prevEnableLineNumbersRef.current = enableLineNumbers;
|
||||
isUserInputRef.current = false;
|
||||
|
||||
queueMicrotask(() => {
|
||||
@@ -54,16 +53,7 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
|
||||
root.clear();
|
||||
|
||||
const parts = value.split(/(\{\{[^}]+\}\}|\n)/);
|
||||
|
||||
if (enableLineNumbers) {
|
||||
const lines = value.split('\n');
|
||||
lines.forEach((line) => {
|
||||
const paragraph = $createParagraphNode();
|
||||
paragraph.append($createTextNode(line));
|
||||
root.append(paragraph);
|
||||
});
|
||||
} else {
|
||||
let paragraph = $createParagraphNode();
|
||||
let paragraph = $createParagraphNode();
|
||||
|
||||
parts.forEach(part => {
|
||||
if (part === '\n') {
|
||||
@@ -129,15 +119,10 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
|
||||
}
|
||||
});
|
||||
root.append(paragraph);
|
||||
}
|
||||
}, { tag: 'programmatic' });
|
||||
});
|
||||
} else {
|
||||
prevValueRef.current = value;
|
||||
prevEnableLineNumbersRef.current = enableLineNumbers;
|
||||
isUserInputRef.current = false;
|
||||
}
|
||||
}, [value, editor, enableLineNumbers]);
|
||||
}, [value, editor]);
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
@@ -0,0 +1,199 @@
|
||||
/*
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-04-02 17:10:59
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-04-02 17:10:59
|
||||
*/
|
||||
import { useEffect, useState, useRef, type FC } from 'react';
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
|
||||
import {
|
||||
$getSelection, $isRangeSelection, $isTextNode,
|
||||
COMMAND_PRIORITY_HIGH, KEY_ENTER_COMMAND, KEY_ARROW_DOWN_COMMAND,
|
||||
KEY_ARROW_UP_COMMAND, KEY_ESCAPE_COMMAND,
|
||||
} from 'lexical';
|
||||
import { Space, Flex } from 'antd';
|
||||
|
||||
import { CLOSE_AUTOCOMPLETE_COMMAND } from '../commands';
|
||||
import type { Suggestion } from './AutocompletePlugin';
|
||||
|
||||
const Jinja2AutocompletePlugin: FC<{ options: Suggestion[] }> = ({ options }) => {
|
||||
const [editor] = useLexicalComposerContext();
|
||||
const [showSuggestions, setShowSuggestions] = useState(false);
|
||||
const [selectedIndex, setSelectedIndex] = useState(0);
|
||||
const [popupPosition, setPopupPosition] = useState({ top: 0, left: 0 });
|
||||
const popupRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const scrollSelectedIntoView = () => {
|
||||
if (!popupRef.current) return;
|
||||
const selectedElement = popupRef.current.querySelector('[data-selected="true"]');
|
||||
if (!selectedElement) return;
|
||||
const container = popupRef.current;
|
||||
const element = selectedElement as HTMLElement;
|
||||
const containerRect = container.getBoundingClientRect();
|
||||
const elementRect = element.getBoundingClientRect();
|
||||
if (elementRect.bottom > containerRect.bottom) {
|
||||
container.scrollTop += elementRect.bottom - containerRect.bottom;
|
||||
} else if (elementRect.top < containerRect.top) {
|
||||
container.scrollTop -= containerRect.top - elementRect.top;
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
return editor.registerUpdateListener(({ editorState }) => {
|
||||
editorState.read(() => {
|
||||
const selection = $getSelection();
|
||||
if (!selection || !$isRangeSelection(selection)) {
|
||||
setShowSuggestions(false);
|
||||
return;
|
||||
}
|
||||
const anchorNode = selection.anchor.getNode();
|
||||
const anchorOffset = selection.anchor.offset;
|
||||
const textBeforeCursor = anchorNode.getTextContent().substring(0, anchorOffset);
|
||||
const shouldShow = textBeforeCursor.endsWith('/');
|
||||
setShowSuggestions(shouldShow);
|
||||
if (!shouldShow) { setSelectedIndex(0); return; }
|
||||
|
||||
const domSelection = window.getSelection();
|
||||
if (domSelection && domSelection.rangeCount > 0) {
|
||||
const rect = domSelection.getRangeAt(0).getBoundingClientRect();
|
||||
const popupWidth = 280, popupHeight = 200;
|
||||
const vw = window.innerWidth, vh = window.innerHeight;
|
||||
let left = Math.min(Math.max(rect.left, 10), vw - popupWidth - 10);
|
||||
let top = rect.top - 10;
|
||||
if (top - popupHeight < 10) {
|
||||
top = Math.min(rect.bottom + 10, vh - popupHeight - 10);
|
||||
}
|
||||
setPopupPosition({ top, left });
|
||||
}
|
||||
});
|
||||
});
|
||||
}, [editor]);
|
||||
|
||||
useEffect(() => {
|
||||
return editor.registerCommand(
|
||||
CLOSE_AUTOCOMPLETE_COMMAND,
|
||||
() => { setShowSuggestions(false); return true; },
|
||||
COMMAND_PRIORITY_HIGH,
|
||||
);
|
||||
}, [editor]);
|
||||
|
||||
const insertMention = (suggestion: Suggestion) => {
|
||||
editor.update(() => {
|
||||
const selection = $getSelection();
|
||||
if (!$isRangeSelection(selection)) return;
|
||||
const anchorNode = selection.anchor.getNode();
|
||||
const anchorOffset = selection.anchor.offset;
|
||||
const nodeText = anchorNode.getTextContent();
|
||||
const textBefore = nodeText.substring(0, anchorOffset - 1);
|
||||
const textAfter = nodeText.substring(anchorOffset);
|
||||
const inserted = `{{${suggestion.value}}}`;
|
||||
if ($isTextNode(anchorNode)) {
|
||||
anchorNode.setTextContent(textBefore + inserted + textAfter);
|
||||
const newOffset = textBefore.length + inserted.length;
|
||||
selection.anchor.offset = newOffset;
|
||||
selection.focus.offset = newOffset;
|
||||
}
|
||||
});
|
||||
setShowSuggestions(false);
|
||||
};
|
||||
|
||||
const groupedSuggestions = options.reduce((groups: Record<string, Suggestion[]>, s) => {
|
||||
const id = s.nodeData.id as string;
|
||||
if (!groups[id]) groups[id] = [];
|
||||
groups[id].push(s);
|
||||
return groups;
|
||||
}, {});
|
||||
|
||||
const allOptions = Object.values(groupedSuggestions).flat();
|
||||
|
||||
useEffect(() => {
|
||||
if (!showSuggestions) return;
|
||||
return editor.registerCommand(
|
||||
KEY_ENTER_COMMAND,
|
||||
(event) => {
|
||||
const opt = allOptions[selectedIndex];
|
||||
if (opt && !opt.disabled) { event?.preventDefault(); insertMention(opt); return true; }
|
||||
return false;
|
||||
},
|
||||
COMMAND_PRIORITY_HIGH,
|
||||
);
|
||||
}, [showSuggestions, selectedIndex, allOptions]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!showSuggestions) return;
|
||||
const down = editor.registerCommand(KEY_ARROW_DOWN_COMMAND, (e) => {
|
||||
e?.preventDefault();
|
||||
setSelectedIndex(prev => {
|
||||
let next = prev + 1;
|
||||
while (next < allOptions.length && allOptions[next].disabled) next++;
|
||||
setTimeout(scrollSelectedIntoView, 0);
|
||||
return next >= allOptions.length ? prev : next;
|
||||
});
|
||||
return true;
|
||||
}, COMMAND_PRIORITY_HIGH);
|
||||
const up = editor.registerCommand(KEY_ARROW_UP_COMMAND, (e) => {
|
||||
e?.preventDefault();
|
||||
setSelectedIndex(prev => {
|
||||
let p = prev - 1;
|
||||
while (p >= 0 && allOptions[p].disabled) p--;
|
||||
setTimeout(scrollSelectedIntoView, 0);
|
||||
return p < 0 ? prev : p;
|
||||
});
|
||||
return true;
|
||||
}, COMMAND_PRIORITY_HIGH);
|
||||
const esc = editor.registerCommand(KEY_ESCAPE_COMMAND, (e) => {
|
||||
e?.preventDefault(); setShowSuggestions(false); return true;
|
||||
}, COMMAND_PRIORITY_HIGH);
|
||||
return () => { down(); up(); esc(); };
|
||||
}, [showSuggestions, selectedIndex, allOptions, editor]);
|
||||
|
||||
if (!showSuggestions || Object.keys(groupedSuggestions).length === 0) return null;
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={popupRef}
|
||||
data-autocomplete-popup="true"
|
||||
onMouseDown={(e) => e.preventDefault()}
|
||||
className="rb:fixed rb:z-1000 rb:py-1 rb:bg-white rb:rounded-xl rb:min-w-70 rb:max-h-50 rb:overflow-y-auto rb:transform-[translateY(-100%)] rb:shadow-[0px_2px_12px_0px_rgba(23,23,25,0.12)]"
|
||||
style={{ top: popupPosition.top, left: popupPosition.left }}
|
||||
>
|
||||
<Flex vertical gap={12}>
|
||||
{Object.entries(groupedSuggestions).map(([nodeId, nodeOptions]) => (
|
||||
<div key={nodeId}>
|
||||
<Flex align="center" gap={4} className="rb:px-3! rb:text-[12px] rb:py-1.25! rb:font-medium rb:text-[#5B6167]">
|
||||
{nodeOptions[0]?.nodeData?.icon && <img src={nodeOptions[0].nodeData.icon} className="rb:size-3" alt="" />}
|
||||
{nodeOptions[0]?.nodeData?.name || nodeId}
|
||||
</Flex>
|
||||
{nodeOptions.map((option) => {
|
||||
const globalIndex = allOptions.indexOf(option);
|
||||
return (
|
||||
<Flex
|
||||
key={option.key}
|
||||
data-selected={selectedIndex === globalIndex}
|
||||
className="rb:pl-6! rb:pr-3! rb:py-2!"
|
||||
align="center"
|
||||
justify="space-between"
|
||||
style={{
|
||||
cursor: option.disabled ? 'not-allowed' : 'pointer',
|
||||
background: selectedIndex === globalIndex ? '#f0f8ff' : 'white',
|
||||
opacity: option.disabled ? 0.5 : 1,
|
||||
}}
|
||||
onClick={() => !option.disabled && insertMention(option)}
|
||||
onMouseEnter={() => setSelectedIndex(globalIndex)}
|
||||
>
|
||||
<Space size={4}>
|
||||
<span className="rb:text-[#155EEF]">{option.isContext ? '📄' : '{x}'}</span>
|
||||
<span>{option.label}</span>
|
||||
</Space>
|
||||
{option.dataType && <span className="rb:text-[#5B6167]">{option.dataType}</span>}
|
||||
</Flex>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
))}
|
||||
</Flex>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default Jinja2AutocompletePlugin;
|
||||
@@ -0,0 +1,41 @@
|
||||
/*
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-04-02 17:11:04
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-04-02 17:11:04
|
||||
*/
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
|
||||
import { useEffect } from 'react';
|
||||
import { $setSelection } from 'lexical';
|
||||
import { CLOSE_AUTOCOMPLETE_COMMAND } from '../commands';
|
||||
|
||||
export default function Jinja2BlurPlugin() {
|
||||
const [editor] = useLexicalComposerContext();
|
||||
|
||||
useEffect(() => {
|
||||
const handleClickOutside = (e: MouseEvent) => {
|
||||
if ((e.target as HTMLElement)?.closest('[data-autocomplete-popup="true"]')) return;
|
||||
editor.dispatchCommand(CLOSE_AUTOCOMPLETE_COMMAND, undefined);
|
||||
};
|
||||
document.addEventListener('mousedown', handleClickOutside);
|
||||
|
||||
return editor.registerRootListener((rootElement) => {
|
||||
if (rootElement) {
|
||||
const handleBlur = (e: FocusEvent) => {
|
||||
if ((e.target as HTMLElement)?.closest('[data-autocomplete-popup="true"]')) return;
|
||||
const relatedTarget = e.relatedTarget as HTMLElement;
|
||||
if (!relatedTarget || relatedTarget === document.body) return;
|
||||
editor.update(() => { $setSelection(null); });
|
||||
};
|
||||
rootElement.addEventListener('blur', handleBlur);
|
||||
return () => {
|
||||
document.removeEventListener('mousedown', handleClickOutside);
|
||||
rootElement.removeEventListener('blur', handleBlur);
|
||||
};
|
||||
}
|
||||
return () => { document.removeEventListener('mousedown', handleClickOutside); };
|
||||
});
|
||||
}, [editor]);
|
||||
|
||||
return null;
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
/*
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-04-02 17:11:07
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-04-02 17:11:07
|
||||
*/
|
||||
import { useEffect, useRef } from 'react';
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
|
||||
import { $getRoot, $createParagraphNode, $createTextNode } from 'lexical';
|
||||
|
||||
interface Jinja2InitialValuePluginProps {
|
||||
value: string;
|
||||
}
|
||||
|
||||
const Jinja2InitialValuePlugin: React.FC<Jinja2InitialValuePluginProps> = ({ value }) => {
|
||||
const [editor] = useLexicalComposerContext();
|
||||
const prevValueRef = useRef<string>('');
|
||||
const isUserInputRef = useRef(false);
|
||||
|
||||
useEffect(() => {
|
||||
return editor.registerUpdateListener(({ editorState, tags }) => {
|
||||
if (tags.has('programmatic')) return;
|
||||
editorState.read(() => {
|
||||
const textContent = $getRoot().getTextContent();
|
||||
if (textContent !== prevValueRef.current) {
|
||||
isUserInputRef.current = true;
|
||||
prevValueRef.current = textContent;
|
||||
}
|
||||
});
|
||||
});
|
||||
}, [editor]);
|
||||
|
||||
useEffect(() => {
|
||||
if (value === prevValueRef.current) return;
|
||||
|
||||
if (isUserInputRef.current) {
|
||||
prevValueRef.current = value;
|
||||
isUserInputRef.current = false;
|
||||
return;
|
||||
}
|
||||
|
||||
prevValueRef.current = value;
|
||||
isUserInputRef.current = false;
|
||||
|
||||
queueMicrotask(() => {
|
||||
editor.update(() => {
|
||||
const root = $getRoot();
|
||||
root.clear();
|
||||
value.split('\n').forEach((line) => {
|
||||
const paragraph = $createParagraphNode();
|
||||
paragraph.append($createTextNode(line));
|
||||
root.append(paragraph);
|
||||
});
|
||||
}, { tag: 'programmatic' });
|
||||
});
|
||||
}, [value, editor]);
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
export default Jinja2InitialValuePlugin;
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-09 18:35:43
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-20 11:32:44
|
||||
* @Last Modified time: 2026-04-02 17:17:06
|
||||
*/
|
||||
import { type FC, useRef, useState } from "react";
|
||||
import { useTranslation } from 'react-i18next'
|
||||
@@ -114,6 +114,7 @@ const HttpRequest: FC<{ options: Suggestion[]; selectedNode?: any; graphRef?: an
|
||||
<Col span={16}>
|
||||
<Form.Item name="url">
|
||||
<Editor
|
||||
key="url"
|
||||
options={options.filter(vo => vo.dataType === 'string' || vo.dataType === 'number')}
|
||||
variant="outlined"
|
||||
type="input"
|
||||
@@ -212,13 +213,15 @@ const HttpRequest: FC<{ options: Suggestion[]; selectedNode?: any; graphRef?: an
|
||||
}
|
||||
{values?.body?.content_type === 'binary' &&
|
||||
<Form.Item name={['body', 'data']}
|
||||
className="rb:bg-[#F6F6F6] rb:border-[#F6F6F6]! rb:hover:bg-white rb:hover:border-[#171719]! rb:border rb:rounded-lg rb:px-2! rb:py-1.5! rb:mb-0!"
|
||||
className="rb:bg-[#F6F6F6] rb:border-[#F6F6F6]! rb:hover:bg-white rb:hover:border-[#171719]! rb:border rb:rounded-lg rb:mb-0!"
|
||||
>
|
||||
<Editor
|
||||
key={['body', 'data'].join('_')}
|
||||
placeholder={t('common.pleaseSelect')}
|
||||
options={options.filter(vo => vo.dataType.includes('file'))}
|
||||
type="input"
|
||||
size="small"
|
||||
height={28}
|
||||
/>
|
||||
</Form.Item>
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user