feat: Add base project structure with API and web components
This commit is contained in:
0
api/app/core/rag/llm/__init__.py
Normal file
0
api/app/core/rag/llm/__init__.py
Normal file
670
api/app/core/rag/llm/chat_model.py
Normal file
670
api/app/core/rag/llm/chat_model.py
Normal file
@@ -0,0 +1,670 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from abc import ABC
|
||||
from copy import deepcopy
|
||||
from typing import Any, Protocol
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import json_repair
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
from strenum import StrEnum
|
||||
|
||||
from app.core.rag.nlp import is_chinese, is_english
|
||||
from app.core.rag.common.token_utils import num_tokens_from_string, total_token_count_from_response
|
||||
|
||||
|
||||
# Error message constants
|
||||
class LLMErrorCode(StrEnum):
|
||||
ERROR_RATE_LIMIT = "RATE_LIMIT_EXCEEDED"
|
||||
ERROR_AUTHENTICATION = "AUTH_ERROR"
|
||||
ERROR_INVALID_REQUEST = "INVALID_REQUEST"
|
||||
ERROR_SERVER = "SERVER_ERROR"
|
||||
ERROR_TIMEOUT = "TIMEOUT"
|
||||
ERROR_CONNECTION = "CONNECTION_ERROR"
|
||||
ERROR_MODEL = "MODEL_ERROR"
|
||||
ERROR_MAX_ROUNDS = "ERROR_MAX_ROUNDS"
|
||||
ERROR_CONTENT_FILTER = "CONTENT_FILTERED"
|
||||
ERROR_QUOTA = "QUOTA_EXCEEDED"
|
||||
ERROR_MAX_RETRIES = "MAX_RETRIES_EXCEEDED"
|
||||
ERROR_GENERIC = "GENERIC_ERROR"
|
||||
|
||||
|
||||
class ReActMode(StrEnum):
|
||||
FUNCTION_CALL = "function_call"
|
||||
REACT = "react"
|
||||
|
||||
|
||||
ERROR_PREFIX = "**ERROR**"
|
||||
LENGTH_NOTIFICATION_CN = "······\n由于大模型的上下文窗口大小限制,回答已经被大模型截断。"
|
||||
LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to its limitation on context length."
|
||||
|
||||
|
||||
class ToolCallSession(Protocol):
|
||||
def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ...
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name, base_url, **kwargs):
|
||||
timeout = int(os.environ.get("LLM_TIMEOUT_SECONDS", 600))
|
||||
self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout)
|
||||
self.model_name = model_name
|
||||
# Configure retry parameters
|
||||
self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
|
||||
self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0)))
|
||||
self.max_rounds = kwargs.get("max_rounds", 5)
|
||||
self.is_tools = False
|
||||
self.tools = []
|
||||
self.toolcall_sessions = {}
|
||||
|
||||
def _get_delay(self):
|
||||
"""Calculate retry delay time"""
|
||||
return self.base_delay * random.uniform(10, 150)
|
||||
|
||||
def _classify_error(self, error):
|
||||
"""Classify error based on error message content"""
|
||||
error_str = str(error).lower()
|
||||
|
||||
keywords_mapping = [
|
||||
(["quota", "capacity", "credit", "billing", "balance", "欠费"], LLMErrorCode.ERROR_QUOTA),
|
||||
(["rate limit", "429", "tpm limit", "too many requests", "requests per minute"], LLMErrorCode.ERROR_RATE_LIMIT),
|
||||
(["auth", "key", "apikey", "401", "forbidden", "permission"], LLMErrorCode.ERROR_AUTHENTICATION),
|
||||
(["invalid", "bad request", "400", "format", "malformed", "parameter"], LLMErrorCode.ERROR_INVALID_REQUEST),
|
||||
(["server", "503", "502", "504", "500", "unavailable"], LLMErrorCode.ERROR_SERVER),
|
||||
(["timeout", "timed out"], LLMErrorCode.ERROR_TIMEOUT),
|
||||
(["connect", "network", "unreachable", "dns"], LLMErrorCode.ERROR_CONNECTION),
|
||||
(["filter", "content", "policy", "blocked", "safety", "inappropriate"], LLMErrorCode.ERROR_CONTENT_FILTER),
|
||||
(["model", "not found", "does not exist", "not available"], LLMErrorCode.ERROR_MODEL),
|
||||
(["max rounds"], LLMErrorCode.ERROR_MODEL),
|
||||
]
|
||||
for words, code in keywords_mapping:
|
||||
if re.search("({})".format("|".join(words)), error_str):
|
||||
return code
|
||||
|
||||
return LLMErrorCode.ERROR_GENERIC
|
||||
|
||||
def _clean_conf(self, gen_conf):
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
|
||||
allowed_conf = {
|
||||
"temperature",
|
||||
"max_completion_tokens",
|
||||
"top_p",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"stop",
|
||||
"n",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"functions",
|
||||
"function_call",
|
||||
"logit_bias",
|
||||
"user",
|
||||
"response_format",
|
||||
"seed",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"extra_headers"
|
||||
}
|
||||
|
||||
gen_conf = {k: v for k, v in gen_conf.items() if k in allowed_conf}
|
||||
|
||||
return gen_conf
|
||||
|
||||
def _chat(self, history, gen_conf, **kwargs):
|
||||
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
||||
if self.model_name.lower().find("qwq") >= 0:
|
||||
logging.info(f"[INFO] {self.model_name} detected as reasoning model, using _chat_streamly")
|
||||
|
||||
final_ans = ""
|
||||
tol_token = 0
|
||||
for delta, tol in self._chat_streamly(history, gen_conf, with_reasoning=False, **kwargs):
|
||||
if delta.startswith("<think>") or delta.endswith("</think>"):
|
||||
continue
|
||||
final_ans += delta
|
||||
tol_token = tol
|
||||
|
||||
if len(final_ans.strip()) == 0:
|
||||
final_ans = "**ERROR**: Empty response from reasoning model"
|
||||
|
||||
return final_ans.strip(), tol_token
|
||||
|
||||
if self.model_name.lower().find("qwen3") >= 0:
|
||||
kwargs["extra_body"] = {"enable_thinking": False}
|
||||
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
|
||||
|
||||
if not response.choices or not response.choices[0].message or not response.choices[0].message.content:
|
||||
return "", 0
|
||||
ans = response.choices[0].message.content.strip()
|
||||
if response.choices[0].finish_reason == "length":
|
||||
ans = self._length_stop(ans)
|
||||
return ans, total_token_count_from_response(response)
|
||||
|
||||
def _chat_streamly(self, history, gen_conf, **kwargs):
|
||||
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
||||
reasoning_start = False
|
||||
|
||||
if kwargs.get("stop") or "stop" in gen_conf:
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf, stop=kwargs.get("stop"))
|
||||
else:
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
||||
|
||||
for resp in response:
|
||||
if not resp.choices:
|
||||
continue
|
||||
if not resp.choices[0].delta.content:
|
||||
resp.choices[0].delta.content = ""
|
||||
if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
|
||||
ans = ""
|
||||
if not reasoning_start:
|
||||
reasoning_start = True
|
||||
ans = "<think>"
|
||||
ans += resp.choices[0].delta.reasoning_content + "</think>"
|
||||
else:
|
||||
reasoning_start = False
|
||||
ans = resp.choices[0].delta.content
|
||||
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
tol = num_tokens_from_string(resp.choices[0].delta.content)
|
||||
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
if is_chinese(ans):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
yield ans, tol
|
||||
|
||||
def _length_stop(self, ans):
|
||||
if is_chinese([ans]):
|
||||
return ans + LENGTH_NOTIFICATION_CN
|
||||
return ans + LENGTH_NOTIFICATION_EN
|
||||
|
||||
@property
|
||||
def _retryable_errors(self) -> set[str]:
|
||||
return {
|
||||
LLMErrorCode.ERROR_RATE_LIMIT,
|
||||
LLMErrorCode.ERROR_SERVER,
|
||||
}
|
||||
|
||||
def _should_retry(self, error_code: str) -> bool:
|
||||
return error_code in self._retryable_errors
|
||||
|
||||
def _exceptions(self, e, attempt) -> str | None:
|
||||
logging.exception("OpenAI chat_with_tools")
|
||||
# Classify the error
|
||||
error_code = self._classify_error(e)
|
||||
if attempt == self.max_retries:
|
||||
error_code = LLMErrorCode.ERROR_MAX_RETRIES
|
||||
|
||||
if self._should_retry(error_code):
|
||||
delay = self._get_delay()
|
||||
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
||||
time.sleep(delay)
|
||||
return None
|
||||
|
||||
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
||||
|
||||
def _verbose_tool_use(self, name, args, res):
|
||||
return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>"
|
||||
|
||||
def _append_history(self, hist, tool_call, tool_res):
|
||||
hist.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": tool_call.index,
|
||||
"id": tool_call.id,
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments,
|
||||
},
|
||||
"type": "function",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
try:
|
||||
if isinstance(tool_res, dict):
|
||||
tool_res = json.dumps(tool_res, ensure_ascii=False)
|
||||
finally:
|
||||
hist.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_res)})
|
||||
return hist
|
||||
|
||||
def bind_tools(self, toolcall_session, tools):
|
||||
if not (toolcall_session and tools):
|
||||
return
|
||||
self.is_tools = True
|
||||
self.toolcall_session = toolcall_session
|
||||
self.tools = tools
|
||||
|
||||
def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
if system and history and history[0].get("role") != "system":
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
hist = deepcopy(history)
|
||||
# Implement exponential backoff retry strategy
|
||||
for attempt in range(self.max_retries + 1):
|
||||
history = hist
|
||||
try:
|
||||
for _ in range(self.max_rounds + 1):
|
||||
logging.info(f"{self.tools=}")
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
|
||||
tk_count += total_token_count_from_response(response)
|
||||
if any([not response.choices, not response.choices[0].message]):
|
||||
raise Exception(f"500 response structure error. Response: {response}")
|
||||
|
||||
if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
|
||||
if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
|
||||
ans += "<think>" + response.choices[0].message.reasoning_content + "</think>"
|
||||
|
||||
ans += response.choices[0].message.content
|
||||
if response.choices[0].finish_reason == "length":
|
||||
ans = self._length_stop(ans)
|
||||
|
||||
return ans, tk_count
|
||||
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
logging.info(f"Response {tool_call=}")
|
||||
name = tool_call.function.name
|
||||
try:
|
||||
args = json_repair.loads(tool_call.function.arguments)
|
||||
tool_response = self.toolcall_session.tool_call(name, args)
|
||||
history = self._append_history(history, tool_call, tool_response)
|
||||
ans += self._verbose_tool_use(name, args, tool_response)
|
||||
except Exception as e:
|
||||
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
|
||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
||||
ans += self._verbose_tool_use(name, {}, str(e))
|
||||
|
||||
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
||||
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
||||
response, token_count = self._chat(history, gen_conf)
|
||||
ans += response
|
||||
tk_count += token_count
|
||||
return ans, tk_count
|
||||
except Exception as e:
|
||||
e = self._exceptions(e, attempt)
|
||||
if e:
|
||||
return e, tk_count
|
||||
|
||||
assert False, "Shouldn't be here."
|
||||
|
||||
def chat(self, system, history, gen_conf={}, **kwargs):
|
||||
if system and history and history[0].get("role") != "system":
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
|
||||
# Implement exponential backoff retry strategy
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return self._chat(history, gen_conf, **kwargs)
|
||||
except Exception as e:
|
||||
e = self._exceptions(e, attempt)
|
||||
if e:
|
||||
return e, 0
|
||||
assert False, "Shouldn't be here."
|
||||
|
||||
def _wrap_toolcall_message(self, stream):
|
||||
final_tool_calls = {}
|
||||
|
||||
for chunk in stream:
|
||||
for tool_call in chunk.choices[0].delta.tool_calls or []:
|
||||
index = tool_call.index
|
||||
|
||||
if index not in final_tool_calls:
|
||||
final_tool_calls[index] = tool_call
|
||||
|
||||
final_tool_calls[index].function.arguments += tool_call.function.arguments
|
||||
|
||||
return final_tool_calls
|
||||
|
||||
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
tools = self.tools
|
||||
if system and history and history[0].get("role") != "system":
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
|
||||
total_tokens = 0
|
||||
hist = deepcopy(history)
|
||||
# Implement exponential backoff retry strategy
|
||||
for attempt in range(self.max_retries + 1):
|
||||
history = hist
|
||||
try:
|
||||
for _ in range(self.max_rounds + 1):
|
||||
reasoning_start = False
|
||||
logging.info(f"{tools=}")
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
|
||||
final_tool_calls = {}
|
||||
answer = ""
|
||||
for resp in response:
|
||||
if resp.choices[0].delta.tool_calls:
|
||||
for tool_call in resp.choices[0].delta.tool_calls or []:
|
||||
index = tool_call.index
|
||||
|
||||
if index not in final_tool_calls:
|
||||
if not tool_call.function.arguments:
|
||||
tool_call.function.arguments = ""
|
||||
final_tool_calls[index] = tool_call
|
||||
else:
|
||||
final_tool_calls[index].function.arguments += tool_call.function.arguments if tool_call.function.arguments else ""
|
||||
continue
|
||||
|
||||
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
|
||||
raise Exception("500 response structure error.")
|
||||
|
||||
if not resp.choices[0].delta.content:
|
||||
resp.choices[0].delta.content = ""
|
||||
|
||||
if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
|
||||
ans = ""
|
||||
if not reasoning_start:
|
||||
reasoning_start = True
|
||||
ans = "<think>"
|
||||
ans += resp.choices[0].delta.reasoning_content + "</think>"
|
||||
yield ans
|
||||
else:
|
||||
reasoning_start = False
|
||||
answer += resp.choices[0].delta.content
|
||||
yield resp.choices[0].delta.content
|
||||
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
||||
else:
|
||||
total_tokens = tol
|
||||
|
||||
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
|
||||
if finish_reason == "length":
|
||||
yield self._length_stop("")
|
||||
|
||||
if answer:
|
||||
yield total_tokens
|
||||
return
|
||||
|
||||
for tool_call in final_tool_calls.values():
|
||||
name = tool_call.function.name
|
||||
try:
|
||||
args = json_repair.loads(tool_call.function.arguments)
|
||||
yield self._verbose_tool_use(name, args, "Begin to call...")
|
||||
tool_response = self.toolcall_session.tool_call(name, args)
|
||||
history = self._append_history(history, tool_call, tool_response)
|
||||
yield self._verbose_tool_use(name, args, tool_response)
|
||||
except Exception as e:
|
||||
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
|
||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
||||
yield self._verbose_tool_use(name, {}, str(e))
|
||||
|
||||
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
||||
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
||||
for resp in response:
|
||||
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
|
||||
raise Exception("500 response structure error.")
|
||||
if not resp.choices[0].delta.content:
|
||||
resp.choices[0].delta.content = ""
|
||||
continue
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
||||
else:
|
||||
total_tokens = tol
|
||||
answer += resp.choices[0].delta.content
|
||||
yield resp.choices[0].delta.content
|
||||
|
||||
yield total_tokens
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
e = self._exceptions(e, attempt)
|
||||
if e:
|
||||
yield e
|
||||
yield total_tokens
|
||||
return
|
||||
|
||||
assert False, "Shouldn't be here."
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
|
||||
if system and history and history[0].get("role") != "system":
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
try:
|
||||
for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs):
|
||||
yield delta_ans
|
||||
total_tokens += tol
|
||||
except openai.APIError as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield total_tokens
|
||||
|
||||
def _calculate_dynamic_ctx(self, history):
|
||||
"""Calculate dynamic context window size"""
|
||||
|
||||
def count_tokens(text):
|
||||
"""Calculate token count for text"""
|
||||
# Simple calculation: 1 token per ASCII character
|
||||
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
|
||||
total = 0
|
||||
for char in text:
|
||||
if ord(char) < 128: # ASCII characters
|
||||
total += 1
|
||||
else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.)
|
||||
total += 2
|
||||
return total
|
||||
|
||||
# Calculate total tokens for all messages
|
||||
total_tokens = 0
|
||||
for message in history:
|
||||
content = message.get("content", "")
|
||||
# Calculate content tokens
|
||||
content_tokens = count_tokens(content)
|
||||
# Add role marker token overhead
|
||||
role_tokens = 4
|
||||
total_tokens += content_tokens + role_tokens
|
||||
|
||||
# Apply 1.2x buffer ratio
|
||||
total_tokens_with_buffer = int(total_tokens * 1.2)
|
||||
|
||||
if total_tokens_with_buffer <= 8192:
|
||||
ctx_size = 8192
|
||||
else:
|
||||
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
|
||||
ctx_size = ctx_multiplier * 8192
|
||||
|
||||
return ctx_size
|
||||
|
||||
|
||||
class GptTurbo(Base):
|
||||
_FACTORY_NAME = "OpenAI"
|
||||
|
||||
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
|
||||
class XinferenceChat(Base):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
def __init__(self, key=None, model_name="", base_url="", **kwargs):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
|
||||
class HuggingFaceChat(Base):
|
||||
_FACTORY_NAME = "HuggingFace"
|
||||
|
||||
def __init__(self, key=None, model_name="", base_url="", **kwargs):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
super().__init__(key, model_name.split("___")[0], base_url, **kwargs)
|
||||
|
||||
|
||||
class ModelScopeChat(Base):
|
||||
_FACTORY_NAME = "ModelScope"
|
||||
|
||||
def __init__(self, key=None, model_name="", base_url="", **kwargs):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
super().__init__(key, model_name.split("___")[0], base_url, **kwargs)
|
||||
|
||||
|
||||
class AzureChat(Base):
|
||||
_FACTORY_NAME = "Azure-OpenAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url, **kwargs):
|
||||
api_key = json.loads(key).get("api_key", "")
|
||||
api_version = json.loads(key).get("api_version", "2024-02-01")
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
||||
self.model_name = model_name
|
||||
|
||||
@property
|
||||
def _retryable_errors(self) -> set[str]:
|
||||
return {
|
||||
LLMErrorCode.ERROR_RATE_LIMIT,
|
||||
LLMErrorCode.ERROR_SERVER,
|
||||
LLMErrorCode.ERROR_QUOTA,
|
||||
}
|
||||
|
||||
|
||||
class BaiChuanChat(Base):
|
||||
_FACTORY_NAME = "BaiChuan"
|
||||
|
||||
def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.baichuan-ai.com/v1"
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _format_params(params):
|
||||
return {
|
||||
"temperature": params.get("temperature", 0.3),
|
||||
"top_p": params.get("top_p", 0.85),
|
||||
}
|
||||
|
||||
def _clean_conf(self, gen_conf):
|
||||
return {
|
||||
"temperature": gen_conf.get("temperature", 0.3),
|
||||
"top_p": gen_conf.get("top_p", 0.85),
|
||||
}
|
||||
|
||||
def _chat(self, history, gen_conf={}, **kwargs):
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]},
|
||||
**gen_conf,
|
||||
)
|
||||
ans = response.choices[0].message.content.strip()
|
||||
if response.choices[0].finish_reason == "length":
|
||||
if is_chinese([ans]):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
return ans, total_token_count_from_response(response)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
if system and history and history[0].get("role") != "system":
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]},
|
||||
stream=True,
|
||||
**self._format_params(gen_conf),
|
||||
)
|
||||
for resp in response:
|
||||
if not resp.choices:
|
||||
continue
|
||||
if not resp.choices[0].delta.content:
|
||||
resp.choices[0].delta.content = ""
|
||||
ans = resp.choices[0].delta.content
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
||||
else:
|
||||
total_tokens = tol
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
if is_chinese([ans]):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
yield ans
|
||||
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield total_tokens
|
||||
|
||||
|
||||
class LocalAIChat(Base):
|
||||
_FACTORY_NAME = "LocalAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key="empty", base_url=base_url)
|
||||
self.model_name = model_name.split("___")[0]
|
||||
|
||||
|
||||
class VolcEngineChat(Base):
|
||||
_FACTORY_NAME = "VolcEngine"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3", **kwargs):
|
||||
"""
|
||||
Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
|
||||
Assemble ark_api_key, ep_id into api_key, store it as a dictionary type, and parse it for use
|
||||
model_name is for display only
|
||||
"""
|
||||
base_url = base_url if base_url else "https://ark.cn-beijing.volces.com/api/v3"
|
||||
ark_api_key = json.loads(key).get("ark_api_key", "")
|
||||
model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
|
||||
super().__init__(ark_api_key, model_name, base_url, **kwargs)
|
||||
|
||||
|
||||
class OpenAI_APIChat(Base):
|
||||
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
|
||||
|
||||
def __init__(self, key, model_name, base_url, **kwargs):
|
||||
if not base_url:
|
||||
raise ValueError("url cannot be None")
|
||||
model_name = model_name.split("___")[0]
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
|
||||
class GPUStackChat(Base):
|
||||
_FACTORY_NAME = "GPUStack"
|
||||
|
||||
def __init__(self, key=None, model_name="", base_url="", **kwargs):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
470
api/app/core/rag/llm/cv_model.py
Normal file
470
api/app/core/rag/llm/cv_model.py
Normal file
@@ -0,0 +1,470 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import logging
|
||||
from abc import ABC
|
||||
from copy import deepcopy
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from urllib.parse import urljoin
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
from app.core.rag.nlp import is_english
|
||||
from app.core.rag.prompts.generator import vision_llm_describe_prompt
|
||||
from app.core.rag.common.token_utils import num_tokens_from_string, total_token_count_from_response
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, **kwargs):
|
||||
# Configure retry parameters
|
||||
self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
|
||||
self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0)))
|
||||
self.max_rounds = kwargs.get("max_rounds", 5)
|
||||
self.is_tools = False
|
||||
self.tools = []
|
||||
self.toolcall_sessions = {}
|
||||
self.extra_body = None
|
||||
|
||||
def describe(self, image):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
def _form_history(self, system, history, images=None):
|
||||
hist = []
|
||||
if system:
|
||||
hist.append({"role": "system", "content": system})
|
||||
for h in history:
|
||||
if images and h["role"] == "user":
|
||||
h["content"] = self._image_prompt(h["content"], images)
|
||||
images = []
|
||||
hist.append(h)
|
||||
return hist
|
||||
|
||||
def _image_prompt(self, text, images):
|
||||
if not images:
|
||||
return text
|
||||
|
||||
if isinstance(images, str) or "bytes" in type(images).__name__:
|
||||
images = [images]
|
||||
|
||||
pmpt = [{"type": "text", "text": text}]
|
||||
for img in images:
|
||||
pmpt.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": img if isinstance(img, str) and img.startswith("data:") else f"data:image/png;base64,{img}"
|
||||
}
|
||||
})
|
||||
return pmpt
|
||||
|
||||
def chat(self, system, history, gen_conf, images=None, **kwargs):
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=self._form_history(system, history, images),
|
||||
extra_body=self.extra_body,
|
||||
)
|
||||
return response.choices[0].message.content.strip(), response.usage.total_tokens
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=self._form_history(system, history, images),
|
||||
stream=True,
|
||||
extra_body=self.extra_body,
|
||||
)
|
||||
for resp in response:
|
||||
if not resp.choices[0].delta.content:
|
||||
continue
|
||||
delta = resp.choices[0].delta.content
|
||||
ans = delta
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
if resp.choices[0].finish_reason == "stop":
|
||||
tk_count += resp.usage.total_tokens
|
||||
yield ans
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield tk_count
|
||||
|
||||
@staticmethod
|
||||
def image2base64_rawvalue(self, image):
|
||||
# Return a base64 string without data URL header
|
||||
if isinstance(image, bytes):
|
||||
b64 = base64.b64encode(image).decode("utf-8")
|
||||
return b64
|
||||
if isinstance(image, BytesIO):
|
||||
data = image.getvalue()
|
||||
b64 = base64.b64encode(data).decode("utf-8")
|
||||
return b64
|
||||
with BytesIO() as buffered:
|
||||
try:
|
||||
image.save(buffered, format="JPEG")
|
||||
except Exception:
|
||||
# reset buffer before saving PNG
|
||||
buffered.seek(0)
|
||||
buffered.truncate()
|
||||
image.save(buffered, format="PNG")
|
||||
data = buffered.getvalue()
|
||||
b64 = base64.b64encode(data).decode("utf-8")
|
||||
return b64
|
||||
|
||||
@staticmethod
|
||||
def image2base64(image):
|
||||
# Return a data URL with the correct MIME to avoid provider mismatches
|
||||
if isinstance(image, bytes):
|
||||
# Best-effort magic number sniffing
|
||||
mime = "image/png"
|
||||
if len(image) >= 2 and image[0] == 0xFF and image[1] == 0xD8:
|
||||
mime = "image/jpeg"
|
||||
b64 = base64.b64encode(image).decode("utf-8")
|
||||
return f"data:{mime};base64,{b64}"
|
||||
if isinstance(image, BytesIO):
|
||||
data = image.getvalue()
|
||||
mime = "image/png"
|
||||
if len(data) >= 2 and data[0] == 0xFF and data[1] == 0xD8:
|
||||
mime = "image/jpeg"
|
||||
b64 = base64.b64encode(data).decode("utf-8")
|
||||
return f"data:{mime};base64,{b64}"
|
||||
with BytesIO() as buffered:
|
||||
fmt = "jpeg"
|
||||
try:
|
||||
image.save(buffered, format="JPEG")
|
||||
except Exception:
|
||||
# reset buffer before saving PNG
|
||||
buffered.seek(0)
|
||||
buffered.truncate()
|
||||
image.save(buffered, format="PNG")
|
||||
fmt = "png"
|
||||
data = buffered.getvalue()
|
||||
b64 = base64.b64encode(data).decode("utf-8")
|
||||
mime = f"image/{fmt}"
|
||||
return f"data:{mime};base64,{b64}"
|
||||
|
||||
def prompt(self, b64):
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": self._image_prompt(
|
||||
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
||||
if self.lang.lower() == "chinese"
|
||||
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
||||
b64
|
||||
)
|
||||
}
|
||||
]
|
||||
|
||||
def vision_llm_prompt(self, b64, prompt=None):
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": self._image_prompt(prompt if prompt else vision_llm_describe_prompt(), b64)
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
class GptV4(Base):
|
||||
_FACTORY_NAME = "OpenAI"
|
||||
|
||||
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
self.api_key = key
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def describe(self, image):
|
||||
b64 = self.image2base64(image)
|
||||
res = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=self.prompt(b64),
|
||||
extra_body=self.extra_body,
|
||||
)
|
||||
return res.choices[0].message.content.strip(), total_token_count_from_response(res)
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
b64 = self.image2base64(image)
|
||||
res = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=self.vision_llm_prompt(b64, prompt),
|
||||
extra_body=self.extra_body,
|
||||
)
|
||||
return res.choices[0].message.content.strip(),total_token_count_from_response(res)
|
||||
|
||||
|
||||
class AzureGptV4(GptV4):
|
||||
_FACTORY_NAME = "Azure-OpenAI"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||
api_key = json.loads(key).get("api_key", "")
|
||||
api_version = json.loads(key).get("api_version", "2024-02-01")
|
||||
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
|
||||
class QWenCV(GptV4):
|
||||
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||
|
||||
def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", base_url=None, **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
super().__init__(key, model_name, lang=lang, base_url=base_url, **kwargs)
|
||||
|
||||
def chat(self, system, history, gen_conf, images=None, video_bytes=None, filename="", **kwargs):
|
||||
if video_bytes:
|
||||
try:
|
||||
summary, summary_num_tokens = self._process_video(video_bytes, filename)
|
||||
return summary, summary_num_tokens
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
return "**ERROR**: Method chat not supported yet.", 0
|
||||
|
||||
def _process_video(self, video_bytes, filename):
|
||||
from dashscope import MultiModalConversation
|
||||
|
||||
video_suffix = Path(filename).suffix or ".mp4"
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=video_suffix) as tmp:
|
||||
tmp.write(video_bytes)
|
||||
tmp_path = tmp.name
|
||||
|
||||
video_path = f"file://{tmp_path}"
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"video": video_path,
|
||||
"fps": 2,
|
||||
},
|
||||
{
|
||||
"text": "Please summarize this video in proper sentences.",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
def call_api():
|
||||
response = MultiModalConversation.call(
|
||||
api_key=self.api_key,
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
)
|
||||
summary = response["output"]["choices"][0]["message"].content[0]["text"]
|
||||
return summary, num_tokens_from_string(summary)
|
||||
|
||||
try:
|
||||
return call_api()
|
||||
except Exception as e1:
|
||||
import dashscope
|
||||
|
||||
dashscope.base_http_api_url = "https://dashscope-intl.aliyuncs.com/api/v1"
|
||||
try:
|
||||
return call_api()
|
||||
except Exception as e2:
|
||||
raise RuntimeError(f"Both default and intl endpoint failed.\nFirst error: {e1}\nSecond error: {e2}")
|
||||
|
||||
|
||||
class XinferenceCV(GptV4):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
def __init__(self, key, model_name="", lang="Chinese", base_url="", **kwargs):
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
|
||||
class GPUStackCV(GptV4):
|
||||
_FACTORY_NAME = "GPUStack"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url="", **kwargs):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
|
||||
class OllamaCV(Base):
|
||||
_FACTORY_NAME = "Ollama"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||
from ollama import Client
|
||||
self.client = Client(host=kwargs["base_url"])
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
self.keep_alive = kwargs.get("ollama_keep_alive", int(os.environ.get("OLLAMA_KEEP_ALIVE", -1)))
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
|
||||
def _clean_img(self, img):
|
||||
if not isinstance(img, str):
|
||||
return img
|
||||
|
||||
#remove the header like "data/*;base64,"
|
||||
if img.startswith("data:") and ";base64," in img:
|
||||
img = img.split(";base64,")[1]
|
||||
return img
|
||||
|
||||
def _clean_conf(self, gen_conf):
|
||||
options = {}
|
||||
if "temperature" in gen_conf:
|
||||
options["temperature"] = gen_conf["temperature"]
|
||||
if "top_p" in gen_conf:
|
||||
options["top_k"] = gen_conf["top_p"]
|
||||
if "presence_penalty" in gen_conf:
|
||||
options["presence_penalty"] = gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
||||
return options
|
||||
|
||||
def _form_history(self, system, history, images=None):
|
||||
hist = deepcopy(history)
|
||||
if system and hist[0]["role"] == "user":
|
||||
hist.insert(0, {"role": "system", "content": system})
|
||||
if not images:
|
||||
return hist
|
||||
temp_images = []
|
||||
for img in images:
|
||||
temp_images.append(self._clean_img(img))
|
||||
for his in hist:
|
||||
if his["role"] == "user":
|
||||
his["images"] = temp_images
|
||||
break
|
||||
return hist
|
||||
|
||||
def describe(self, image):
|
||||
prompt = self.prompt("")
|
||||
try:
|
||||
response = self.client.generate(
|
||||
model=self.model_name,
|
||||
prompt=prompt[0]["content"],
|
||||
images=[image],
|
||||
)
|
||||
ans = response["response"].strip()
|
||||
return ans, 128
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
vision_prompt = self.vision_llm_prompt("", prompt) if prompt else self.vision_llm_prompt("")
|
||||
try:
|
||||
response = self.client.generate(
|
||||
model=self.model_name,
|
||||
prompt=vision_prompt[0]["content"],
|
||||
images=[image],
|
||||
)
|
||||
ans = response["response"].strip()
|
||||
return ans, 128
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat(self, system, history, gen_conf, images=None, **kwargs):
|
||||
try:
|
||||
response = self.client.chat(
|
||||
model=self.model_name,
|
||||
messages=self._form_history(system, history, images),
|
||||
options=self._clean_conf(gen_conf),
|
||||
keep_alive=self.keep_alive
|
||||
)
|
||||
|
||||
ans = response["message"]["content"].strip()
|
||||
return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
||||
ans = ""
|
||||
try:
|
||||
response = self.client.chat(
|
||||
model=self.model_name,
|
||||
messages=self._form_history(system, history, images),
|
||||
stream=True,
|
||||
options=self._clean_conf(gen_conf),
|
||||
keep_alive=self.keep_alive
|
||||
)
|
||||
for resp in response:
|
||||
if resp["done"]:
|
||||
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
|
||||
ans = resp["message"]["content"]
|
||||
yield ans
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
yield 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# import sys
|
||||
# chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
|
||||
|
||||
# # 准备配置vision_model信息
|
||||
# azure_config = {
|
||||
# "api_key": "xxxxx",
|
||||
# "api_version": "2024-02-01"
|
||||
# }
|
||||
# # 转换为 JSON 字符串,因为类中使用 json.loads(key) 解析
|
||||
# key = json.dumps(azure_config)
|
||||
# # 初始化 AzureGptV4
|
||||
# vision_model = AzureGptV4(
|
||||
# key=key, # JSON 字符串形式的配置
|
||||
# model_name="gpt-4o",
|
||||
# lang="Chinese", # 默认使用中文
|
||||
# base_url="https://fosun-openai-east-us-001.openai.azure.com/" # Azure OpenAI 端点
|
||||
# )
|
||||
# try:
|
||||
# # 测试图像描述功能
|
||||
# image_path = "/Users/sbtjfdn/Downloads/记忆科学/files/aippt.cn.png"
|
||||
# with open(image_path, "rb") as image_file:
|
||||
# image_data = image_file.read()
|
||||
#
|
||||
# # 使用 describe 方法
|
||||
# description, token_count = vision_model.describe(image_data)
|
||||
# # from app.core.rag.prompts.generator import vision_llm_figure_describe_prompt
|
||||
# # description, token_count = vision_model.describe_with_prompt(image_data, prompt=vision_llm_figure_describe_prompt())
|
||||
# print(f"描述: {description}")
|
||||
# print(f"使用的令牌数: {token_count}")
|
||||
#
|
||||
# except Exception as e:
|
||||
# print(f"初始化或处理过程中出错: {str(e)}")
|
||||
|
||||
# 准备配置vision_model信息
|
||||
# 初始化 QWenCV
|
||||
vision_model = QWenCV(
|
||||
key="sk-8e9e40cd171749858ce2d3722ea75669",
|
||||
model_name="qwen-vl-max",
|
||||
lang="Chinese", # 默认使用中文
|
||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
)
|
||||
try:
|
||||
# 测试图像描述功能
|
||||
image_path = "/Users/sbtjfdn/Downloads/记忆科学/files/10.png"
|
||||
with open(image_path, "rb") as image_file:
|
||||
image_data = image_file.read()
|
||||
|
||||
# 使用 describe 方法
|
||||
description, token_count = vision_model.describe(image_data)
|
||||
# from app.core.rag.prompts.generator import vision_llm_figure_describe_prompt
|
||||
# description, token_count = vision_model.describe_with_prompt(image_data, prompt=vision_llm_figure_describe_prompt())
|
||||
print(f"描述: {description}")
|
||||
print(f"使用的令牌数: {token_count}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"初始化或处理过程中出错: {str(e)}")
|
||||
179
api/app/core/rag/llm/sequence2txt_model.py
Normal file
179
api/app/core/rag/llm/sequence2txt_model.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from abc import ABC
|
||||
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
|
||||
from app.core.rag.common.token_utils import num_tokens_from_string
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
"""
|
||||
Abstract base class constructor.
|
||||
Parameters are not stored; initialization is left to subclasses.
|
||||
"""
|
||||
pass
|
||||
|
||||
def transcription(self, audio_path, **kwargs):
|
||||
audio_file = open(audio_path, "rb")
|
||||
transcription = self.client.audio.transcriptions.create(model=self.model_name, file=audio_file)
|
||||
return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
|
||||
|
||||
def audio2base64(self, audio):
|
||||
if isinstance(audio, bytes):
|
||||
return base64.b64encode(audio).decode("utf-8")
|
||||
if isinstance(audio, io.BytesIO):
|
||||
return base64.b64encode(audio.getvalue()).decode("utf-8")
|
||||
raise TypeError("The input audio file should be in binary format.")
|
||||
|
||||
|
||||
class GPTSeq2txt(Base):
|
||||
_FACTORY_NAME = "OpenAI"
|
||||
|
||||
def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
class QWenSeq2txt(Base):
|
||||
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||
|
||||
def __init__(self, key, model_name="qwen-audio-asr", **kwargs):
|
||||
import dashscope
|
||||
|
||||
dashscope.api_key = key
|
||||
self.model_name = model_name
|
||||
|
||||
def transcription(self, audio_path):
|
||||
if "paraformer" in self.model_name or "sensevoice" in self.model_name:
|
||||
return f"**ERROR**: model {self.model_name} is not suppported yet.", 0
|
||||
|
||||
from dashscope import MultiModalConversation
|
||||
|
||||
audio_path = f"file://{audio_path}"
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"audio": audio_path}],
|
||||
}
|
||||
]
|
||||
|
||||
response = None
|
||||
full_content = ""
|
||||
try:
|
||||
response = MultiModalConversation.call(model="qwen-audio-asr", messages=messages, result_format="message", stream=True)
|
||||
for response in response:
|
||||
try:
|
||||
full_content += response["output"]["choices"][0]["message"].content[0]["text"]
|
||||
except Exception:
|
||||
pass
|
||||
return full_content, num_tokens_from_string(full_content)
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
|
||||
class AzureSeq2txt(Base):
|
||||
_FACTORY_NAME = "Azure-OpenAI"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
|
||||
|
||||
class XinferenceSeq2txt(Base):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
def __init__(self, key, model_name="whisper-small", **kwargs):
|
||||
self.base_url = kwargs.get("base_url", None)
|
||||
self.model_name = model_name
|
||||
self.key = key
|
||||
|
||||
def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
|
||||
if isinstance(audio, str):
|
||||
audio_file = open(audio, "rb")
|
||||
audio_data = audio_file.read()
|
||||
audio_file_name = audio.split("/")[-1]
|
||||
else:
|
||||
audio_data = audio
|
||||
audio_file_name = "audio.wav"
|
||||
|
||||
payload = {"model": self.model_name, "language": language, "prompt": prompt, "response_format": response_format, "temperature": temperature}
|
||||
|
||||
files = {"file": (audio_file_name, audio_data, "audio/wav")}
|
||||
|
||||
try:
|
||||
response = requests.post(f"{self.base_url}/v1/audio/transcriptions", files=files, data=payload)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
if "text" in result:
|
||||
transcription_text = result["text"].strip()
|
||||
return transcription_text, num_tokens_from_string(transcription_text)
|
||||
else:
|
||||
return "**ERROR**: Failed to retrieve transcription.", 0
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
return f"**ERROR**: {str(e)}", 0
|
||||
|
||||
|
||||
class GPUStackSeq2txt(Base):
|
||||
_FACTORY_NAME = "GPUStack"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("url cannot be None")
|
||||
if base_url.split("/")[-1] != "v1":
|
||||
base_url = os.path.join(base_url, "v1")
|
||||
self.base_url = base_url
|
||||
self.model_name = model_name
|
||||
self.key = key
|
||||
|
||||
|
||||
class ZhipuSeq2txt(Base):
|
||||
_FACTORY_NAME = "ZHIPU-AI"
|
||||
|
||||
def __init__(self, key, model_name="glm-asr", base_url="https://open.bigmodel.cn/api/paas/v4", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://open.bigmodel.cn/api/paas/v4"
|
||||
self.base_url = base_url
|
||||
self.api_key = key
|
||||
self.model_name = model_name
|
||||
self.gen_conf = kwargs.get("gen_conf", {})
|
||||
self.stream = kwargs.get("stream", False)
|
||||
|
||||
def transcription(self, audio_path):
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"temperature": str(self.gen_conf.get("temperature", 0.75)) or "0.75",
|
||||
"stream": self.stream,
|
||||
}
|
||||
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
with open(audio_path, "rb") as audio_file:
|
||||
files = {"file": audio_file}
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
url=f"{self.base_url}/audio/transcriptions",
|
||||
data=payload,
|
||||
files=files,
|
||||
headers=headers,
|
||||
)
|
||||
body = response.json()
|
||||
if response.status_code == 200:
|
||||
full_content = body["text"]
|
||||
return full_content, num_tokens_from_string(full_content)
|
||||
else:
|
||||
error = body["error"]
|
||||
return f"**ERROR**: code: {error['code']}, message: {error['message']}", 0
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
Reference in New Issue
Block a user