feat: Add base project structure with API and web components

This commit is contained in:
Ke Sun
2025-12-02 20:28:01 +08:00
parent f3de6d6cc9
commit c1adc62ec6
817 changed files with 111226 additions and 106 deletions

View File

View File

@@ -0,0 +1,670 @@
import json
import logging
import os
import random
import re
import time
from abc import ABC
from copy import deepcopy
from typing import Any, Protocol
from urllib.parse import urljoin
import json_repair
import openai
from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from strenum import StrEnum
from app.core.rag.nlp import is_chinese, is_english
from app.core.rag.common.token_utils import num_tokens_from_string, total_token_count_from_response
# Error message constants
class LLMErrorCode(StrEnum):
ERROR_RATE_LIMIT = "RATE_LIMIT_EXCEEDED"
ERROR_AUTHENTICATION = "AUTH_ERROR"
ERROR_INVALID_REQUEST = "INVALID_REQUEST"
ERROR_SERVER = "SERVER_ERROR"
ERROR_TIMEOUT = "TIMEOUT"
ERROR_CONNECTION = "CONNECTION_ERROR"
ERROR_MODEL = "MODEL_ERROR"
ERROR_MAX_ROUNDS = "ERROR_MAX_ROUNDS"
ERROR_CONTENT_FILTER = "CONTENT_FILTERED"
ERROR_QUOTA = "QUOTA_EXCEEDED"
ERROR_MAX_RETRIES = "MAX_RETRIES_EXCEEDED"
ERROR_GENERIC = "GENERIC_ERROR"
class ReActMode(StrEnum):
FUNCTION_CALL = "function_call"
REACT = "react"
ERROR_PREFIX = "**ERROR**"
LENGTH_NOTIFICATION_CN = "······\n由于大模型的上下文窗口大小限制,回答已经被大模型截断。"
LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to its limitation on context length."
class ToolCallSession(Protocol):
def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ...
class Base(ABC):
def __init__(self, key, model_name, base_url, **kwargs):
timeout = int(os.environ.get("LLM_TIMEOUT_SECONDS", 600))
self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout)
self.model_name = model_name
# Configure retry parameters
self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0)))
self.max_rounds = kwargs.get("max_rounds", 5)
self.is_tools = False
self.tools = []
self.toolcall_sessions = {}
def _get_delay(self):
"""Calculate retry delay time"""
return self.base_delay * random.uniform(10, 150)
def _classify_error(self, error):
"""Classify error based on error message content"""
error_str = str(error).lower()
keywords_mapping = [
(["quota", "capacity", "credit", "billing", "balance", "欠费"], LLMErrorCode.ERROR_QUOTA),
(["rate limit", "429", "tpm limit", "too many requests", "requests per minute"], LLMErrorCode.ERROR_RATE_LIMIT),
(["auth", "key", "apikey", "401", "forbidden", "permission"], LLMErrorCode.ERROR_AUTHENTICATION),
(["invalid", "bad request", "400", "format", "malformed", "parameter"], LLMErrorCode.ERROR_INVALID_REQUEST),
(["server", "503", "502", "504", "500", "unavailable"], LLMErrorCode.ERROR_SERVER),
(["timeout", "timed out"], LLMErrorCode.ERROR_TIMEOUT),
(["connect", "network", "unreachable", "dns"], LLMErrorCode.ERROR_CONNECTION),
(["filter", "content", "policy", "blocked", "safety", "inappropriate"], LLMErrorCode.ERROR_CONTENT_FILTER),
(["model", "not found", "does not exist", "not available"], LLMErrorCode.ERROR_MODEL),
(["max rounds"], LLMErrorCode.ERROR_MODEL),
]
for words, code in keywords_mapping:
if re.search("({})".format("|".join(words)), error_str):
return code
return LLMErrorCode.ERROR_GENERIC
def _clean_conf(self, gen_conf):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
allowed_conf = {
"temperature",
"max_completion_tokens",
"top_p",
"stream",
"stream_options",
"stop",
"n",
"presence_penalty",
"frequency_penalty",
"functions",
"function_call",
"logit_bias",
"user",
"response_format",
"seed",
"tools",
"tool_choice",
"logprobs",
"top_logprobs",
"extra_headers"
}
gen_conf = {k: v for k, v in gen_conf.items() if k in allowed_conf}
return gen_conf
def _chat(self, history, gen_conf, **kwargs):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwq") >= 0:
logging.info(f"[INFO] {self.model_name} detected as reasoning model, using _chat_streamly")
final_ans = ""
tol_token = 0
for delta, tol in self._chat_streamly(history, gen_conf, with_reasoning=False, **kwargs):
if delta.startswith("<think>") or delta.endswith("</think>"):
continue
final_ans += delta
tol_token = tol
if len(final_ans.strip()) == 0:
final_ans = "**ERROR**: Empty response from reasoning model"
return final_ans.strip(), tol_token
if self.model_name.lower().find("qwen3") >= 0:
kwargs["extra_body"] = {"enable_thinking": False}
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
if not response.choices or not response.choices[0].message or not response.choices[0].message.content:
return "", 0
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
ans = self._length_stop(ans)
return ans, total_token_count_from_response(response)
def _chat_streamly(self, history, gen_conf, **kwargs):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
reasoning_start = False
if kwargs.get("stop") or "stop" in gen_conf:
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf, stop=kwargs.get("stop"))
else:
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
for resp in response:
if not resp.choices:
continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += resp.choices[0].delta.reasoning_content + "</think>"
else:
reasoning_start = False
ans = resp.choices[0].delta.content
tol = total_token_count_from_response(resp)
if not tol:
tol = num_tokens_from_string(resp.choices[0].delta.content)
if resp.choices[0].finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans, tol
def _length_stop(self, ans):
if is_chinese([ans]):
return ans + LENGTH_NOTIFICATION_CN
return ans + LENGTH_NOTIFICATION_EN
@property
def _retryable_errors(self) -> set[str]:
return {
LLMErrorCode.ERROR_RATE_LIMIT,
LLMErrorCode.ERROR_SERVER,
}
def _should_retry(self, error_code: str) -> bool:
return error_code in self._retryable_errors
def _exceptions(self, e, attempt) -> str | None:
logging.exception("OpenAI chat_with_tools")
# Classify the error
error_code = self._classify_error(e)
if attempt == self.max_retries:
error_code = LLMErrorCode.ERROR_MAX_RETRIES
if self._should_retry(error_code):
delay = self._get_delay()
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
time.sleep(delay)
return None
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
def _verbose_tool_use(self, name, args, res):
return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>"
def _append_history(self, hist, tool_call, tool_res):
hist.append(
{
"role": "assistant",
"tool_calls": [
{
"index": tool_call.index,
"id": tool_call.id,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
},
"type": "function",
},
],
}
)
try:
if isinstance(tool_res, dict):
tool_res = json.dumps(tool_res, ensure_ascii=False)
finally:
hist.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_res)})
return hist
def bind_tools(self, toolcall_session, tools):
if not (toolcall_session and tools):
return
self.is_tools = True
self.toolcall_session = toolcall_session
self.tools = tools
def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf)
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
ans = ""
tk_count = 0
hist = deepcopy(history)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
history = hist
try:
for _ in range(self.max_rounds + 1):
logging.info(f"{self.tools=}")
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
tk_count += total_token_count_from_response(response)
if any([not response.choices, not response.choices[0].message]):
raise Exception(f"500 response structure error. Response: {response}")
if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
ans += "<think>" + response.choices[0].message.reasoning_content + "</think>"
ans += response.choices[0].message.content
if response.choices[0].finish_reason == "length":
ans = self._length_stop(ans)
return ans, tk_count
for tool_call in response.choices[0].message.tool_calls:
logging.info(f"Response {tool_call=}")
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = self.toolcall_session.tool_call(name, args)
history = self._append_history(history, tool_call, tool_response)
ans += self._verbose_tool_use(name, args, tool_response)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
ans += self._verbose_tool_use(name, {}, str(e))
logging.warning(f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
response, token_count = self._chat(history, gen_conf)
ans += response
tk_count += token_count
return ans, tk_count
except Exception as e:
e = self._exceptions(e, attempt)
if e:
return e, tk_count
assert False, "Shouldn't be here."
def chat(self, system, history, gen_conf={}, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
try:
return self._chat(history, gen_conf, **kwargs)
except Exception as e:
e = self._exceptions(e, attempt)
if e:
return e, 0
assert False, "Shouldn't be here."
def _wrap_toolcall_message(self, stream):
final_tool_calls = {}
for chunk in stream:
for tool_call in chunk.choices[0].delta.tool_calls or []:
index = tool_call.index
if index not in final_tool_calls:
final_tool_calls[index] = tool_call
final_tool_calls[index].function.arguments += tool_call.function.arguments
return final_tool_calls
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf)
tools = self.tools
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
total_tokens = 0
hist = deepcopy(history)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
history = hist
try:
for _ in range(self.max_rounds + 1):
reasoning_start = False
logging.info(f"{tools=}")
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
final_tool_calls = {}
answer = ""
for resp in response:
if resp.choices[0].delta.tool_calls:
for tool_call in resp.choices[0].delta.tool_calls or []:
index = tool_call.index
if index not in final_tool_calls:
if not tool_call.function.arguments:
tool_call.function.arguments = ""
final_tool_calls[index] = tool_call
else:
final_tool_calls[index].function.arguments += tool_call.function.arguments if tool_call.function.arguments else ""
continue
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
raise Exception("500 response structure error.")
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += resp.choices[0].delta.reasoning_content + "</think>"
yield ans
else:
reasoning_start = False
answer += resp.choices[0].delta.content
yield resp.choices[0].delta.content
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
else:
total_tokens = tol
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
if finish_reason == "length":
yield self._length_stop("")
if answer:
yield total_tokens
return
for tool_call in final_tool_calls.values():
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
yield self._verbose_tool_use(name, args, "Begin to call...")
tool_response = self.toolcall_session.tool_call(name, args)
history = self._append_history(history, tool_call, tool_response)
yield self._verbose_tool_use(name, args, tool_response)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
yield self._verbose_tool_use(name, {}, str(e))
logging.warning(f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
for resp in response:
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
raise Exception("500 response structure error.")
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
continue
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
else:
total_tokens = tol
answer += resp.choices[0].delta.content
yield resp.choices[0].delta.content
yield total_tokens
return
except Exception as e:
e = self._exceptions(e, attempt)
if e:
yield e
yield total_tokens
return
assert False, "Shouldn't be here."
def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
ans = ""
total_tokens = 0
try:
for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs):
yield delta_ans
total_tokens += tol
except openai.APIError as e:
yield ans + "\n**ERROR**: " + str(e)
yield total_tokens
def _calculate_dynamic_ctx(self, history):
"""Calculate dynamic context window size"""
def count_tokens(text):
"""Calculate token count for text"""
# Simple calculation: 1 token per ASCII character
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
total = 0
for char in text:
if ord(char) < 128: # ASCII characters
total += 1
else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.)
total += 2
return total
# Calculate total tokens for all messages
total_tokens = 0
for message in history:
content = message.get("content", "")
# Calculate content tokens
content_tokens = count_tokens(content)
# Add role marker token overhead
role_tokens = 4
total_tokens += content_tokens + role_tokens
# Apply 1.2x buffer ratio
total_tokens_with_buffer = int(total_tokens * 1.2)
if total_tokens_with_buffer <= 8192:
ctx_size = 8192
else:
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
ctx_size = ctx_multiplier * 8192
return ctx_size
class GptTurbo(Base):
_FACTORY_NAME = "OpenAI"
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1", **kwargs):
if not base_url:
base_url = "https://api.openai.com/v1"
super().__init__(key, model_name, base_url, **kwargs)
class XinferenceChat(Base):
_FACTORY_NAME = "Xinference"
def __init__(self, key=None, model_name="", base_url="", **kwargs):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
super().__init__(key, model_name, base_url, **kwargs)
class HuggingFaceChat(Base):
_FACTORY_NAME = "HuggingFace"
def __init__(self, key=None, model_name="", base_url="", **kwargs):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
super().__init__(key, model_name.split("___")[0], base_url, **kwargs)
class ModelScopeChat(Base):
_FACTORY_NAME = "ModelScope"
def __init__(self, key=None, model_name="", base_url="", **kwargs):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
super().__init__(key, model_name.split("___")[0], base_url, **kwargs)
class AzureChat(Base):
_FACTORY_NAME = "Azure-OpenAI"
def __init__(self, key, model_name, base_url, **kwargs):
api_key = json.loads(key).get("api_key", "")
api_version = json.loads(key).get("api_version", "2024-02-01")
super().__init__(key, model_name, base_url, **kwargs)
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
self.model_name = model_name
@property
def _retryable_errors(self) -> set[str]:
return {
LLMErrorCode.ERROR_RATE_LIMIT,
LLMErrorCode.ERROR_SERVER,
LLMErrorCode.ERROR_QUOTA,
}
class BaiChuanChat(Base):
_FACTORY_NAME = "BaiChuan"
def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1", **kwargs):
if not base_url:
base_url = "https://api.baichuan-ai.com/v1"
super().__init__(key, model_name, base_url, **kwargs)
@staticmethod
def _format_params(params):
return {
"temperature": params.get("temperature", 0.3),
"top_p": params.get("top_p", 0.85),
}
def _clean_conf(self, gen_conf):
return {
"temperature": gen_conf.get("temperature", 0.3),
"top_p": gen_conf.get("top_p", 0.85),
}
def _chat(self, history, gen_conf={}, **kwargs):
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]},
**gen_conf,
)
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, total_token_count_from_response(response)
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
ans = ""
total_tokens = 0
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]},
stream=True,
**self._format_params(gen_conf),
)
for resp in response:
if not resp.choices:
continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
ans = resp.choices[0].delta.content
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
else:
total_tokens = tol
if resp.choices[0].finish_reason == "length":
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield total_tokens
class LocalAIChat(Base):
_FACTORY_NAME = "LocalAI"
def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key="empty", base_url=base_url)
self.model_name = model_name.split("___")[0]
class VolcEngineChat(Base):
_FACTORY_NAME = "VolcEngine"
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3", **kwargs):
"""
Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
Assemble ark_api_key, ep_id into api_key, store it as a dictionary type, and parse it for use
model_name is for display only
"""
base_url = base_url if base_url else "https://ark.cn-beijing.volces.com/api/v3"
ark_api_key = json.loads(key).get("ark_api_key", "")
model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
super().__init__(ark_api_key, model_name, base_url, **kwargs)
class OpenAI_APIChat(Base):
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
def __init__(self, key, model_name, base_url, **kwargs):
if not base_url:
raise ValueError("url cannot be None")
model_name = model_name.split("___")[0]
super().__init__(key, model_name, base_url, **kwargs)
class GPUStackChat(Base):
_FACTORY_NAME = "GPUStack"
def __init__(self, key=None, model_name="", base_url="", **kwargs):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
super().__init__(key, model_name, base_url, **kwargs)

