diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index 74991bcf..3fbdfbf1 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -292,10 +292,19 @@ def get_opening( ): """返回开场白文本和预设问题,供前端对话界面初始化时展示""" workspace_id = current_user.current_workspace_id - cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id) - features = cfg.features or {} - if hasattr(features, "model_dump"): - features = features.model_dump() + + # 根据应用类型获取 features + from app.models.app_model import App as AppModel + app = db.get(AppModel, app_id) + if app and app.type == "workflow": + cfg = app_service.get_workflow_config(db=db, app_id=app_id, workspace_id=workspace_id) + features = cfg.features or {} + else: + cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id) + features = cfg.features or {} + if hasattr(features, "model_dump"): + features = features.model_dump() + opening = features.get("opening_statement", {}) return success(data=app_schema.OpeningResponse( enabled=opening.get("enabled", False), diff --git a/api/app/controllers/document_controller.py b/api/app/controllers/document_controller.py index 72f9cb8f..350acc0e 100644 --- a/api/app/controllers/document_controller.py +++ b/api/app/controllers/document_controller.py @@ -314,8 +314,10 @@ async def parse_documents( ) # 4. Check if the file exists + api_logger.debug(f"Constructed file path: {file_path}") + api_logger.debug(f"File metadata - kb_id: {db_file.kb_id}, parent_id: {db_file.parent_id}, file_id: {db_file.id}, extension: {db_file.file_ext}") if not os.path.exists(file_path): - api_logger.warning(f"File not found (possibly deleted): file_path={file_path}") + api_logger.error(f"File not found (possibly deleted): file_path={file_path}, file_id={db_file.id}, document_id={document_id}") raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="File not found (possibly deleted)" diff --git a/api/app/core/workflow/engine/result_builder.py b/api/app/core/workflow/engine/result_builder.py index be0c957a..dc16df17 100644 --- a/api/app/core/workflow/engine/result_builder.py +++ b/api/app/core/workflow/engine/result_builder.py @@ -59,6 +59,9 @@ class WorkflowResultBuilder: conversation_vars = variable_pool.get_all_conversation_vars() sys_vars = variable_pool.get_all_system_vars() + # 汇总所有 knowledge 节点的 citations + citations = self.aggregate_citations(node_outputs) + return { "status": "completed" if success else "failed", "output": final_output, @@ -71,9 +74,25 @@ class WorkflowResultBuilder: "conversation_id": execution_context.conversation_id, "elapsed_time": elapsed_time, "token_usage": token_usage, + "citations": citations, "error": result.get("error"), } + @staticmethod + def aggregate_citations(node_outputs: dict) -> list: + """从所有 knowledge 节点的输出中汇总 citations,去重""" + seen = set() + citations = [] + for node_output in node_outputs.values(): + if not isinstance(node_output, dict): + continue + for c in node_output.get("citations", []): + key = c.get("document_id") + if key and key not in seen: + seen.add(key) + citations.append(c) + return citations + @staticmethod def aggregate_token_usage(node_outputs: dict) -> dict[str, int] | None: """ diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index bedf6165..5458a80c 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -395,7 +395,8 @@ class BaseNode(ABC): "output": output, "elapsed_time": elapsed_time, "token_usage": token_usage, - "error": None + "error": None, + **self._extract_extra_fields(business_result), } final_output = { "node_outputs": {self.node_id: node_output}, @@ -498,6 +499,13 @@ class BaseNode(ABC): # Default implementation returns the business result directly return business_result + def _extract_extra_fields(self, business_result: Any) -> dict: + """Extracts extra fields to merge into node_output (e.g. citations). + + Subclasses may override to inject additional metadata. + """ + return {} + def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None: """Extracts token usage information from the business result. diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index d0b6d098..97fa86cb 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -34,6 +34,20 @@ class KnowledgeRetrievalNode(BaseNode): "output": VariableType.ARRAY_STRING } + def _extract_output(self, business_result: Any) -> Any: + """下游节点只拿 chunks 列表""" + if isinstance(business_result, dict) and "chunks" in business_result: + return business_result["chunks"] + return business_result + + def _extract_citations(self, business_result: Any) -> list: + if isinstance(business_result, dict): + return business_result.get("citations", []) + return [] + + def _extract_extra_fields(self, business_result: Any) -> dict: + return {"citations": self._extract_citations(business_result)} + def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: return { "query": self._render_template(self.typed_config.query, variable_pool), @@ -314,4 +328,20 @@ class KnowledgeRetrievalNode(BaseNode): logger.info( f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}" ) - return [chunk.page_content for chunk in final_rs] + citations = [] + seen_doc_ids = set() + for chunk in final_rs: + meta = chunk.metadata or {} + doc_id = meta.get("document_id") or meta.get("doc_id") + if doc_id and doc_id not in seen_doc_ids: + seen_doc_ids.add(doc_id) + citations.append({ + "document_id": str(doc_id), + "file_name": meta.get("file_name", ""), + "knowledge_id": str(meta.get("knowledge_id", kb_config.kb_id)), + "score": meta.get("score", 0.0), + }) + return { + "chunks": [chunk.page_content for chunk in final_rs], + "citations": citations, + } diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 4ca3e7de..97665b9b 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -4,10 +4,6 @@ from typing import Optional, Any, List, Dict, Union from enum import Enum, StrEnum from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator - -from app.schemas.workflow_schema import WorkflowConfigCreate - - # ---------- Multimodal File Support ---------- class FileType(StrEnum): @@ -317,7 +313,7 @@ class AppCreate(BaseModel): # only for type=multi_agent multi_agent_config: Optional[Dict[str, Any]] = None - workflow_config: Optional[WorkflowConfigCreate] = None + workflow_config: Optional[Dict[str, Any]] = None class AppUpdate(BaseModel): diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 5af6bf41..36d7e614 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -401,7 +401,7 @@ class AppService: def _create_workflow_config( self, app_id: uuid.UUID, - data: app_schema.WorkflowConfigCreate, + data, now: datetime.datetime ): workflow_cfg = WorkflowConfig( @@ -678,7 +678,9 @@ class AppService: self._create_multi_agent_config(app.id, data.multi_agent_config, now) if app.type == "workflow" and data.workflow_config: - self._create_workflow_config(app.id, data.workflow_config, now) + from app.schemas.workflow_schema import WorkflowConfigCreate + wf_data = WorkflowConfigCreate(**data.workflow_config) if isinstance(data.workflow_config, dict) else data.workflow_config + self._create_workflow_config(app.id, wf_data, now) self.db.commit() self.db.refresh(app) diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index 13267078..97715ee9 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -545,6 +545,12 @@ class WorkflowService: def _get_memory_store_info(self, workspace_id: uuid.UUID) -> tuple[str, str]: storage_type = get_workspace_storage_type_without_auth(self.db, workspace_id) user_rag_memory_id = "" + # 如果 storage_type 为 None,使用默认值 'neo4j' + if not storage_type: + storage_type = 'neo4j' + logger.warning( + f"Storage type not set for workspace {workspace_id}, using default: neo4j" + ) if storage_type == "rag": knowledge = knowledge_repository.get_knowledge_by_name( db=self.db, @@ -659,6 +665,26 @@ class WorkflowService: input_data["conv_messages"] = conv_messages init_message_length = len(input_data.get("conv_messages", [])) + # 新会话时写入开场白 + is_new_conversation = init_message_length == 0 + if is_new_conversation: + opening_cfg = feature_configs.get("opening_statement", {}) + if isinstance(opening_cfg, dict) and opening_cfg.get("enabled") and opening_cfg.get("statement"): + statement = opening_cfg["statement"] + suggested_questions = opening_cfg.get("suggested_questions", []) + if payload.variables: + for var_name, var_value in payload.variables.items(): + statement = statement.replace(f"{{{{{var_name}}}}}", str(var_value)) + self.conversation_service.add_message( + conversation_id=conversation_id_uuid, + role="assistant", + content=statement, + meta_data={"suggested_questions": suggested_questions} + ) + # 注入到 conv_messages,让 LLM 感知开场白 + input_data["conv_messages"] = [{"role": "assistant", "content": statement}] + init_message_length = 1 + result = await execute_workflow( workflow_config=workflow_config_dict, input_data=input_data, @@ -721,6 +747,13 @@ class WorkflowService: logger.error(f"Workflow Run Failed, execution_id: {execution.execution_id}," f" error: {result.get('error')}") + # 过滤 citations + citations = result.get("citations", []) + citation_cfg = feature_configs.get("citation", {}) + filtered_citations = ( + citations if isinstance(citation_cfg, dict) and citation_cfg.get("enabled") else [] + ) + # 返回增强的响应结构 return { "execution_id": execution.execution_id, @@ -734,7 +767,8 @@ class WorkflowService: "conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID "error_message": result.get("error"), "elapsed_time": result.get("elapsed_time"), - "token_usage": result.get("token_usage") + "token_usage": result.get("token_usage"), + "citations": filtered_citations, } except Exception as e: @@ -825,6 +859,27 @@ class WorkflowService: input_data["conv_messages"] = conv_messages init_message_length = len(input_data.get("conv_messages", [])) message_id = uuid.uuid4() + + # 新会话时写入开场白 + is_new_conversation = init_message_length == 0 + if is_new_conversation: + opening_cfg = feature_configs.get("opening_statement", {}) + if isinstance(opening_cfg, dict) and opening_cfg.get("enabled") and opening_cfg.get("statement"): + statement = opening_cfg["statement"] + suggested_questions = opening_cfg.get("suggested_questions", []) + if payload.variables: + for var_name, var_value in payload.variables.items(): + statement = statement.replace(f"{{{{{var_name}}}}}", str(var_value)) + self.conversation_service.add_message( + conversation_id=conversation_id_uuid, + role="assistant", + content=statement, + meta_data={"suggested_questions": suggested_questions} + ) + # 注入到 conv_messages,让 LLM 感知开场白 + input_data["conv_messages"] = [{"role": "assistant", "content": statement}] + init_message_length = 1 + async for event in execute_workflow_stream( workflow_config=workflow_config_dict, input_data=input_data, @@ -875,6 +930,13 @@ class WorkflowService: output_data=event.get("data"), token_usage=token_usage.get("total_tokens", None) ) + # 注入 citations 到 workflow_end 事件 + citations = event.get("data", {}).get("citations", []) + citation_cfg = feature_configs.get("citation", {}) + filtered_citations = ( + citations if isinstance(citation_cfg, dict) and citation_cfg.get("enabled") else [] + ) + event.setdefault("data", {})["citations"] = filtered_citations logger.info(f"Workflow Run Success, " f"execution_id: {execution.execution_id}, message count: {len(final_messages)}") elif status == "failed": diff --git a/api/app/utils/app_config_utils.py b/api/app/utils/app_config_utils.py index bc03bb28..4f88fb4d 100644 --- a/api/app/utils/app_config_utils.py +++ b/api/app/utils/app_config_utils.py @@ -153,7 +153,8 @@ def workflow_config_4_app_release(release: AppRelease) -> WorkflowConfig: edges=config_dict.get("edges", []), variables=config_dict.get("variables", []), execution_config=config_dict.get("execution_config", {}), - triggers=config_dict.get("triggers", []) + triggers=config_dict.get("triggers", []), + features=config_dict.get("features", {}) ) return config