diff --git a/api/app/core/workflow/nodes/memory/config.py b/api/app/core/workflow/nodes/memory/config.py index 31881e24..25b5105d 100644 --- a/api/app/core/workflow/nodes/memory/config.py +++ b/api/app/core/workflow/nodes/memory/config.py @@ -1,10 +1,33 @@ from uuid import UUID -from pydantic import Field +from pydantic import BaseModel, field_validator, Field from app.core.workflow.nodes.base_config import BaseNodeConfig +class MessageConfig(BaseModel): + """消息配置""" + + role: str = Field( + default='user', + description="消息角色:system, user, assistant" + ) + + content: str = Field( + default="", + description="消息内容,支持模板变量,如:{{ sys.message }}" + ) + + @field_validator("role") + @classmethod + def validate_role(cls, v: str) -> str: + """验证角色""" + allowed_roles = ["system", "user", "human", "assistant", "ai"] + if v.lower() not in allowed_roles: + raise ValueError(f"角色必须是以下之一: {', '.join(allowed_roles)}") + return v.lower() + + class MemoryReadNodeConfig(BaseNodeConfig): message: str = Field( ... @@ -25,6 +48,10 @@ class MemoryWriteNodeConfig(BaseNodeConfig): ... ) + messages: list[MessageConfig] = Field( + default_factory=list + ) + config_id: UUID | int = Field( ... ) diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index ddbe4b99..654ea0c6 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -55,10 +55,22 @@ class MemoryWriteNode(BaseNode): if not end_user_id: raise RuntimeError("End user id is required") + messages = [] + if self.typed_config.message: + messages.append({ + "role": "user", + "content": self._render_template(self.typed_config.message, variable_pool) + }) + + for message in self.typed_config.messages: + messages.append({ + "role": message.role, + "content": self._render_template(message.content, variable_pool) + }) write_message_task.delay( end_user_id, - self._render_template(self.typed_config.message, variable_pool), + messages, str(self.typed_config.config_id), "neo4j", "" diff --git a/api/app/tasks.py b/api/app/tasks.py index 539a3700..3a9da918 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -981,7 +981,7 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s @celery_app.task(name="app.core.memory.agent.write_message", bind=True) -def write_message_task(self, end_user_id: str, message: str, config_id: str, storage_type: str, user_rag_memory_id: str, +def write_message_task(self, end_user_id: str, message: list[dict], config_id: str, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> Dict[str, Any]: """Celery task to process a write message via MemoryAgentService.