Merge pull request #375 from SuanmoSuanyangTechnology/fix/workflow-memory-write
fix(workflow): adapt memory node write behavior
This commit is contained in:
@@ -1,10 +1,33 @@
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, field_validator, Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
|
||||
|
||||
class MessageConfig(BaseModel):
|
||||
"""消息配置"""
|
||||
|
||||
role: str = Field(
|
||||
default='user',
|
||||
description="消息角色:system, user, assistant"
|
||||
)
|
||||
|
||||
content: str = Field(
|
||||
default="",
|
||||
description="消息内容,支持模板变量,如:{{ sys.message }}"
|
||||
)
|
||||
|
||||
@field_validator("role")
|
||||
@classmethod
|
||||
def validate_role(cls, v: str) -> str:
|
||||
"""验证角色"""
|
||||
allowed_roles = ["system", "user", "human", "assistant", "ai"]
|
||||
if v.lower() not in allowed_roles:
|
||||
raise ValueError(f"角色必须是以下之一: {', '.join(allowed_roles)}")
|
||||
return v.lower()
|
||||
|
||||
|
||||
class MemoryReadNodeConfig(BaseNodeConfig):
|
||||
message: str = Field(
|
||||
...
|
||||
@@ -25,6 +48,10 @@ class MemoryWriteNodeConfig(BaseNodeConfig):
|
||||
...
|
||||
)
|
||||
|
||||
messages: list[MessageConfig] = Field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
config_id: UUID | int = Field(
|
||||
...
|
||||
)
|
||||
|
||||
@@ -55,10 +55,22 @@ class MemoryWriteNode(BaseNode):
|
||||
|
||||
if not end_user_id:
|
||||
raise RuntimeError("End user id is required")
|
||||
messages = []
|
||||
if self.typed_config.message:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": self._render_template(self.typed_config.message, variable_pool)
|
||||
})
|
||||
|
||||
for message in self.typed_config.messages:
|
||||
messages.append({
|
||||
"role": message.role,
|
||||
"content": self._render_template(message.content, variable_pool)
|
||||
})
|
||||
|
||||
write_message_task.delay(
|
||||
end_user_id,
|
||||
self._render_template(self.typed_config.message, variable_pool),
|
||||
messages,
|
||||
str(self.typed_config.config_id),
|
||||
"neo4j",
|
||||
""
|
||||
|
||||
@@ -981,7 +981,7 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
|
||||
|
||||
|
||||
@celery_app.task(name="app.core.memory.agent.write_message", bind=True)
|
||||
def write_message_task(self, end_user_id: str, message: str, config_id: str, storage_type: str, user_rag_memory_id: str,
|
||||
def write_message_task(self, end_user_id: str, message: list[dict], config_id: str, storage_type: str, user_rag_memory_id: str,
|
||||
language: str = "zh") -> Dict[str, Any]:
|
||||
"""Celery task to process a write message via MemoryAgentService.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user