Merge branch 'develop' of codeup.aliyun.com:redbearai/python/redbear-mem-open into develop
This commit is contained in:
@@ -149,6 +149,9 @@ class Settings:
|
||||
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
|
||||
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
|
||||
REFLECTION_INTERVAL_TIME:Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
||||
|
||||
# Memory Cache Regeneration Configuration
|
||||
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
|
||||
|
||||
# Memory Module Configuration (internal)
|
||||
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
|
||||
|
||||
@@ -18,17 +18,10 @@ from enum import Enum
|
||||
import uuid
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
from app.core.response_utils import success
|
||||
from app.repositories.neo4j.cypher_queries import neo4j_query_part, neo4j_statement_part, neo4j_query_all, neo4j_statement_all
|
||||
from app.repositories.neo4j.neo4j_update import neo4j_data
|
||||
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.memory.utils.config import get_model_config
|
||||
from app.core.memory.utils.config.get_data import get_data
|
||||
from app.core.memory.utils.config.get_data import get_data_statement
|
||||
from app.core.memory.utils.config.get_data import get_data,get_data_statement,extract_and_process_changes
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.prompt.template_render import render_evaluate_prompt
|
||||
from app.core.memory.utils.prompt.template_render import render_reflexion_prompt
|
||||
@@ -45,7 +38,6 @@ from app.repositories.neo4j.neo4j_update import neo4j_data
|
||||
from app.schemas.memory_storage_schema import ConflictResultSchema
|
||||
from app.schemas.memory_storage_schema import ReflexionResultSchema
|
||||
|
||||
|
||||
# 配置日志
|
||||
_root_logger = logging.getLogger()
|
||||
if not _root_logger.handlers:
|
||||
@@ -56,7 +48,9 @@ if not _root_logger.handlers:
|
||||
else:
|
||||
_root_logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class TranslationResponse(BaseModel):
|
||||
"""翻译响应模型"""
|
||||
data: str
|
||||
class ReflectionRange(str, Enum):
|
||||
"""反思范围枚举"""
|
||||
PARTIAL = "partial" # 从检索结果中反思
|
||||
@@ -84,6 +78,7 @@ class ReflectionConfig(BaseModel):
|
||||
memory_verify: bool = True # 记忆验证
|
||||
quality_assessment: bool = True # 质量评估
|
||||
violation_handling_strategy: str = "warn" # 违规处理策略
|
||||
language_type: str = "zh"
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
@@ -241,15 +236,12 @@ class ReflectionEngine:
|
||||
print(100 * '-')
|
||||
print(conflict_data)
|
||||
print(100 * '-')
|
||||
|
||||
# 检查是否真的有冲突
|
||||
has_conflict = conflict_data[0].get('conflict', False)
|
||||
conflicts_found = len(conflict_data[0]['data']) if has_conflict else 0
|
||||
logging.info(f"冲突状态: {has_conflict}, 发现 {conflicts_found} 个冲突")
|
||||
# # 检查是否真的有冲突
|
||||
conflicts_found=''
|
||||
|
||||
# 记录冲突数据
|
||||
await self._log_data("conflict", conflict_data)
|
||||
|
||||
conflicts_found=''
|
||||
# 3. 解决冲突
|
||||
solved_data = await self._resolve_conflicts(conflict_data, statement_databasets)
|
||||
if not solved_data:
|
||||
@@ -270,7 +262,7 @@ class ReflectionEngine:
|
||||
await self._log_data("solved_data", solved_data)
|
||||
|
||||
# 4. 应用反思结果(更新记忆库)
|
||||
memories_updated = await self._apply_reflection_results(solved_data)
|
||||
memories_updated=await self._apply_reflection_results(solved_data)
|
||||
|
||||
execution_time = asyncio.get_event_loop().time() - start_time
|
||||
|
||||
@@ -294,9 +286,60 @@ class ReflectionEngine:
|
||||
execution_time=asyncio.get_event_loop().time() - start_time
|
||||
)
|
||||
|
||||
async def Translate(self, text):
|
||||
# 翻译中文为英文
|
||||
translation_messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{text}\n\n中文翻译为英文,输出格式为{{\"data\":\"翻译后的内容\"}}"
|
||||
}
|
||||
]
|
||||
|
||||
response = await self.llm_client.response_structured(
|
||||
messages=translation_messages,
|
||||
response_model=TranslationResponse
|
||||
)
|
||||
return response.data
|
||||
async def extract_translation(self,data):
|
||||
end_datas={}
|
||||
end_datas['source_data']=await self.Translate(data['source_data'])
|
||||
quality_assessments = []
|
||||
memory_verifies = []
|
||||
reflexion_data=[]
|
||||
if data['memory_verifies']!=[]:
|
||||
for i in data['memory_verifies']:
|
||||
end_data={}
|
||||
end_data['has_privacy'] = i['has_privacy']
|
||||
privacy=i['privacy_types']
|
||||
privacy_types_=[]
|
||||
for pri in privacy:
|
||||
privacy_types_.append(await self.Translate(pri))
|
||||
end_data['privacy_types']=privacy_types_
|
||||
end_data['summary']=await self.Translate(i['summary'])
|
||||
memory_verifies.append(end_data)
|
||||
end_datas['memory_verifies']=memory_verifies
|
||||
|
||||
if data['quality_assessments']!=[]:
|
||||
for i in data['quality_assessments']:
|
||||
end_data = {}
|
||||
end_data['score']=i['score']
|
||||
end_data['summary'] = await self.Translate(i['summary'])
|
||||
quality_assessments.append(end_data)
|
||||
end_datas['quality_assessments'] = quality_assessments
|
||||
for i in data['reflexion_data']:
|
||||
end_data = {}
|
||||
end_data['reason'] = await self.Translate(i['reason'])
|
||||
end_data['solution'] = await self.Translate(i['solution'])
|
||||
reflexion_data.append(end_data)
|
||||
end_datas['reflexion_data'] = reflexion_data
|
||||
return end_datas
|
||||
|
||||
async def reflection_run(self):
|
||||
self._lazy_init()
|
||||
start_time = time.time()
|
||||
memory_verifies_flag = self.config.memory_verify
|
||||
quality_assessment=self.config.quality_assessment
|
||||
language_type=self.config.language_type
|
||||
|
||||
asyncio.get_event_loop().time()
|
||||
logging.info("====== 自我反思流程开始 ======")
|
||||
@@ -305,9 +348,8 @@ class ReflectionEngine:
|
||||
|
||||
source_data, databasets = await self.extract_fields_from_json()
|
||||
result_data['baseline'] = self.config.baseline
|
||||
result_data[
|
||||
'source_data'] = "我是 2023 年春天去北京工作的,后来基本一直都在北京上班,也没怎么换过城市。不过后来公司调整,2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X,银行卡是 6222023847595898,这些一直没变。对了,其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合"
|
||||
|
||||
result_data['source_data'] = "我是 2023 年春天去北京工作的,后来基本一直都在北京上班,也没怎么换过城市。不过后来公司调整,2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X,银行卡是 6222023847595898,这些一直没变。对了,其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合"
|
||||
# 2. 检测冲突(基于事实的反思)
|
||||
conflict_data = await self._detect_conflicts(databasets, source_data)
|
||||
# 遍历数据提取字段
|
||||
@@ -316,8 +358,8 @@ class ReflectionEngine:
|
||||
for item in conflict_data:
|
||||
quality_assessments.append(item['quality_assessment'])
|
||||
memory_verifies.append(item['memory_verify'])
|
||||
result_data['quality_assessments'] = quality_assessments
|
||||
result_data['memory_verifies'] = memory_verifies
|
||||
result_data['quality_assessments'] = quality_assessments
|
||||
|
||||
# 检查是否真的有冲突
|
||||
has_conflict = conflict_data[0].get('conflict', False)
|
||||
@@ -335,8 +377,6 @@ class ReflectionEngine:
|
||||
'conflict': item['conflict']
|
||||
}
|
||||
cleaned_conflict_data.append(cleaned_item)
|
||||
print(cleaned_conflict_data)
|
||||
|
||||
# 3. 解决冲突
|
||||
solved_data = await self._resolve_conflicts(cleaned_conflict_data, source_data)
|
||||
if not solved_data:
|
||||
@@ -354,6 +394,14 @@ class ReflectionEngine:
|
||||
for result in item['results']:
|
||||
reflexion_data.append(result['reflexion'])
|
||||
result_data['reflexion_data'] = reflexion_data
|
||||
if memory_verifies_flag==False:
|
||||
result_data['memory_verifies']=[]
|
||||
if quality_assessment==False:
|
||||
result_data['quality_assessments']=[]
|
||||
|
||||
if language_type=='en':
|
||||
result_data=await self.extract_translation(result_data)
|
||||
print(time.time()-start_time,'----------')
|
||||
return result_data
|
||||
|
||||
|
||||
@@ -436,6 +484,7 @@ class ReflectionEngine:
|
||||
logging.info("====== 冲突检测开始 ======")
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
quality_assessment = self.config.quality_assessment
|
||||
language_type=self.config.language_type
|
||||
|
||||
try:
|
||||
# 渲染冲突检测提示词
|
||||
@@ -445,7 +494,8 @@ class ReflectionEngine:
|
||||
self.config.baseline,
|
||||
memory_verify,
|
||||
quality_assessment,
|
||||
statement_databasets
|
||||
statement_databasets,
|
||||
language_type
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
@@ -561,7 +611,8 @@ class ReflectionEngine:
|
||||
Returns:
|
||||
int: 成功更新的记忆数量
|
||||
"""
|
||||
success_count = await neo4j_data(solved_data)
|
||||
changes = extract_and_process_changes(solved_data)
|
||||
success_count = await neo4j_data(changes)
|
||||
return success_count
|
||||
|
||||
async def _log_data(self, label: str, data: Any) -> None:
|
||||
@@ -668,4 +719,8 @@ class ReflectionEngine:
|
||||
execution_time=time_result.execution_time + fact_result.execution_time
|
||||
)
|
||||
else:
|
||||
|
||||
raise ValueError(f"未知的反思基线: {self.config.baseline}")
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,20 @@ import uuid
|
||||
import logging
|
||||
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from openai import BaseModel
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from pydantic import model_validator, Field
|
||||
|
||||
from app.schemas.memory_storage_schema import SingleReflexionResultSchema
|
||||
from app.schemas.memory_storage_schema import ReflexionResultSchema
|
||||
from app.repositories.neo4j.neo4j_update import map_field_names
|
||||
# 添加项目根目录到 Python 路径
|
||||
sys.path.append(str(Path(__file__).parent))
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def _load_(data: List[Any]) -> List[Dict]:
|
||||
@@ -59,6 +73,14 @@ async def get_data(result):
|
||||
"""
|
||||
从数据库中获取数据
|
||||
"""
|
||||
EXCLUDE_FIELDS = {
|
||||
"user_id",
|
||||
"group_id",
|
||||
"entity_type",
|
||||
"connect_strength",
|
||||
"relationship_type",
|
||||
"apply_id"
|
||||
}
|
||||
neo4j_databasets=[]
|
||||
for item in result:
|
||||
filtered_item = {}
|
||||
@@ -73,14 +95,17 @@ async def get_data(result):
|
||||
rel_filtered['statement_id'] = value.get('statement_id')
|
||||
rel_filtered['expired_at'] = value.get('expired_at')
|
||||
rel_filtered['created_at'] = value.get('created_at')
|
||||
filtered_item[key] = rel_filtered
|
||||
filtered_item[key] = value
|
||||
elif key == 'entity2' and value is not None:
|
||||
# 过滤entity2的name_embedding字段
|
||||
entity2_filtered = {}
|
||||
if hasattr(value, 'items'):
|
||||
for e_key, e_value in value.items():
|
||||
if 'name_embedding' not in e_key.lower():
|
||||
entity2_filtered[e_key] = e_value
|
||||
if e_key in EXCLUDE_FIELDS:
|
||||
continue
|
||||
if 'name_embedding' in e_key.lower():
|
||||
continue
|
||||
entity2_filtered[e_key] = e_value
|
||||
filtered_item[key] = entity2_filtered
|
||||
else:
|
||||
filtered_item[key] = value
|
||||
@@ -94,8 +119,57 @@ async def get_data_statement( result):
|
||||
neo4j_databasets.append(i)
|
||||
return neo4j_databasets
|
||||
|
||||
class ReflexionResultSchema(BaseModel):
|
||||
"""Schema for the complete reflexion result data - a list of individual conflict resolutions."""
|
||||
results: List[SingleReflexionResultSchema] = Field(..., description="List of individual conflict resolution results, grouped by conflict type.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
def _normalize_resolved(cls, v):
|
||||
if isinstance(v, dict):
|
||||
conflict = v.get("conflict")
|
||||
if isinstance(conflict, dict) and conflict.get("conflict") is False:
|
||||
v["resolved"] = None
|
||||
else:
|
||||
resolved = v.get("resolved")
|
||||
if isinstance(resolved, dict):
|
||||
orig = resolved.get("original_memory_id")
|
||||
mem = resolved.get("resolved_memory")
|
||||
if orig is None and (mem is None or mem == {}):
|
||||
v["resolved"] = None
|
||||
return v
|
||||
def extract_and_process_changes(DATA):
|
||||
"""提取并处理 change 字段"""
|
||||
all_changes = []
|
||||
for i, item in enumerate(DATA):
|
||||
try:
|
||||
result = ReflexionResultSchema(**item)
|
||||
for j, res in enumerate(result.results):
|
||||
if res.resolved and res.resolved.change:
|
||||
for k, change in enumerate(res.resolved.change):
|
||||
change_data = {}
|
||||
for field_item in change.field:
|
||||
for key, value in field_item.items():
|
||||
change_data[key] = value
|
||||
if isinstance(value, list):
|
||||
print(f" - {key}: {value[0]} -> {value[1]}")
|
||||
else:
|
||||
print(f" - {key}: {value}")
|
||||
|
||||
all_changes.append({
|
||||
'data': change_data
|
||||
})
|
||||
|
||||
# 测试字段映射
|
||||
try:
|
||||
mapped = map_field_names(change_data)
|
||||
print(f" 映射结果: {mapped}")
|
||||
except Exception as e:
|
||||
print(f" 映射失败: {e}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理结果 {i + 1} 失败: {e}")
|
||||
|
||||
return all_changes
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
- **冲突类型**: {{ baseline }} (TIME/FACT/HYBRID)
|
||||
- **隐私审核**: {{ memory_verify }} (true/false)
|
||||
- **质量评估**: {{ quality_assessment }} (true/false)
|
||||
|
||||
- **语言类型**:{{language_type}}(zh/en)
|
||||
## 任务目标
|
||||
对用户记忆数据进行冲突检测、隐私审核和质量评估,输出结构化JSON结果。
|
||||
**数据关系**: statement_databasets中的statement_id对应evaluate_data中的记录,代表句子拆分后的实体关系。
|
||||
@@ -23,7 +23,7 @@
|
||||
- **身份冲突**: 同一实体被赋予不同类型或角色
|
||||
### 混合冲突
|
||||
检测所有逻辑不一致或相互矛盾的记录。
|
||||
**检测原则**:
|
||||
**检测原则**:
|
||||
- 重点检查相同实体的记录
|
||||
- 分析description字段语义冲突
|
||||
- 验证时间字段逻辑一致性
|
||||
@@ -54,7 +54,7 @@
|
||||
1. **conflict=true**: 存在冲突或隐私信息时,将所有相关记录放入data数组
|
||||
2. **conflict=false**: 无冲突且无隐私信息时,data为空数组
|
||||
3. **独立功能**: 冲突检测、隐私审核、质量评估三者完全独立
|
||||
4. **条件输出**:
|
||||
4. **条件输出**:
|
||||
- quality_assessment=true时输出评估对象,否则为null
|
||||
- memory_verify=true时输出隐私检测对象,否则为null
|
||||
5. **不输出conflict_memory字段**
|
||||
@@ -63,7 +63,6 @@
|
||||
2. 隐私审核(如启用) → 将隐私记录加入data
|
||||
3. 质量评估(如启用) → 独立输出评估结果
|
||||
4. 去重data数组中的记录
|
||||
|
||||
**输出结构**:
|
||||
```json
|
||||
{
|
||||
@@ -82,6 +81,8 @@
|
||||
```
|
||||
**字段说明**:
|
||||
- **data**: 包含冲突记录和隐私信息记录,无则为空数组
|
||||
- **quality_assessment**: quality_assessment=true时输出评估对象,否则为null
|
||||
- **quality_assessment**:
|
||||
quality_assessment=true时输出评估对象,否则为null(注意:- summary输出的结果不允许含有(expired_at设为2024-01-01T00:00:00Z)等原数据字段以及涉及需要修改的字段以及内容)
|
||||
- **memory_verify**: memory_verify=true时输出隐私检测对象,否则为null
|
||||
(注意:- summary输出的结果不允许含有(expired_at设为2024-01-01T00:00:00Z)等原数据字段以及涉及需要修改的字段以及内容)
|
||||
模式参考:{{ json_schema }}
|
||||
@@ -5,6 +5,7 @@
|
||||
- **原始句子**: {{ statement_databasets }}
|
||||
- **冲突类型**: {{ baseline }} (TIME/FACT/HYBRID)
|
||||
- **隐私审核**: {{ memory_verify }} (true/false)
|
||||
- **语言类型**:{{language_type}}(zh/en)
|
||||
|
||||
## 任务目标
|
||||
作为数据冲突解决专家,分析冲突原因,按类型分组处理,为每种冲突生成独立解决方案。
|
||||
@@ -61,7 +62,7 @@
|
||||
- 微信号: user123456 → use****3456
|
||||
- 邮箱: zhang.san@example.com → zha****@example.com
|
||||
|
||||
**脱敏字段**: name、entity1_name、entity2_name、description
|
||||
**脱敏字段**: name、entity1_name、entity2_name、description、relationship
|
||||
|
||||
## 4. 处理流程
|
||||
|
||||
@@ -97,21 +98,11 @@
|
||||
|
||||
### 处理规则
|
||||
|
||||
**情况1: 正确答案存在于data中**
|
||||
- 保留正确记录不变
|
||||
- 时间冲突: 修改错误记录的expired_at为当前时间(2025-12-16T12:00:00)
|
||||
- 事实冲突: 同样处理
|
||||
- resolved.resolved_memory只包含被设为失效的错误记录
|
||||
- change字段只记录expired_at变更: `[{"expired_at": "2025-12-16T12:00:00"}]`
|
||||
- 注意: 如果已存在时间则不需要修改
|
||||
|
||||
**情况2: 正确答案不存在于data中**
|
||||
- 选择最合适记录进行修改
|
||||
- 更新相关字段:
|
||||
- description字段: 添加或修改描述信息{% if memory_verify %}(含隐私信息需脱敏){% endif %}
|
||||
- name字段: 修改名称字段{% if memory_verify %}(含隐私信息需脱敏){% endif %}
|
||||
- resolved.resolved_memory包含修改后的完整记录{% if memory_verify %}(已脱敏){% endif %}
|
||||
- change字段记录所有被修改字段{% if memory_verify %},包括脱敏变更{% endif %}
|
||||
** baseline是TIME
|
||||
-保留正确记录不变修改错误记录的expired_at为当前时间(2025-12-16T12:00:00),以及name需要修改成正确的
|
||||
** baseline不是TIME
|
||||
- 修改字段内容( name、entity1_name、entity2_name、description、relationship)字段内容是否正确,如果不正确,需要对这些字段的内容重新生成,则不需要修改expired_at字段,
|
||||
如果涉及到修改entity1_name/entity2_name字段的时候,同时也需要修改description字段,输出修改前和修改后的放入change里面的field
|
||||
|
||||
**核心原则**:
|
||||
- 只输出需要修改的记录
|
||||
@@ -120,14 +111,17 @@
|
||||
- 隐私保护优先: 所有输出记录必须完成隐私脱敏
|
||||
- 脱敏变更记录: 隐私脱敏变更也必须在change字段中记录{% endif %}
|
||||
- 不可修改数据: 数据被判定为正确时不可修改,无数据可输出时为空
|
||||
- 输出的结果reflexion字段中的reason字段和solution不允许含有(expired_at设为2024-01-01T00:00:00Z、memory_verify=true)等原数据字段以及涉及需要修改的字段以及内容
|
||||
|
||||
**变更记录格式**:
|
||||
```json
|
||||
"change": [
|
||||
{
|
||||
"field": [
|
||||
{"字段名1": "修改后的值1"},
|
||||
{"字段名2": "修改后的值2"}
|
||||
{"id":修改字段对应的ID}
|
||||
{"statement_id":需要修改的对象对应的statement_id}
|
||||
{"字段名1": ["修改前的值1","修改后的值1"]},
|
||||
{"字段名2": ["修改前的值2","修改后的值2"]}
|
||||
]
|
||||
}
|
||||
]
|
||||
@@ -149,7 +143,8 @@
|
||||
|
||||
**嵌套字段映射**(系统自动处理):
|
||||
- `entity2.name` → 自动映射为 `name`
|
||||
- `entity1.name` → 自动映射为 `name`
|
||||
- `entity1.name` → 自动映射为 `name`
|
||||
- `relationship` → 自动映射为 `statement`
|
||||
- `entity1.description` → 自动映射为 `description`
|
||||
- `entity2.description` → 自动映射为 `description`
|
||||
|
||||
@@ -188,5 +183,4 @@
|
||||
- **resolved.change**: 包含详细变更信息
|
||||
- 无需修改的冲突类型resolved为null
|
||||
- 与baseline不匹配的冲突类型不包含在results中
|
||||
|
||||
模式参考: {{ json_schema }}
|
||||
@@ -9,7 +9,8 @@ prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
|
||||
|
||||
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any],
|
||||
baseline: str = "TIME",
|
||||
memory_verify: bool = False,quality_assessment:bool = False,statement_databasets: List[str] = []) -> str:
|
||||
memory_verify: bool = False,quality_assessment:bool = False,
|
||||
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
|
||||
"""
|
||||
Renders the evaluate prompt using the evaluate_optimized.jinja2 template.
|
||||
|
||||
@@ -30,12 +31,13 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any
|
||||
baseline=baseline,
|
||||
memory_verify=memory_verify,
|
||||
quality_assessment=quality_assessment,
|
||||
statement_databasets=statement_databasets
|
||||
statement_databasets=statement_databasets,
|
||||
language_type=language_type
|
||||
)
|
||||
return rendered_prompt
|
||||
|
||||
async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any], baseline: str, memory_verify: bool = False,
|
||||
statement_databasets: List[str] = []) -> str:
|
||||
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
|
||||
"""
|
||||
Renders the reflexion prompt using the reflexion_optimized.jinja2 template.
|
||||
|
||||
@@ -51,6 +53,6 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any],
|
||||
|
||||
rendered_prompt = template.render(data=data, json_schema=schema,
|
||||
baseline=baseline,memory_verify=memory_verify,
|
||||
statement_databasets=statement_databasets)
|
||||
statement_databasets=statement_databasets,language_type=language_type)
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
Reference in New Issue
Block a user