Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import os | |
import os.path as osp | |
import warnings | |
import multiprocessing | |
from functools import partial | |
from src.tools.api_lib.claude import complete_text_claude | |
from src.tools.api_lib.gpt import get_gpt_output | |
from src.tools.api_lib.openai_emb import get_openai_embedding, get_openai_embeddings | |
from src.tools.api_lib.huggingface import complete_text_hf | |
# setup parameters for retrying API calls and the sleep time between retries | |
MAX_OPENAI_RETRY, OPENAI_SLEEP_TIME = 5, 60 | |
MAX_CLAUDE_RETRY, CLAUDE_SLEEP_TIME = 10, 0 | |
registered_text_completion_llms = { | |
"gpt-4-1106-preview", | |
"gpt-4-0125-preview", "gpt-4-turbo-preview", | |
"gpt-4-turbo", "gpt-4-turbo-2024-04-09" | |
"gpt-4-turbo", | |
"claude-2.1", | |
"claude-3-opus-20240229", | |
"claude-3-sonnet-20240229", | |
"claude-3-haiku-20240307", | |
"huggingface/codellama/CodeLlama-7b-hf", | |
"text-embedding-3-small", | |
"text-embedding-3-large", | |
"text-embedding-ada-002" | |
} | |
def parallel_func(func, n_max_nodes=5): | |
''' | |
A general function to call a function on a list of inputs. | |
''' | |
def _parallel_func(inputs: list, **kwargs): | |
partial_func = partial(func, **kwargs) | |
processes = min(len(inputs), n_max_nodes) | |
with multiprocessing.Pool(processes=processes) as pool: | |
results = pool.map(partial_func, inputs) | |
return results | |
return _parallel_func | |
def get_llm_output(message, | |
model="gpt-4-0125-preview", | |
max_tokens=2048, | |
temperature=1, | |
json_object=False | |
): | |
''' | |
A general function to complete a prompt using the specified model. | |
''' | |
if model not in registered_text_completion_llms: | |
warnings.warn(f"Model {model} is not registered. You may still be able to use it.") | |
kwargs = {'message': message, | |
'model': model, | |
'max_tokens': max_tokens, | |
'temperature': temperature, | |
'json_object': json_object} | |
if 'gpt-4' in model: | |
kwargs.update({'max_retry': MAX_OPENAI_RETRY, 'sleep_time': OPENAI_SLEEP_TIME}) | |
return get_gpt_output(**kwargs) | |
elif 'claude' in model: | |
kwargs.update({'max_retry': MAX_CLAUDE_RETRY, 'sleep_time': CLAUDE_SLEEP_TIME}) | |
return complete_text_claude(**kwargs) | |
elif 'huggingface' in model: | |
return complete_text_hf(**kwargs) | |
else: | |
raise ValueError(f"Model {model} not recognized.") | |
complete_texts_claude = parallel_func(complete_text_claude) | |
complete_texts_hf = parallel_func(complete_text_hf) | |
get_gpt_outputs = parallel_func(get_gpt_output) | |
get_llm_outputs = parallel_func(get_llm_output) | |