Spaces:
Configuration error
Configuration error
from typing import Literal, cast | |
from openai import OpenAI | |
from neollm.llm.abstract_llm import AbstractLLM | |
from neollm.llm.gpt.abstract_gpt import AbstractGPT | |
from neollm.types import ( | |
APIPricing, | |
ClientSettings, | |
LLMSettings, | |
Messages, | |
Response, | |
StreamResponse, | |
) | |
from neollm.utils.utils import cprint | |
# Models: https://platform.openai.com/docs/models/continuous-model-upgrades | |
# Pricing: https://openai.com/pricing | |
SUPPORTED_MODELS = Literal[ | |
"gpt-4o-2024-05-13", | |
"gpt-4-turbo-2024-04-09", | |
"gpt-3.5-turbo-0125", | |
"gpt-4-turbo-0125", | |
"gpt-3.5-turbo-1106", | |
"gpt-4-turbo-1106", | |
"gpt-4v-turbo-1106", | |
"gpt-3.5-turbo-0613", | |
"gpt-3.5-turbo-16k-0613", | |
"gpt-4-0613", | |
"gpt-4-32k-0613", | |
] | |
def get_openai_llm(model_name: SUPPORTED_MODELS | str, client_settings: ClientSettings) -> AbstractLLM: | |
# Add 日付 | |
replace_map_for_nodate: dict[str, SUPPORTED_MODELS] = { | |
"gpt-4o": "gpt-4o-2024-05-13", | |
"gpt-3.5-turbo": "gpt-3.5-turbo-0613", | |
"gpt-3.5-turbo-16k": "gpt-3.5-turbo-16k-0613", | |
"gpt-4": "gpt-4-0613", | |
"gpt-4-32k": "gpt-4-32k-0613", | |
"gpt-4-turbo": "gpt-4-turbo-1106", | |
"gpt-4v-turbo": "gpt-4v-turbo-1106", | |
} | |
if model_name in replace_map_for_nodate: | |
cprint("WARNING: model_nameに日付を指定してください", color="yellow", background=True) | |
print(f"model_name: {model_name} -> {replace_map_for_nodate[model_name]}") | |
model_name = replace_map_for_nodate[model_name] | |
# map to LLM | |
supported_model_map: dict[SUPPORTED_MODELS, AbstractLLM] = { | |
"gpt-4o-2024-05-13": OpenAIGPT4O_20240513(client_settings), | |
"gpt-4-turbo-2024-04-09": OpenAIGPT4T_20240409(client_settings), | |
"gpt-3.5-turbo-0125": OpenAIGPT35T_0125(client_settings), | |
"gpt-4-turbo-0125": OpenAIGPT4T_0125(client_settings), | |
"gpt-3.5-turbo-1106": OpenAIGPT35T_1106(client_settings), | |
"gpt-4-turbo-1106": OpenAIGPT4T_1106(client_settings), | |
"gpt-4v-turbo-1106": OpenAIGPT4VT_1106(client_settings), | |
"gpt-3.5-turbo-0613": OpenAIGPT35T_0613(client_settings), | |
"gpt-3.5-turbo-16k-0613": OpenAIGPT35T16k_0613(client_settings), | |
"gpt-4-0613": OpenAIGPT4_0613(client_settings), | |
"gpt-4-32k-0613": OpenAIGPT432k_0613(client_settings), | |
} | |
# 通常モデル | |
if model_name in supported_model_map: | |
model_name = cast(SUPPORTED_MODELS, model_name) | |
return supported_model_map[model_name] | |
# FTモデル | |
if "gpt-3.5-turbo-1106" in model_name: | |
return OpenAIGPT35TFT_1106(model_name, client_settings) | |
if "gpt-3.5-turbo-0613" in model_name: | |
return OpenAIGPT35TFT_0613(model_name, client_settings) | |
if "gpt-3.5-turbo-0125" in model_name: | |
return OpenAIGPT35TFT_0125(model_name, client_settings) | |
if "gpt4" in model_name.replace("-", ""): # TODO! もっといい条件に修正 | |
return OpenAIGPT4FT_0613(model_name, client_settings) | |
cprint( | |
f"WARNING: このFTモデルは何?: {model_name} -> OpenAIGPT35TFT_1106として設定", color="yellow", background=True | |
) | |
return OpenAIGPT35TFT_1106(model_name, client_settings) | |
class OpenAILLM(AbstractGPT): | |
model: str | |
def client(self) -> OpenAI: | |
client: OpenAI = OpenAI(**self.client_settings) | |
# api_key: str | None = (None,) | |
# timeout: httpx.Timeout(timeout=600.0, connect=5.0) | |
# max_retries: int = 2 | |
return client | |
def generate(self, messages: Messages, llm_settings: LLMSettings) -> Response: | |
openai_response = self.client.chat.completions.create( | |
model=self.model, | |
messages=self._convert_to_platform_messages(messages), | |
stream=False, | |
**llm_settings, | |
) | |
response = self._convert_to_response(openai_response) | |
return response | |
def generate_stream(self, messages: Messages, llm_settings: LLMSettings) -> StreamResponse: | |
platform_stream_response = self.client.chat.completions.create( | |
model=self.model, | |
messages=self._convert_to_platform_messages(messages), | |
stream=True, | |
**llm_settings, | |
) | |
stream_response = self._convert_to_streamresponse(platform_stream_response) | |
return stream_response | |
# omni 2024-05-13 -------------------------------------------------------------------------------------------- | |
class OpenAIGPT4O_20240513(OpenAILLM): | |
dollar_per_ktoken = APIPricing(input=0.005, output=0.015) | |
model: str = "gpt-4o-2024-05-13" | |
context_window: int = 128_000 | |
# 2024-04-09 -------------------------------------------------------------------------------------------- | |
class OpenAIGPT4T_20240409(OpenAILLM): | |
dollar_per_ktoken = APIPricing(input=0.01, output=0.03) # 10倍/15倍 | |
model: str = "gpt-4-turbo-2024-04-09" | |
# model: str = "gpt-4-turbo-2024-04-09" | |
context_window: int = 128_000 | |
# 0125 -------------------------------------------------------------------------------------------- | |
class OpenAIGPT35T_0125(OpenAILLM): | |
dollar_per_ktoken = APIPricing(input=0.0005, output=0.0015) | |
model: str = "gpt-3.5-turbo-0125" | |
context_window: int = 16_385 | |
class OpenAIGPT4T_0125(OpenAILLM): | |
dollar_per_ktoken = APIPricing(input=0.01, output=0.03) | |
model: str = "gpt-4-0125-preview" | |
context_window: int = 128_000 | |
class OpenAIGPT35TFT_0125(OpenAILLM): | |
dollar_per_ktoken = APIPricing(input=0.003, output=0.006) | |
context_window: int = 16_385 | |
def __init__(self, model_name: str, client_setting: ClientSettings) -> None: | |
super().__init__(client_setting) | |
self.model = model_name | |
# 1106 -------------------------------------------------------------------------------------------- | |
class OpenAIGPT35T_1106(OpenAILLM): | |
dollar_per_ktoken = APIPricing(input=0.0010, output=0.0020) | |
model: str = "gpt-3.5-turbo-1106" | |
context_window: int = 16_385 | |
class OpenAIGPT4T_1106(OpenAILLM): | |
dollar_per_ktoken = APIPricing(input=0.01, output=0.03) | |
model: str = "gpt-4-1106-preview" | |
context_window: int = 128_000 | |
class OpenAIGPT4VT_1106(OpenAILLM): | |
dollar_per_ktoken = APIPricing(input=0.01, output=0.03) | |
model: str = "gpt-4-1106-vision-preview" | |
context_window: int = 128_000 | |
class OpenAIGPT35TFT_1106(OpenAILLM): | |
dollar_per_ktoken = APIPricing(input=0.003, output=0.006) | |
context_window: int = 4_096 | |
def __init__(self, model_name: str, client_setting: ClientSettings) -> None: | |
super().__init__(client_setting) | |
self.model = model_name | |
# 0613 -------------------------------------------------------------------------------------------- | |
class OpenAIGPT35T_0613(OpenAILLM): | |
dollar_per_ktoken = APIPricing(input=0.0015, output=0.002) | |
model: str = "gpt-3.5-turbo-0613" | |
context_window: int = 4_096 | |
class OpenAIGPT35T16k_0613(OpenAILLM): | |
dollar_per_ktoken = APIPricing(input=0.003, output=0.004) | |
model: str = "gpt-3.5-turbo-16k-0613" | |
context_window: int = 16_385 | |
class OpenAIGPT4_0613(OpenAILLM): | |
dollar_per_ktoken = APIPricing(input=0.03, output=0.06) | |
model: str = "gpt-4-0613" | |
context_window: int = 8_192 | |
class OpenAIGPT432k_0613(OpenAILLM): | |
dollar_per_ktoken = APIPricing(input=0.06, output=0.12) | |
model: str = "gpt-4-32k-0613" | |
context_window: int = 32_768 | |
class OpenAIGPT35TFT_0613(OpenAILLM): | |
dollar_per_ktoken = APIPricing(input=0.003, output=0.006) | |
context_window: int = 4_096 | |
def __init__(self, model_name: str, client_setting: ClientSettings) -> None: | |
super().__init__(client_setting) | |
self.model = model_name | |
class OpenAIGPT4FT_0613(OpenAILLM): | |
dollar_per_ktoken = APIPricing(input=0.045, output=0.090) | |
context_window: int = 8_192 | |
def __init__(self, model_name: str, client_setting: ClientSettings) -> None: | |
super().__init__(client_setting) | |
self.model = model_name | |