Merge pull request #785 from wanxunyang/feat/app-log-wxy
feat(workflow): add opening statement and citation support
This commit is contained in:
@@ -292,10 +292,19 @@ def get_opening(
|
||||
):
|
||||
"""返回开场白文本和预设问题,供前端对话界面初始化时展示"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
|
||||
features = cfg.features or {}
|
||||
if hasattr(features, "model_dump"):
|
||||
features = features.model_dump()
|
||||
|
||||
# 根据应用类型获取 features
|
||||
from app.models.app_model import App as AppModel
|
||||
app = db.get(AppModel, app_id)
|
||||
if app and app.type == "workflow":
|
||||
cfg = app_service.get_workflow_config(db=db, app_id=app_id, workspace_id=workspace_id)
|
||||
features = cfg.features or {}
|
||||
else:
|
||||
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
|
||||
features = cfg.features or {}
|
||||
if hasattr(features, "model_dump"):
|
||||
features = features.model_dump()
|
||||
|
||||
opening = features.get("opening_statement", {})
|
||||
return success(data=app_schema.OpeningResponse(
|
||||
enabled=opening.get("enabled", False),
|
||||
|
||||
@@ -314,8 +314,10 @@ async def parse_documents(
|
||||
)
|
||||
|
||||
# 4. Check if the file exists
|
||||
api_logger.debug(f"Constructed file path: {file_path}")
|
||||
api_logger.debug(f"File metadata - kb_id: {db_file.kb_id}, parent_id: {db_file.parent_id}, file_id: {db_file.id}, extension: {db_file.file_ext}")
|
||||
if not os.path.exists(file_path):
|
||||
api_logger.warning(f"File not found (possibly deleted): file_path={file_path}")
|
||||
api_logger.error(f"File not found (possibly deleted): file_path={file_path}, file_id={db_file.id}, document_id={document_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="File not found (possibly deleted)"
|
||||
|
||||
@@ -59,6 +59,9 @@ class WorkflowResultBuilder:
|
||||
conversation_vars = variable_pool.get_all_conversation_vars()
|
||||
sys_vars = variable_pool.get_all_system_vars()
|
||||
|
||||
# 汇总所有 knowledge 节点的 citations
|
||||
citations = self.aggregate_citations(node_outputs)
|
||||
|
||||
return {
|
||||
"status": "completed" if success else "failed",
|
||||
"output": final_output,
|
||||
@@ -71,9 +74,25 @@ class WorkflowResultBuilder:
|
||||
"conversation_id": execution_context.conversation_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": token_usage,
|
||||
"citations": citations,
|
||||
"error": result.get("error"),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def aggregate_citations(node_outputs: dict) -> list:
|
||||
"""从所有 knowledge 节点的输出中汇总 citations,去重"""
|
||||
seen = set()
|
||||
citations = []
|
||||
for node_output in node_outputs.values():
|
||||
if not isinstance(node_output, dict):
|
||||
continue
|
||||
for c in node_output.get("citations", []):
|
||||
key = c.get("document_id")
|
||||
if key and key not in seen:
|
||||
seen.add(key)
|
||||
citations.append(c)
|
||||
return citations
|
||||
|
||||
@staticmethod
|
||||
def aggregate_token_usage(node_outputs: dict) -> dict[str, int] | None:
|
||||
"""
|
||||
|
||||
@@ -395,7 +395,8 @@ class BaseNode(ABC):
|
||||
"output": output,
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": token_usage,
|
||||
"error": None
|
||||
"error": None,
|
||||
**self._extract_extra_fields(business_result),
|
||||
}
|
||||
final_output = {
|
||||
"node_outputs": {self.node_id: node_output},
|
||||
@@ -498,6 +499,13 @@ class BaseNode(ABC):
|
||||
# Default implementation returns the business result directly
|
||||
return business_result
|
||||
|
||||
def _extract_extra_fields(self, business_result: Any) -> dict:
|
||||
"""Extracts extra fields to merge into node_output (e.g. citations).
|
||||
|
||||
Subclasses may override to inject additional metadata.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||
"""Extracts token usage information from the business result.
|
||||
|
||||
|
||||
@@ -34,6 +34,20 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
"output": VariableType.ARRAY_STRING
|
||||
}
|
||||
|
||||
def _extract_output(self, business_result: Any) -> Any:
|
||||
"""下游节点只拿 chunks 列表"""
|
||||
if isinstance(business_result, dict) and "chunks" in business_result:
|
||||
return business_result["chunks"]
|
||||
return business_result
|
||||
|
||||
def _extract_citations(self, business_result: Any) -> list:
|
||||
if isinstance(business_result, dict):
|
||||
return business_result.get("citations", [])
|
||||
return []
|
||||
|
||||
def _extract_extra_fields(self, business_result: Any) -> dict:
|
||||
return {"citations": self._extract_citations(business_result)}
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
return {
|
||||
"query": self._render_template(self.typed_config.query, variable_pool),
|
||||
@@ -314,4 +328,20 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
logger.info(
|
||||
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
|
||||
)
|
||||
return [chunk.page_content for chunk in final_rs]
|
||||
citations = []
|
||||
seen_doc_ids = set()
|
||||
for chunk in final_rs:
|
||||
meta = chunk.metadata or {}
|
||||
doc_id = meta.get("document_id") or meta.get("doc_id")
|
||||
if doc_id and doc_id not in seen_doc_ids:
|
||||
seen_doc_ids.add(doc_id)
|
||||
citations.append({
|
||||
"document_id": str(doc_id),
|
||||
"file_name": meta.get("file_name", ""),
|
||||
"knowledge_id": str(meta.get("knowledge_id", kb_config.kb_id)),
|
||||
"score": meta.get("score", 0.0),
|
||||
})
|
||||
return {
|
||||
"chunks": [chunk.page_content for chunk in final_rs],
|
||||
"citations": citations,
|
||||
}
|
||||
|
||||
@@ -4,10 +4,6 @@ from typing import Optional, Any, List, Dict, Union
|
||||
from enum import Enum, StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
|
||||
|
||||
from app.schemas.workflow_schema import WorkflowConfigCreate
|
||||
|
||||
|
||||
# ---------- Multimodal File Support ----------
|
||||
|
||||
class FileType(StrEnum):
|
||||
@@ -317,7 +313,7 @@ class AppCreate(BaseModel):
|
||||
# only for type=multi_agent
|
||||
multi_agent_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
workflow_config: Optional[WorkflowConfigCreate] = None
|
||||
workflow_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class AppUpdate(BaseModel):
|
||||
|
||||
@@ -401,7 +401,7 @@ class AppService:
|
||||
def _create_workflow_config(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
data: app_schema.WorkflowConfigCreate,
|
||||
data,
|
||||
now: datetime.datetime
|
||||
):
|
||||
workflow_cfg = WorkflowConfig(
|
||||
@@ -678,7 +678,9 @@ class AppService:
|
||||
self._create_multi_agent_config(app.id, data.multi_agent_config, now)
|
||||
|
||||
if app.type == "workflow" and data.workflow_config:
|
||||
self._create_workflow_config(app.id, data.workflow_config, now)
|
||||
from app.schemas.workflow_schema import WorkflowConfigCreate
|
||||
wf_data = WorkflowConfigCreate(**data.workflow_config) if isinstance(data.workflow_config, dict) else data.workflow_config
|
||||
self._create_workflow_config(app.id, wf_data, now)
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(app)
|
||||
|
||||
@@ -545,6 +545,12 @@ class WorkflowService:
|
||||
def _get_memory_store_info(self, workspace_id: uuid.UUID) -> tuple[str, str]:
|
||||
storage_type = get_workspace_storage_type_without_auth(self.db, workspace_id)
|
||||
user_rag_memory_id = ""
|
||||
# 如果 storage_type 为 None,使用默认值 'neo4j'
|
||||
if not storage_type:
|
||||
storage_type = 'neo4j'
|
||||
logger.warning(
|
||||
f"Storage type not set for workspace {workspace_id}, using default: neo4j"
|
||||
)
|
||||
if storage_type == "rag":
|
||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
db=self.db,
|
||||
@@ -659,6 +665,26 @@ class WorkflowService:
|
||||
input_data["conv_messages"] = conv_messages
|
||||
init_message_length = len(input_data.get("conv_messages", []))
|
||||
|
||||
# 新会话时写入开场白
|
||||
is_new_conversation = init_message_length == 0
|
||||
if is_new_conversation:
|
||||
opening_cfg = feature_configs.get("opening_statement", {})
|
||||
if isinstance(opening_cfg, dict) and opening_cfg.get("enabled") and opening_cfg.get("statement"):
|
||||
statement = opening_cfg["statement"]
|
||||
suggested_questions = opening_cfg.get("suggested_questions", [])
|
||||
if payload.variables:
|
||||
for var_name, var_value in payload.variables.items():
|
||||
statement = statement.replace(f"{{{{{var_name}}}}}", str(var_value))
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id_uuid,
|
||||
role="assistant",
|
||||
content=statement,
|
||||
meta_data={"suggested_questions": suggested_questions}
|
||||
)
|
||||
# 注入到 conv_messages,让 LLM 感知开场白
|
||||
input_data["conv_messages"] = [{"role": "assistant", "content": statement}]
|
||||
init_message_length = 1
|
||||
|
||||
result = await execute_workflow(
|
||||
workflow_config=workflow_config_dict,
|
||||
input_data=input_data,
|
||||
@@ -721,6 +747,13 @@ class WorkflowService:
|
||||
logger.error(f"Workflow Run Failed, execution_id: {execution.execution_id},"
|
||||
f" error: {result.get('error')}")
|
||||
|
||||
# 过滤 citations
|
||||
citations = result.get("citations", [])
|
||||
citation_cfg = feature_configs.get("citation", {})
|
||||
filtered_citations = (
|
||||
citations if isinstance(citation_cfg, dict) and citation_cfg.get("enabled") else []
|
||||
)
|
||||
|
||||
# 返回增强的响应结构
|
||||
return {
|
||||
"execution_id": execution.execution_id,
|
||||
@@ -734,7 +767,8 @@ class WorkflowService:
|
||||
"conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID
|
||||
"error_message": result.get("error"),
|
||||
"elapsed_time": result.get("elapsed_time"),
|
||||
"token_usage": result.get("token_usage")
|
||||
"token_usage": result.get("token_usage"),
|
||||
"citations": filtered_citations,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -825,6 +859,27 @@ class WorkflowService:
|
||||
input_data["conv_messages"] = conv_messages
|
||||
init_message_length = len(input_data.get("conv_messages", []))
|
||||
message_id = uuid.uuid4()
|
||||
|
||||
# 新会话时写入开场白
|
||||
is_new_conversation = init_message_length == 0
|
||||
if is_new_conversation:
|
||||
opening_cfg = feature_configs.get("opening_statement", {})
|
||||
if isinstance(opening_cfg, dict) and opening_cfg.get("enabled") and opening_cfg.get("statement"):
|
||||
statement = opening_cfg["statement"]
|
||||
suggested_questions = opening_cfg.get("suggested_questions", [])
|
||||
if payload.variables:
|
||||
for var_name, var_value in payload.variables.items():
|
||||
statement = statement.replace(f"{{{{{var_name}}}}}", str(var_value))
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id_uuid,
|
||||
role="assistant",
|
||||
content=statement,
|
||||
meta_data={"suggested_questions": suggested_questions}
|
||||
)
|
||||
# 注入到 conv_messages,让 LLM 感知开场白
|
||||
input_data["conv_messages"] = [{"role": "assistant", "content": statement}]
|
||||
init_message_length = 1
|
||||
|
||||
async for event in execute_workflow_stream(
|
||||
workflow_config=workflow_config_dict,
|
||||
input_data=input_data,
|
||||
@@ -875,6 +930,13 @@ class WorkflowService:
|
||||
output_data=event.get("data"),
|
||||
token_usage=token_usage.get("total_tokens", None)
|
||||
)
|
||||
# 注入 citations 到 workflow_end 事件
|
||||
citations = event.get("data", {}).get("citations", [])
|
||||
citation_cfg = feature_configs.get("citation", {})
|
||||
filtered_citations = (
|
||||
citations if isinstance(citation_cfg, dict) and citation_cfg.get("enabled") else []
|
||||
)
|
||||
event.setdefault("data", {})["citations"] = filtered_citations
|
||||
logger.info(f"Workflow Run Success, "
|
||||
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
|
||||
elif status == "failed":
|
||||
|
||||
@@ -153,7 +153,8 @@ def workflow_config_4_app_release(release: AppRelease) -> WorkflowConfig:
|
||||
edges=config_dict.get("edges", []),
|
||||
variables=config_dict.get("variables", []),
|
||||
execution_config=config_dict.get("execution_config", {}),
|
||||
triggers=config_dict.get("triggers", [])
|
||||
triggers=config_dict.get("triggers", []),
|
||||
features=config_dict.get("features", {})
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
Reference in New Issue
Block a user