diff --git a/api/app/core/rag/prompts/generator.py b/api/app/core/rag/prompts/generator.py index ef7f4474..0eab32cb 100644 --- a/api/app/core/rag/prompts/generator.py +++ b/api/app/core/rag/prompts/generator.py @@ -149,14 +149,12 @@ def qa_proposal(chat_mdl, content, topn=3, custom_prompt=None): """ if custom_prompt: template = PROMPT_JINJA_ENV.from_string(custom_prompt) - rendered_user = template.render(content=content, topn=topn) - msg = [{"role": "user", "content": rendered_user}] - sys_prompt = "" + sys_prompt = template.render(topn=topn) else: sys_prompt = QUESTION_PROMPT_TEMPLATE - msg = [{"role": "user", "content": f"## Text Content (topn: {topn})\n\n{content}"}] - _, msg = message_fit_in([{"role": "system", "content": sys_prompt}] + msg, getattr(chat_mdl, 'max_length', 8096)) - raw = chat_mdl.chat(sys_prompt, msg, {"temperature": 0.2}) + msg = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": content}] + _, msg = message_fit_in(msg, getattr(chat_mdl, 'max_length', 8096)) + raw = chat_mdl.chat(sys_prompt, msg[1:], {"temperature": 0.2}) if isinstance(raw, tuple): raw = raw[0] raw = re.sub(r"^.*", "", raw, flags=re.DOTALL)