feat(workflow): add opening statement and citation support
- Trigger opening statement on new conversation in run/run_stream - Fix /opening endpoint to support workflow app type - Fix features field missing in workflow config release snapshot - Knowledge node returns citations alongside chunks - Aggregate citations from all knowledge nodes in result builder - Filter citations based on features.citation.enabled switch - Fix WorkflowConfigCreate circular import in app_schema
This commit is contained in:
@@ -292,10 +292,19 @@ def get_opening(
|
|||||||
):
|
):
|
||||||
"""返回开场白文本和预设问题,供前端对话界面初始化时展示"""
|
"""返回开场白文本和预设问题,供前端对话界面初始化时展示"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
|
|
||||||
features = cfg.features or {}
|
# 根据应用类型获取 features
|
||||||
if hasattr(features, "model_dump"):
|
from app.models.app_model import App as AppModel
|
||||||
features = features.model_dump()
|
app = db.get(AppModel, app_id)
|
||||||
|
if app and app.type == "workflow":
|
||||||
|
cfg = app_service.get_workflow_config(db=db, app_id=app_id, workspace_id=workspace_id)
|
||||||
|
features = cfg.features or {}
|
||||||
|
else:
|
||||||
|
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
|
||||||
|
features = cfg.features or {}
|
||||||
|
if hasattr(features, "model_dump"):
|
||||||
|
features = features.model_dump()
|
||||||
|
|
||||||
opening = features.get("opening_statement", {})
|
opening = features.get("opening_statement", {})
|
||||||
return success(data=app_schema.OpeningResponse(
|
return success(data=app_schema.OpeningResponse(
|
||||||
enabled=opening.get("enabled", False),
|
enabled=opening.get("enabled", False),
|
||||||
|
|||||||
@@ -314,8 +314,10 @@ async def parse_documents(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 4. Check if the file exists
|
# 4. Check if the file exists
|
||||||
|
api_logger.debug(f"Constructed file path: {file_path}")
|
||||||
|
api_logger.debug(f"File metadata - kb_id: {db_file.kb_id}, parent_id: {db_file.parent_id}, file_id: {db_file.id}, extension: {db_file.file_ext}")
|
||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
api_logger.warning(f"File not found (possibly deleted): file_path={file_path}")
|
api_logger.error(f"File not found (possibly deleted): file_path={file_path}, file_id={db_file.id}, document_id={document_id}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail="File not found (possibly deleted)"
|
detail="File not found (possibly deleted)"
|
||||||
|
|||||||
@@ -59,6 +59,9 @@ class WorkflowResultBuilder:
|
|||||||
conversation_vars = variable_pool.get_all_conversation_vars()
|
conversation_vars = variable_pool.get_all_conversation_vars()
|
||||||
sys_vars = variable_pool.get_all_system_vars()
|
sys_vars = variable_pool.get_all_system_vars()
|
||||||
|
|
||||||
|
# 汇总所有 knowledge 节点的 citations
|
||||||
|
citations = self.aggregate_citations(node_outputs)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "completed" if success else "failed",
|
"status": "completed" if success else "failed",
|
||||||
"output": final_output,
|
"output": final_output,
|
||||||
@@ -71,9 +74,25 @@ class WorkflowResultBuilder:
|
|||||||
"conversation_id": execution_context.conversation_id,
|
"conversation_id": execution_context.conversation_id,
|
||||||
"elapsed_time": elapsed_time,
|
"elapsed_time": elapsed_time,
|
||||||
"token_usage": token_usage,
|
"token_usage": token_usage,
|
||||||
|
"citations": citations,
|
||||||
"error": result.get("error"),
|
"error": result.get("error"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def aggregate_citations(node_outputs: dict) -> list:
|
||||||
|
"""从所有 knowledge 节点的输出中汇总 citations,去重"""
|
||||||
|
seen = set()
|
||||||
|
citations = []
|
||||||
|
for node_output in node_outputs.values():
|
||||||
|
if not isinstance(node_output, dict):
|
||||||
|
continue
|
||||||
|
for c in node_output.get("citations", []):
|
||||||
|
key = c.get("document_id")
|
||||||
|
if key and key not in seen:
|
||||||
|
seen.add(key)
|
||||||
|
citations.append(c)
|
||||||
|
return citations
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def aggregate_token_usage(node_outputs: dict) -> dict[str, int] | None:
|
def aggregate_token_usage(node_outputs: dict) -> dict[str, int] | None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -395,7 +395,8 @@ class BaseNode(ABC):
|
|||||||
"output": output,
|
"output": output,
|
||||||
"elapsed_time": elapsed_time,
|
"elapsed_time": elapsed_time,
|
||||||
"token_usage": token_usage,
|
"token_usage": token_usage,
|
||||||
"error": None
|
"error": None,
|
||||||
|
**self._extract_extra_fields(business_result),
|
||||||
}
|
}
|
||||||
final_output = {
|
final_output = {
|
||||||
"node_outputs": {self.node_id: node_output},
|
"node_outputs": {self.node_id: node_output},
|
||||||
@@ -498,6 +499,13 @@ class BaseNode(ABC):
|
|||||||
# Default implementation returns the business result directly
|
# Default implementation returns the business result directly
|
||||||
return business_result
|
return business_result
|
||||||
|
|
||||||
|
def _extract_extra_fields(self, business_result: Any) -> dict:
|
||||||
|
"""Extracts extra fields to merge into node_output (e.g. citations).
|
||||||
|
|
||||||
|
Subclasses may override to inject additional metadata.
|
||||||
|
"""
|
||||||
|
return {}
|
||||||
|
|
||||||
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||||
"""Extracts token usage information from the business result.
|
"""Extracts token usage information from the business result.
|
||||||
|
|
||||||
|
|||||||
@@ -34,6 +34,20 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
"output": VariableType.ARRAY_STRING
|
"output": VariableType.ARRAY_STRING
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _extract_output(self, business_result: Any) -> Any:
|
||||||
|
"""下游节点只拿 chunks 列表"""
|
||||||
|
if isinstance(business_result, dict) and "chunks" in business_result:
|
||||||
|
return business_result["chunks"]
|
||||||
|
return business_result
|
||||||
|
|
||||||
|
def _extract_citations(self, business_result: Any) -> list:
|
||||||
|
if isinstance(business_result, dict):
|
||||||
|
return business_result.get("citations", [])
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _extract_extra_fields(self, business_result: Any) -> dict:
|
||||||
|
return {"citations": self._extract_citations(business_result)}
|
||||||
|
|
||||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"query": self._render_template(self.typed_config.query, variable_pool),
|
"query": self._render_template(self.typed_config.query, variable_pool),
|
||||||
@@ -314,4 +328,20 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
|
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
|
||||||
)
|
)
|
||||||
return [chunk.page_content for chunk in final_rs]
|
citations = []
|
||||||
|
seen_doc_ids = set()
|
||||||
|
for chunk in final_rs:
|
||||||
|
meta = chunk.metadata or {}
|
||||||
|
doc_id = meta.get("document_id") or meta.get("doc_id")
|
||||||
|
if doc_id and doc_id not in seen_doc_ids:
|
||||||
|
seen_doc_ids.add(doc_id)
|
||||||
|
citations.append({
|
||||||
|
"document_id": str(doc_id),
|
||||||
|
"file_name": meta.get("file_name", ""),
|
||||||
|
"knowledge_id": str(meta.get("knowledge_id", kb_config.kb_id)),
|
||||||
|
"score": meta.get("score", 0.0),
|
||||||
|
})
|
||||||
|
return {
|
||||||
|
"chunks": [chunk.page_content for chunk in final_rs],
|
||||||
|
"citations": citations,
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,10 +4,6 @@ from typing import Optional, Any, List, Dict, Union
|
|||||||
from enum import Enum, StrEnum
|
from enum import Enum, StrEnum
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
|
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
|
||||||
|
|
||||||
from app.schemas.workflow_schema import WorkflowConfigCreate
|
|
||||||
|
|
||||||
|
|
||||||
# ---------- Multimodal File Support ----------
|
# ---------- Multimodal File Support ----------
|
||||||
|
|
||||||
class FileType(StrEnum):
|
class FileType(StrEnum):
|
||||||
@@ -317,7 +313,7 @@ class AppCreate(BaseModel):
|
|||||||
# only for type=multi_agent
|
# only for type=multi_agent
|
||||||
multi_agent_config: Optional[Dict[str, Any]] = None
|
multi_agent_config: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
workflow_config: Optional[WorkflowConfigCreate] = None
|
workflow_config: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
class AppUpdate(BaseModel):
|
class AppUpdate(BaseModel):
|
||||||
|
|||||||
@@ -401,7 +401,7 @@ class AppService:
|
|||||||
def _create_workflow_config(
|
def _create_workflow_config(
|
||||||
self,
|
self,
|
||||||
app_id: uuid.UUID,
|
app_id: uuid.UUID,
|
||||||
data: app_schema.WorkflowConfigCreate,
|
data,
|
||||||
now: datetime.datetime
|
now: datetime.datetime
|
||||||
):
|
):
|
||||||
workflow_cfg = WorkflowConfig(
|
workflow_cfg = WorkflowConfig(
|
||||||
@@ -678,7 +678,9 @@ class AppService:
|
|||||||
self._create_multi_agent_config(app.id, data.multi_agent_config, now)
|
self._create_multi_agent_config(app.id, data.multi_agent_config, now)
|
||||||
|
|
||||||
if app.type == "workflow" and data.workflow_config:
|
if app.type == "workflow" and data.workflow_config:
|
||||||
self._create_workflow_config(app.id, data.workflow_config, now)
|
from app.schemas.workflow_schema import WorkflowConfigCreate
|
||||||
|
wf_data = WorkflowConfigCreate(**data.workflow_config) if isinstance(data.workflow_config, dict) else data.workflow_config
|
||||||
|
self._create_workflow_config(app.id, wf_data, now)
|
||||||
|
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
self.db.refresh(app)
|
self.db.refresh(app)
|
||||||
|
|||||||
@@ -545,6 +545,12 @@ class WorkflowService:
|
|||||||
def _get_memory_store_info(self, workspace_id: uuid.UUID) -> tuple[str, str]:
|
def _get_memory_store_info(self, workspace_id: uuid.UUID) -> tuple[str, str]:
|
||||||
storage_type = get_workspace_storage_type_without_auth(self.db, workspace_id)
|
storage_type = get_workspace_storage_type_without_auth(self.db, workspace_id)
|
||||||
user_rag_memory_id = ""
|
user_rag_memory_id = ""
|
||||||
|
# 如果 storage_type 为 None,使用默认值 'neo4j'
|
||||||
|
if not storage_type:
|
||||||
|
storage_type = 'neo4j'
|
||||||
|
logger.warning(
|
||||||
|
f"Storage type not set for workspace {workspace_id}, using default: neo4j"
|
||||||
|
)
|
||||||
if storage_type == "rag":
|
if storage_type == "rag":
|
||||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||||
db=self.db,
|
db=self.db,
|
||||||
@@ -659,6 +665,26 @@ class WorkflowService:
|
|||||||
input_data["conv_messages"] = conv_messages
|
input_data["conv_messages"] = conv_messages
|
||||||
init_message_length = len(input_data.get("conv_messages", []))
|
init_message_length = len(input_data.get("conv_messages", []))
|
||||||
|
|
||||||
|
# 新会话时写入开场白
|
||||||
|
is_new_conversation = init_message_length == 0
|
||||||
|
if is_new_conversation:
|
||||||
|
opening_cfg = feature_configs.get("opening_statement", {})
|
||||||
|
if isinstance(opening_cfg, dict) and opening_cfg.get("enabled") and opening_cfg.get("statement"):
|
||||||
|
statement = opening_cfg["statement"]
|
||||||
|
suggested_questions = opening_cfg.get("suggested_questions", [])
|
||||||
|
if payload.variables:
|
||||||
|
for var_name, var_value in payload.variables.items():
|
||||||
|
statement = statement.replace(f"{{{{{var_name}}}}}", str(var_value))
|
||||||
|
self.conversation_service.add_message(
|
||||||
|
conversation_id=conversation_id_uuid,
|
||||||
|
role="assistant",
|
||||||
|
content=statement,
|
||||||
|
meta_data={"suggested_questions": suggested_questions}
|
||||||
|
)
|
||||||
|
# 注入到 conv_messages,让 LLM 感知开场白
|
||||||
|
input_data["conv_messages"] = [{"role": "assistant", "content": statement}]
|
||||||
|
init_message_length = 1
|
||||||
|
|
||||||
result = await execute_workflow(
|
result = await execute_workflow(
|
||||||
workflow_config=workflow_config_dict,
|
workflow_config=workflow_config_dict,
|
||||||
input_data=input_data,
|
input_data=input_data,
|
||||||
@@ -721,6 +747,13 @@ class WorkflowService:
|
|||||||
logger.error(f"Workflow Run Failed, execution_id: {execution.execution_id},"
|
logger.error(f"Workflow Run Failed, execution_id: {execution.execution_id},"
|
||||||
f" error: {result.get('error')}")
|
f" error: {result.get('error')}")
|
||||||
|
|
||||||
|
# 过滤 citations
|
||||||
|
citations = result.get("citations", [])
|
||||||
|
citation_cfg = feature_configs.get("citation", {})
|
||||||
|
filtered_citations = (
|
||||||
|
citations if isinstance(citation_cfg, dict) and citation_cfg.get("enabled") else []
|
||||||
|
)
|
||||||
|
|
||||||
# 返回增强的响应结构
|
# 返回增强的响应结构
|
||||||
return {
|
return {
|
||||||
"execution_id": execution.execution_id,
|
"execution_id": execution.execution_id,
|
||||||
@@ -734,7 +767,8 @@ class WorkflowService:
|
|||||||
"conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID
|
"conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID
|
||||||
"error_message": result.get("error"),
|
"error_message": result.get("error"),
|
||||||
"elapsed_time": result.get("elapsed_time"),
|
"elapsed_time": result.get("elapsed_time"),
|
||||||
"token_usage": result.get("token_usage")
|
"token_usage": result.get("token_usage"),
|
||||||
|
"citations": filtered_citations,
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -825,6 +859,27 @@ class WorkflowService:
|
|||||||
input_data["conv_messages"] = conv_messages
|
input_data["conv_messages"] = conv_messages
|
||||||
init_message_length = len(input_data.get("conv_messages", []))
|
init_message_length = len(input_data.get("conv_messages", []))
|
||||||
message_id = uuid.uuid4()
|
message_id = uuid.uuid4()
|
||||||
|
|
||||||
|
# 新会话时写入开场白
|
||||||
|
is_new_conversation = init_message_length == 0
|
||||||
|
if is_new_conversation:
|
||||||
|
opening_cfg = feature_configs.get("opening_statement", {})
|
||||||
|
if isinstance(opening_cfg, dict) and opening_cfg.get("enabled") and opening_cfg.get("statement"):
|
||||||
|
statement = opening_cfg["statement"]
|
||||||
|
suggested_questions = opening_cfg.get("suggested_questions", [])
|
||||||
|
if payload.variables:
|
||||||
|
for var_name, var_value in payload.variables.items():
|
||||||
|
statement = statement.replace(f"{{{{{var_name}}}}}", str(var_value))
|
||||||
|
self.conversation_service.add_message(
|
||||||
|
conversation_id=conversation_id_uuid,
|
||||||
|
role="assistant",
|
||||||
|
content=statement,
|
||||||
|
meta_data={"suggested_questions": suggested_questions}
|
||||||
|
)
|
||||||
|
# 注入到 conv_messages,让 LLM 感知开场白
|
||||||
|
input_data["conv_messages"] = [{"role": "assistant", "content": statement}]
|
||||||
|
init_message_length = 1
|
||||||
|
|
||||||
async for event in execute_workflow_stream(
|
async for event in execute_workflow_stream(
|
||||||
workflow_config=workflow_config_dict,
|
workflow_config=workflow_config_dict,
|
||||||
input_data=input_data,
|
input_data=input_data,
|
||||||
@@ -875,6 +930,13 @@ class WorkflowService:
|
|||||||
output_data=event.get("data"),
|
output_data=event.get("data"),
|
||||||
token_usage=token_usage.get("total_tokens", None)
|
token_usage=token_usage.get("total_tokens", None)
|
||||||
)
|
)
|
||||||
|
# 注入 citations 到 workflow_end 事件
|
||||||
|
citations = event.get("data", {}).get("citations", [])
|
||||||
|
citation_cfg = feature_configs.get("citation", {})
|
||||||
|
filtered_citations = (
|
||||||
|
citations if isinstance(citation_cfg, dict) and citation_cfg.get("enabled") else []
|
||||||
|
)
|
||||||
|
event.setdefault("data", {})["citations"] = filtered_citations
|
||||||
logger.info(f"Workflow Run Success, "
|
logger.info(f"Workflow Run Success, "
|
||||||
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
|
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
|
||||||
elif status == "failed":
|
elif status == "failed":
|
||||||
|
|||||||
@@ -153,7 +153,8 @@ def workflow_config_4_app_release(release: AppRelease) -> WorkflowConfig:
|
|||||||
edges=config_dict.get("edges", []),
|
edges=config_dict.get("edges", []),
|
||||||
variables=config_dict.get("variables", []),
|
variables=config_dict.get("variables", []),
|
||||||
execution_config=config_dict.get("execution_config", {}),
|
execution_config=config_dict.get("execution_config", {}),
|
||||||
triggers=config_dict.get("triggers", [])
|
triggers=config_dict.get("triggers", []),
|
||||||
|
features=config_dict.get("features", {})
|
||||||
)
|
)
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|||||||
Reference in New Issue
Block a user