diff --git a/api/app/core/rag/prompts/generator.py b/api/app/core/rag/prompts/generator.py index fe928d8d..4838bf82 100644 --- a/api/app/core/rag/prompts/generator.py +++ b/api/app/core/rag/prompts/generator.py @@ -119,7 +119,7 @@ def keyword_extraction(chat_mdl, content, topn=3): rendered_prompt = template.render(content=content, topn=topn) msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}] - _, msg = message_fit_in(msg, chat_mdl.max_length) + _, msg = message_fit_in(msg, getattr(chat_mdl, 'max_length', 8096)) kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2}) if isinstance(kwd, tuple): kwd = kwd[0] @@ -194,7 +194,7 @@ def content_tagging(chat_mdl, content, all_tags, examples, topn=3): ) msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}] - _, msg = message_fit_in(msg, chat_mdl.max_length) + _, msg = message_fit_in(msg, getattr(chat_mdl, 'max_length', 8096)) kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.5}) if isinstance(kwd, tuple): kwd = kwd[0] @@ -314,7 +314,7 @@ def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defi hist[-1]["content"] += user_prompt else: hist.append({"role": "user", "content": user_prompt}) - _, msg = message_fit_in(hist, chat_mdl.max_length) + _, msg = message_fit_in(hist, getattr(chat_mdl, 'max_length', 8096)) ans = chat_mdl.chat(msg[0]["content"], msg[1:]) ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) return """ @@ -341,7 +341,7 @@ def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defin params=json.dumps(params, ensure_ascii=False, indent=2), result=result) user_prompt = "→ Summary: " - _, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length) + _, msg = message_fit_in(form_message(system_prompt, user_prompt), getattr(chat_mdl, 'max_length', 8096)) ans = chat_mdl.chat(msg[0]["content"], msg[1:]) return re.sub(r"^.*", "", ans, flags=re.DOTALL) @@ -350,7 +350,7 @@ def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[st template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY) system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)]) user_prompt = " → rank: " - _, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length) + _, msg = message_fit_in(form_message(system_prompt, user_prompt), getattr(chat_mdl, 'max_length', 8096)) ans = chat_mdl.chat(msg[0]["content"], msg[1:], stop="<|stop|>") return re.sub(r"^.*", "", ans, flags=re.DOTALL) @@ -378,7 +378,7 @@ def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None): cached = get_llm_cache(chat_mdl.llm_name, system_prompt, user_prompt, gen_conf) if cached: return json_repair.loads(cached) - _, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length) + _, msg = message_fit_in(form_message(system_prompt, user_prompt), getattr(chat_mdl, 'max_length', 8096)) ans = chat_mdl.chat(msg[0]["content"], msg[1:],gen_conf=gen_conf) ans = re.sub(r"(^.*|```json\n|```\n*$)", "", ans, flags=re.DOTALL) try: @@ -641,7 +641,7 @@ def split_chunks(chunks, max_length: int): async def run_toc_from_text(chunks, chat_mdl, callback=None): - input_budget = int(chat_mdl.max_length * INPUT_UTILIZATION) - num_tokens_from_string( + input_budget = int(getattr(chat_mdl, 'max_length', 8096) * INPUT_UTILIZATION) - num_tokens_from_string( TOC_FROM_TEXT_USER + TOC_FROM_TEXT_SYSTEM ) diff --git a/api/app/models/document_model.py b/api/app/models/document_model.py index 44012a56..a415bad8 100644 --- a/api/app/models/document_model.py +++ b/api/app/models/document_model.py @@ -16,7 +16,26 @@ class Document(Base): file_size = Column(Integer, default=0, comment="file size(byte)") file_meta = Column(JSON, nullable=False, default={}) parser_id = Column(String, index=True, nullable=False, comment="default parser ID") - parser_config = Column(JSON, nullable=False, default={"layout_recognize": "DeepDOC", "chunk_token_num": 128, "delimiter": "\n"}, comment="default parser config") + parser_config = Column(JSON, nullable=False, + default={ + "layout_recognize": "DeepDOC", + "chunk_token_num": 128, + "delimiter": "\n", + "auto_keywords": 0, + "auto_questions": 0, + "html4excel": False, + "graphrag": { + "use_graphrag": False, + "entity_types": [ + "organization", + "person", + "geo", + "event", + "category", + ], + "method": "general", + } + }, comment="default parser config") chunk_num = Column(Integer, default=0, comment="chunk num") progress = Column(Float, default=0) progress_msg = Column(String, default="", comment="process message") diff --git a/api/app/models/knowledge_model.py b/api/app/models/knowledge_model.py index 0587da53..e3c1ece1 100644 --- a/api/app/models/knowledge_model.py +++ b/api/app/models/knowledge_model.py @@ -56,7 +56,25 @@ class Knowledge(Base): chunk_num = Column(Integer, default=0, comment="chunk num") parser_id = Column(String, index=True, default="naive", comment="default parser ID") parser_config = Column(JSON, nullable=False, - default={"layout_recognize": "DeepDOC", "chunk_token_num": 128, "delimiter": "\n"}, + default={ + "layout_recognize": "DeepDOC", + "chunk_token_num": 128, + "delimiter": "\n", + "auto_keywords": 0, + "auto_questions": 0, + "html4excel": False, + "graphrag": { + "use_graphrag": False, + "entity_types": [ + "organization", + "person", + "geo", + "event", + "category", + ], + "method": "general", + } + }, comment="default parser config") status = Column(Integer, index=True, default=1, comment="is it validate(0: disable, 1: enable, 2:Soft-delete)") created_at = Column(DateTime, default=datetime.datetime.now)