Feature/memory work (#61)
* refactor(conversation): separate service and repository layers for conversation module - Split ConversationService and repository/UnitOfWork layers - Service layer now only handles business logic and orchestration - Repository layer handles all direct database operations - UnitOfWork encapsulates transactional operations for messages - Ensured all public methods have clear English docstrings with arguments, return values, and exceptions * feat(memory): implement work memory endpoints and services - Added API routes for conversation count, conversation list, messages, and detail. - Integrated ConversationService for database queries and LLM-based summary generation. * feat(memory): implement work memory endpoints and services - Added API routes for conversation count, conversation list, messages, and detail. - Integrated ConversationService for database queries and LLM-based summary generation. * feat(workflow): fix issues causing workflow failures if-else None value error knowledge empty list rerank end node output none node value assigner input none value * feat(memory): convert memory file creation time to timestamp and include title and first-line fields in file type * fix(memory): fix serialization output and default value issues * fix(workflow): fix issue with hybrid search logic in knowledge retrieval node
This commit is contained in:
@@ -22,7 +22,7 @@ class AssignmentItem(BaseModel):
|
||||
)
|
||||
|
||||
value: Any = Field(
|
||||
...,
|
||||
default=None,
|
||||
description="Value(s) to assign to the variable(s)",
|
||||
)
|
||||
|
||||
|
||||
@@ -534,7 +534,7 @@ class BaseNode(ABC):
|
||||
return edge
|
||||
return None
|
||||
|
||||
def _render_template(self, template: str, state: WorkflowState | None) -> str:
|
||||
def _render_template(self, template: str, state: WorkflowState | None, struct: bool = True) -> str:
|
||||
"""渲染模板
|
||||
|
||||
支持的变量命名空间:
|
||||
@@ -568,7 +568,8 @@ class BaseNode(ABC):
|
||||
template=template,
|
||||
variables=variables,
|
||||
node_outputs=pool.get_all_node_outputs(),
|
||||
system_vars=pool.get_all_system_vars()
|
||||
system_vars=pool.get_all_system_vars(),
|
||||
struct=struct
|
||||
)
|
||||
|
||||
def _evaluate_condition(self, expression: str, state: WorkflowState | None) -> bool:
|
||||
|
||||
@@ -55,7 +55,7 @@ class ConditionDetail(BaseModel):
|
||||
)
|
||||
|
||||
input_type: ValueInputType = Field(
|
||||
...,
|
||||
default=ValueInputType.CONSTANT,
|
||||
description="Input type of the loop variable"
|
||||
)
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ class EndNode(BaseNode):
|
||||
|
||||
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
||||
if output_template:
|
||||
output = self._render_template(output_template, state)
|
||||
output = self._render_template(output_template, state, struct=False)
|
||||
else:
|
||||
output = "工作流已完成"
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ class ConditionDetail(BaseModel):
|
||||
)
|
||||
|
||||
input_type: ValueInputType = Field(
|
||||
...,
|
||||
default=ValueInputType.CONSTANT,
|
||||
description="Value input type for comparison"
|
||||
)
|
||||
|
||||
|
||||
@@ -71,7 +71,10 @@ class IfElseNode(BaseNode):
|
||||
for expression in case_branch.expressions:
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
left_string = re.sub(pattern, r"\1", expression.left).strip()
|
||||
left_value = self.get_variable(left_string, state)
|
||||
try:
|
||||
left_value = self.get_variable(left_string, state)
|
||||
except KeyError:
|
||||
left_value = None
|
||||
evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
|
||||
self.get_variable_pool(state),
|
||||
expression.left,
|
||||
|
||||
@@ -205,10 +205,14 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
score_threshold=kb_config.similarity_threshold)
|
||||
# Deduplicate hy brid retrieval results
|
||||
unique_rs = self._deduplicate_docs(rs1, rs2)
|
||||
if not unique_rs:
|
||||
continue
|
||||
vector_service.reranker = self.get_reranker_model()
|
||||
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
||||
case _:
|
||||
raise RuntimeError("Unknown retrieval type")
|
||||
if not rs:
|
||||
return []
|
||||
vector_service.reranker = self.get_reranker_model()
|
||||
# TODO:其他重排序方式支持
|
||||
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
||||
|
||||
@@ -65,11 +65,12 @@ class LLMNode(BaseNode):
|
||||
- user/human: 用户消息(HumanMessage)
|
||||
- ai/assistant: AI 消息(AIMessage)
|
||||
"""
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = LLMNodeConfig(**self.config)
|
||||
|
||||
def _render_context(self, message,state):
|
||||
def _render_context(self, message, state):
|
||||
context = f"<context>{self._render_template(self.typed_config.context, state)}</context>"
|
||||
return re.sub(r"{{context}}", context, message)
|
||||
|
||||
|
||||
@@ -387,6 +387,11 @@ class ArrayComparisonOperator(ConditionBase):
|
||||
return self.right_value not in self.left_value
|
||||
|
||||
|
||||
class NoneObjectComparisonOperator(ConditionBase):
|
||||
def __getattr__(self, name):
|
||||
return lambda *args, **kwargs: False
|
||||
|
||||
|
||||
CompareOperatorInstance = Union[
|
||||
StringComparisonOperator,
|
||||
NumberComparisonOperator,
|
||||
@@ -405,6 +410,7 @@ class ConditionExpressionResolver:
|
||||
float: NumberComparisonOperator,
|
||||
list: ArrayComparisonOperator,
|
||||
dict: ObjectComparisonOperator,
|
||||
type(None): NoneObjectComparisonOperator
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class TemplateRenderer:
|
||||
"""模板渲染器"""
|
||||
|
||||
|
||||
def __init__(self, strict: bool = True):
|
||||
"""初始化渲染器
|
||||
|
||||
@@ -25,13 +25,13 @@ class TemplateRenderer:
|
||||
undefined=StrictUndefined if strict else Undefined,
|
||||
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML
|
||||
)
|
||||
|
||||
|
||||
def render(
|
||||
self,
|
||||
template: str,
|
||||
variables: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
self,
|
||||
template: str,
|
||||
variables: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> str:
|
||||
"""渲染模板
|
||||
|
||||
@@ -69,40 +69,40 @@ class TemplateRenderer:
|
||||
# variables 的结构:{"sys": {...}, "conv": {...}}
|
||||
sys_vars = variables.get("sys", {}) if isinstance(variables, dict) else {}
|
||||
conv_vars = variables.get("conv", {}) if isinstance(variables, dict) else {}
|
||||
|
||||
|
||||
context = {
|
||||
"conv": conv_vars, # 会话变量:{{conv.user_name}}
|
||||
"node": node_outputs, # 节点输出:{{node.node_1.output}}
|
||||
"conv": conv_vars, # 会话变量:{{conv.user_name}}
|
||||
"node": node_outputs, # 节点输出:{{node.node_1.output}}
|
||||
"sys": {**(system_vars or {}), **sys_vars}, # 系统变量:{{sys.execution_id}}(合并两个来源)
|
||||
}
|
||||
|
||||
|
||||
# 支持直接通过节点ID访问节点输出:{{llm_qa.output}}
|
||||
# 将所有节点输出添加到顶层上下文
|
||||
if node_outputs:
|
||||
context.update(node_outputs)
|
||||
|
||||
|
||||
# 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}}
|
||||
if conv_vars:
|
||||
context.update(conv_vars)
|
||||
|
||||
|
||||
context["nodes"] = node_outputs or {} # 旧语法兼容
|
||||
|
||||
|
||||
try:
|
||||
tmpl = self.env.from_string(template)
|
||||
return tmpl.render(**context)
|
||||
|
||||
|
||||
except TemplateSyntaxError as e:
|
||||
logger.error(f"模板语法错误: {template}, 错误: {e}")
|
||||
raise ValueError(f"模板语法错误: {e}")
|
||||
|
||||
|
||||
except UndefinedError as e:
|
||||
logger.error(f"模板中引用了未定义的变量: {template}, 错误: {e}")
|
||||
raise ValueError(f"未定义的变量: {e}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模板渲染异常: {template}, 错误: {e}")
|
||||
raise ValueError(f"模板渲染失败: {e}")
|
||||
|
||||
|
||||
def validate(self, template: str) -> list[str]:
|
||||
"""验证模板语法
|
||||
|
||||
@@ -121,14 +121,14 @@ class TemplateRenderer:
|
||||
['模板语法错误: ...']
|
||||
"""
|
||||
errors = []
|
||||
|
||||
|
||||
try:
|
||||
self.env.from_string(template)
|
||||
except TemplateSyntaxError as e:
|
||||
errors.append(f"模板语法错误: {e}")
|
||||
except Exception as e:
|
||||
errors.append(f"模板验证失败: {e}")
|
||||
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
@@ -137,14 +137,16 @@ _default_renderer = TemplateRenderer(strict=True)
|
||||
|
||||
|
||||
def render_template(
|
||||
template: str,
|
||||
variables: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
template: str,
|
||||
variables: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None,
|
||||
struct: bool = True
|
||||
) -> str:
|
||||
"""渲染模板(便捷函数)
|
||||
|
||||
Args:
|
||||
struct: 渲染模式
|
||||
template: 模板字符串
|
||||
variables: 用户变量
|
||||
node_outputs: 节点输出
|
||||
@@ -162,7 +164,8 @@ def render_template(
|
||||
... )
|
||||
'请分析: 这是一段文本'
|
||||
"""
|
||||
return _default_renderer.render(template, variables, node_outputs, system_vars)
|
||||
renderer = TemplateRenderer(strict=struct)
|
||||
return renderer.render(template, variables, node_outputs, system_vars)
|
||||
|
||||
|
||||
def validate_template(template: str) -> list[str]:
|
||||
|
||||
Reference in New Issue
Block a user