Spaces:
Paused
Paused
import os | |
import abc | |
import asyncio | |
from abc import abstractmethod | |
import math | |
import tiktoken | |
import openai | |
import backoff | |
class LLM(abc.ABC): | |
prompt_percent = 0.9 | |
def __init__(self): | |
raise NotImplementedError("Subclasses should implement this!") | |
def infer(self, prompts): | |
raise NotImplementedError("Subclasses should implement this!") | |
def split_input(self, fixed_instruction, few_shot_examples, splittable_input, input_header, output_header): | |
raise NotImplementedError("Subclasses should implement this!") | |
class GPT(LLM): | |
prompt_percent = 0.8 | |
openai_cxn_dict = { | |
'default': { | |
'endpoint': "INSERT YOUR AZURE OPENAI ENDPOINT HERE", | |
'api_key': "INSERT YOUR AZURE OPENAI API KEY HERE", | |
}, | |
} | |
deployment_max_length_dict = { | |
'gpt-4': 8192, | |
'gpt-4-0314': 8192, | |
'gpt-4-32k': 32768, | |
'gpt-35-turbo': 4096, | |
'gpt-35-turbo-16k': 16385, | |
} | |
def __init__(self, model_id): | |
self.temperature = 0.0 | |
self.top_k = 1 | |
self.encoding = tiktoken.encoding_for_model("-".join(model_id.split("-", 2)[:2]).replace('5', '.5')) | |
self.openai_api = 'default' | |
self.model_id = model_id | |
self.max_length = self.deployment_max_length_dict[model_id] | |
self.client = openai.AsyncAzureOpenAI( | |
api_key=self.openai_cxn_dict[self.openai_api]['api_key'], | |
api_version="2023-12-01-preview", | |
azure_endpoint=self.openai_cxn_dict[self.openai_api]['endpoint'] | |
) | |
def gen_messages(self, fixed_instruction, few_shot_examples, input, input_header, output_header): | |
messages = [ | |
{ | |
"role": "system", | |
"content": fixed_instruction, | |
}, | |
] | |
for example in few_shot_examples: | |
messages.extend( | |
[ | |
{ | |
"role": "user", | |
"content": input_header+'\n'+example['user']+'\n\n'+output_header, | |
}, | |
{ | |
"role": "assistant", | |
"content": example['assistant'], | |
}, | |
] | |
) | |
messages.extend( | |
[ | |
{ | |
"role": "user", | |
"content": input_header+'\n'+input+'\n\n'+output_header, | |
}, | |
] | |
) | |
return messages | |
# Define the coroutine for making API calls to GPT | |
async def make_api_call_to_gpt( | |
self, | |
messages | |
): | |
response = await self.client.chat.completions.create( | |
model=self.model_id, | |
messages=messages, | |
temperature=self.temperature, | |
) | |
return response.choices[0].message.content | |
async def dispatch_openai_requests( | |
self, | |
messages_list, | |
): | |
# Asynchronously call the function for each prompt | |
tasks = [self.make_api_call_to_gpt(messages) for messages in messages_list] | |
# Gather and run the tasks concurrently | |
results = await asyncio.gather(*tasks) | |
return results | |
def infer(self, | |
messages_list, | |
): | |
return asyncio.run(self.dispatch_openai_requests(messages_list)) | |
def split_input(self, fixed_instruction, few_shot_examples, splittable_input, input_header, output_header): | |
# Tokenize fixed_prompt | |
fixed_token_ids = self.encoding.encode(fixed_instruction+' '.join([x['user']+' '+x['assistant'] for x in few_shot_examples])) | |
# Calculate remaining token length | |
remaining_token_len = math.ceil((self.prompt_percent*self.max_length)-len(fixed_token_ids)) | |
# Tokenize splittable_input | |
split_token_ids = self.encoding.encode(splittable_input) | |
# Split tokenized split_prompt into list of individual inputs strings. Uses tokens to calculate length | |
split_token_ids_list = [split_token_ids[i:i+remaining_token_len+10] for i in range(0, len(split_token_ids), remaining_token_len)] | |
split_input_list = [self.encoding.decode(split_token_ids) for split_token_ids in split_token_ids_list] | |
# Take the fixed_prompt, few_shot_examples, splitted inputs, and input/output headers and generate list of prompt strings. | |
return [self.gen_messages(fixed_instruction, few_shot_examples, split_input, input_header, output_header) for split_input in split_input_list] |