View File

@@ -0,0 +1,470 @@
import base64
import json
import os
import tempfile
import logging
from abc import ABC
from copy import deepcopy
from io import BytesIO
from pathlib import Path
from urllib.parse import urljoin
import requests
from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from app.core.rag.nlp import is_english
from app.core.rag.prompts.generator import vision_llm_describe_prompt
from app.core.rag.common.token_utils import num_tokens_from_string, total_token_count_from_response
class Base(ABC):
def __init__(self, **kwargs):
# Configure retry parameters
self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0)))
self.max_rounds = kwargs.get("max_rounds", 5)
self.is_tools = False
self.tools = []
self.toolcall_sessions = {}
self.extra_body = None
def describe(self, image):
raise NotImplementedError("Please implement encode method!")
def describe_with_prompt(self, image, prompt=None):
raise NotImplementedError("Please implement encode method!")
def _form_history(self, system, history, images=None):
hist = []
if system:
hist.append({"role": "system", "content": system})
for h in history:
if images and h["role"] == "user":
h["content"] = self._image_prompt(h["content"], images)
images = []
hist.append(h)
return hist
def _image_prompt(self, text, images):
if not images:
return text
if isinstance(images, str) or "bytes" in type(images).__name__:
images = [images]
pmpt = [{"type": "text", "text": text}]
for img in images:
pmpt.append({
"type": "image_url",
"image_url": {
"url": img if isinstance(img, str) and img.startswith("data:") else f"data:image/png;base64,{img}"
}
})
return pmpt
def chat(self, system, history, gen_conf, images=None, **kwargs):
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=self._form_history(system, history, images),
extra_body=self.extra_body,
)
return response.choices[0].message.content.strip(), response.usage.total_tokens
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
ans = ""
tk_count = 0
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=self._form_history(system, history, images),
stream=True,
extra_body=self.extra_body,
)
for resp in response:
if not resp.choices[0].delta.content:
continue
delta = resp.choices[0].delta.content
ans = delta
if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
if resp.choices[0].finish_reason == "stop":
tk_count += resp.usage.total_tokens
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield tk_count
@staticmethod
def image2base64_rawvalue(self, image):
# Return a base64 string without data URL header
if isinstance(image, bytes):
b64 = base64.b64encode(image).decode("utf-8")
return b64
if isinstance(image, BytesIO):
data = image.getvalue()
b64 = base64.b64encode(data).decode("utf-8")
return b64
with BytesIO() as buffered:
try:
image.save(buffered, format="JPEG")
except Exception:
# reset buffer before saving PNG
buffered.seek(0)
buffered.truncate()
image.save(buffered, format="PNG")
data = buffered.getvalue()
b64 = base64.b64encode(data).decode("utf-8")
return b64
@staticmethod
def image2base64(image):
# Return a data URL with the correct MIME to avoid provider mismatches
if isinstance(image, bytes):
# Best-effort magic number sniffing
mime = "image/png"
if len(image) >= 2 and image[0] == 0xFF and image[1] == 0xD8:
mime = "image/jpeg"
b64 = base64.b64encode(image).decode("utf-8")
return f"data:{mime};base64,{b64}"
if isinstance(image, BytesIO):
data = image.getvalue()
mime = "image/png"
if len(data) >= 2 and data[0] == 0xFF and data[1] == 0xD8:
mime = "image/jpeg"
b64 = base64.b64encode(data).decode("utf-8")
return f"data:{mime};base64,{b64}"
with BytesIO() as buffered:
fmt = "jpeg"
try:
image.save(buffered, format="JPEG")
except Exception:
# reset buffer before saving PNG
buffered.seek(0)
buffered.truncate()
image.save(buffered, format="PNG")
fmt = "png"
data = buffered.getvalue()
b64 = base64.b64encode(data).decode("utf-8")
mime = f"image/{fmt}"
return f"data:{mime};base64,{b64}"
def prompt(self, b64):
return [
{
"role": "user",
"content": self._image_prompt(
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
if self.lang.lower() == "chinese"
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
b64
)
}
]
def vision_llm_prompt(self, b64, prompt=None):
return [
{
"role": "user",
"content": self._image_prompt(prompt if prompt else vision_llm_describe_prompt(), b64)
}
]
class GptV4(Base):
_FACTORY_NAME = "OpenAI"
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1", **kwargs):
if not base_url:
base_url = "https://api.openai.com/v1"
self.api_key = key
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
super().__init__(**kwargs)
def describe(self, image):
b64 = self.image2base64(image)
res = self.client.chat.completions.create(
model=self.model_name,
messages=self.prompt(b64),
extra_body=self.extra_body,
)
return res.choices[0].message.content.strip(), total_token_count_from_response(res)
def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image)
res = self.client.chat.completions.create(
model=self.model_name,
messages=self.vision_llm_prompt(b64, prompt),
extra_body=self.extra_body,
)
return res.choices[0].message.content.strip(),total_token_count_from_response(res)
class AzureGptV4(GptV4):
_FACTORY_NAME = "Azure-OpenAI"
def __init__(self, key, model_name, lang="Chinese", **kwargs):
api_key = json.loads(key).get("api_key", "")
api_version = json.loads(key).get("api_version", "2024-02-01")
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
self.model_name = model_name
self.lang = lang
Base.__init__(self, **kwargs)
class QWenCV(GptV4):
_FACTORY_NAME = "Tongyi-Qianwen"
def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", base_url=None, **kwargs):
if not base_url:
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
super().__init__(key, model_name, lang=lang, base_url=base_url, **kwargs)
def chat(self, system, history, gen_conf, images=None, video_bytes=None, filename="", **kwargs):
if video_bytes:
try:
summary, summary_num_tokens = self._process_video(video_bytes, filename)
return summary, summary_num_tokens
except Exception as e:
return "**ERROR**: " + str(e), 0
return "**ERROR**: Method chat not supported yet.", 0
def _process_video(self, video_bytes, filename):
from dashscope import MultiModalConversation
video_suffix = Path(filename).suffix or ".mp4"
with tempfile.NamedTemporaryFile(delete=False, suffix=video_suffix) as tmp:
tmp.write(video_bytes)
tmp_path = tmp.name
video_path = f"file://{tmp_path}"
messages = [
{
"role": "user",
"content": [
{
"video": video_path,
"fps": 2,
},
{
"text": "Please summarize this video in proper sentences.",
},
],
}
]
def call_api():
response = MultiModalConversation.call(
api_key=self.api_key,
model=self.model_name,
messages=messages,
)
summary = response["output"]["choices"][0]["message"].content[0]["text"]
return summary, num_tokens_from_string(summary)
try:
return call_api()
except Exception as e1:
import dashscope
dashscope.base_http_api_url = "https://dashscope-intl.aliyuncs.com/api/v1"
try:
return call_api()
except Exception as e2:
raise RuntimeError(f"Both default and intl endpoint failed.\nFirst error: {e1}\nSecond error: {e2}")
class XinferenceCV(GptV4):
_FACTORY_NAME = "Xinference"
def __init__(self, key, model_name="", lang="Chinese", base_url="", **kwargs):
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
Base.__init__(self, **kwargs)
class GPUStackCV(GptV4):
_FACTORY_NAME = "GPUStack"
def __init__(self, key, model_name, lang="Chinese", base_url="", **kwargs):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
Base.__init__(self, **kwargs)
class OllamaCV(Base):
_FACTORY_NAME = "Ollama"
def __init__(self, key, model_name, lang="Chinese", **kwargs):
from ollama import Client
self.client = Client(host=kwargs["base_url"])
self.model_name = model_name
self.lang = lang
self.keep_alive = kwargs.get("ollama_keep_alive", int(os.environ.get("OLLAMA_KEEP_ALIVE", -1)))
Base.__init__(self, **kwargs)
def _clean_img(self, img):
if not isinstance(img, str):
return img
#remove the header like "data/*;base64,"
if img.startswith("data:") and ";base64," in img:
img = img.split(";base64,")[1]
return img
def _clean_conf(self, gen_conf):
options = {}
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "top_p" in gen_conf:
options["top_k"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
return options
def _form_history(self, system, history, images=None):
hist = deepcopy(history)
if system and hist[0]["role"] == "user":
hist.insert(0, {"role": "system", "content": system})
if not images:
return hist
temp_images = []
for img in images:
temp_images.append(self._clean_img(img))
for his in hist:
if his["role"] == "user":
his["images"] = temp_images
break
return hist
def describe(self, image):
prompt = self.prompt("")
try:
response = self.client.generate(
model=self.model_name,
prompt=prompt[0]["content"],
images=[image],
)
ans = response["response"].strip()
return ans, 128
except Exception as e:
return "**ERROR**: " + str(e), 0
def describe_with_prompt(self, image, prompt=None):
vision_prompt = self.vision_llm_prompt("", prompt) if prompt else self.vision_llm_prompt("")
try:
response = self.client.generate(
model=self.model_name,
prompt=vision_prompt[0]["content"],
images=[image],
)
ans = response["response"].strip()
return ans, 128
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat(self, system, history, gen_conf, images=None, **kwargs):
try:
response = self.client.chat(
model=self.model_name,
messages=self._form_history(system, history, images),
options=self._clean_conf(gen_conf),
keep_alive=self.keep_alive
)
ans = response["message"]["content"].strip()
return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
ans = ""
try:
response = self.client.chat(
model=self.model_name,
messages=self._form_history(system, history, images),
stream=True,
options=self._clean_conf(gen_conf),
keep_alive=self.keep_alive
)
for resp in response:
if resp["done"]:
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
ans = resp["message"]["content"]
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield 0
if __name__ == "__main__":
# import sys
# chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
# # 准备配置vision_model信息
# azure_config = {
# "api_key": "xxxxx",
# "api_version": "2024-02-01"
# }
# # 转换为 JSON 字符串,因为类中使用 json.loads(key) 解析
# key = json.dumps(azure_config)
# # 初始化 AzureGptV4
# vision_model = AzureGptV4(
# key=key, # JSON 字符串形式的配置
# model_name="gpt-4o",
# lang="Chinese", # 默认使用中文
# base_url="https://fosun-openai-east-us-001.openai.azure.com/" # Azure OpenAI 端点
# )
# try:
# # 测试图像描述功能
# image_path = "/Users/sbtjfdn/Downloads/记忆科学/files/aippt.cn.png"
# with open(image_path, "rb") as image_file:
# image_data = image_file.read()
#
# # 使用 describe 方法
# description, token_count = vision_model.describe(image_data)
# # from app.core.rag.prompts.generator import vision_llm_figure_describe_prompt
# # description, token_count = vision_model.describe_with_prompt(image_data, prompt=vision_llm_figure_describe_prompt())
# print(f"描述: {description}")
# print(f"使用的令牌数: {token_count}")
#
# except Exception as e:
# print(f"初始化或处理过程中出错: {str(e)}")
# 准备配置vision_model信息
# 初始化 QWenCV
vision_model = QWenCV(
key="sk-8e9e40cd171749858ce2d3722ea75669",
model_name="qwen-vl-max",
lang="Chinese", # 默认使用中文
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
)
try:
# 测试图像描述功能
image_path = "/Users/sbtjfdn/Downloads/记忆科学/files/10.png"
with open(image_path, "rb") as image_file:
image_data = image_file.read()
# 使用 describe 方法
description, token_count = vision_model.describe(image_data)
# from app.core.rag.prompts.generator import vision_llm_figure_describe_prompt
# description, token_count = vision_model.describe_with_prompt(image_data, prompt=vision_llm_figure_describe_prompt())
print(f"描述: {description}")
print(f"使用的令牌数: {token_count}")
except Exception as e:
print(f"初始化或处理过程中出错: {str(e)}")

View File

@@ -0,0 +1,179 @@
import base64
import io
import json
import os
import re
from abc import ABC
import requests
from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from app.core.rag.common.token_utils import num_tokens_from_string
class Base(ABC):
def __init__(self, key, model_name, **kwargs):
"""
Abstract base class constructor.
Parameters are not stored; initialization is left to subclasses.
"""
pass
def transcription(self, audio_path, **kwargs):
audio_file = open(audio_path, "rb")
transcription = self.client.audio.transcriptions.create(model=self.model_name, file=audio_file)
return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
def audio2base64(self, audio):
if isinstance(audio, bytes):
return base64.b64encode(audio).decode("utf-8")
if isinstance(audio, io.BytesIO):
return base64.b64encode(audio.getvalue()).decode("utf-8")
raise TypeError("The input audio file should be in binary format.")
class GPTSeq2txt(Base):
_FACTORY_NAME = "OpenAI"
def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1", **kwargs):
if not base_url:
base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
class QWenSeq2txt(Base):
_FACTORY_NAME = "Tongyi-Qianwen"
def __init__(self, key, model_name="qwen-audio-asr", **kwargs):
import dashscope
dashscope.api_key = key
self.model_name = model_name
def transcription(self, audio_path):
if "paraformer" in self.model_name or "sensevoice" in self.model_name:
return f"**ERROR**: model {self.model_name} is not suppported yet.", 0
from dashscope import MultiModalConversation
audio_path = f"file://{audio_path}"
messages = [
{
"role": "user",
"content": [{"audio": audio_path}],
}
]
response = None
full_content = ""
try:
response = MultiModalConversation.call(model="qwen-audio-asr", messages=messages, result_format="message", stream=True)
for response in response:
try:
full_content += response["output"]["choices"][0]["message"].content[0]["text"]
except Exception:
pass
return full_content, num_tokens_from_string(full_content)
except Exception as e:
return "**ERROR**: " + str(e), 0
class AzureSeq2txt(Base):
_FACTORY_NAME = "Azure-OpenAI"
def __init__(self, key, model_name, lang="Chinese", **kwargs):
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
self.model_name = model_name
self.lang = lang
class XinferenceSeq2txt(Base):
_FACTORY_NAME = "Xinference"
def __init__(self, key, model_name="whisper-small", **kwargs):
self.base_url = kwargs.get("base_url", None)
self.model_name = model_name
self.key = key
def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
if isinstance(audio, str):
audio_file = open(audio, "rb")
audio_data = audio_file.read()
audio_file_name = audio.split("/")[-1]
else:
audio_data = audio
audio_file_name = "audio.wav"
payload = {"model": self.model_name, "language": language, "prompt": prompt, "response_format": response_format, "temperature": temperature}
files = {"file": (audio_file_name, audio_data, "audio/wav")}
try:
response = requests.post(f"{self.base_url}/v1/audio/transcriptions", files=files, data=payload)
response.raise_for_status()
result = response.json()
if "text" in result:
transcription_text = result["text"].strip()
return transcription_text, num_tokens_from_string(transcription_text)
else:
return "**ERROR**: Failed to retrieve transcription.", 0
except requests.exceptions.RequestException as e:
return f"**ERROR**: {str(e)}", 0
class GPUStackSeq2txt(Base):
_FACTORY_NAME = "GPUStack"
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
if base_url.split("/")[-1] != "v1":
base_url = os.path.join(base_url, "v1")
self.base_url = base_url
self.model_name = model_name
self.key = key
class ZhipuSeq2txt(Base):
_FACTORY_NAME = "ZHIPU-AI"
def __init__(self, key, model_name="glm-asr", base_url="https://open.bigmodel.cn/api/paas/v4", **kwargs):
if not base_url:
base_url = "https://open.bigmodel.cn/api/paas/v4"
self.base_url = base_url
self.api_key = key
self.model_name = model_name
self.gen_conf = kwargs.get("gen_conf", {})
self.stream = kwargs.get("stream", False)
def transcription(self, audio_path):
payload = {
"model": self.model_name,
"temperature": str(self.gen_conf.get("temperature", 0.75)) or "0.75",
"stream": self.stream,
}
headers = {"Authorization": f"Bearer {self.api_key}"}
with open(audio_path, "rb") as audio_file:
files = {"file": audio_file}
try:
response = requests.post(
url=f"{self.base_url}/audio/transcriptions",
data=payload,
files=files,
headers=headers,
)
body = response.json()
if response.status_code == 200:
full_content = body["text"]
return full_content, num_tokens_from_string(full_content)
else:
error = body["error"]
return f"**ERROR**: code: {error['code']}, message: {error['message']}", 0
except Exception as e:
return "**ERROR**: " + str(e), 0