Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 2,720 Bytes
0c3992e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import StoppingCriteria, StoppingCriteriaList
loaded_hf_models = {}
class StopAtSpecificTokenCriteria(StoppingCriteria):
def __init__(self, stop_sequence):
super().__init__()
self.stop_sequence = stop_sequence
def __call__(self, input_ids, scores, **kwargs):
# Create a tensor from the stop_sequence
stop_sequence_tensor = torch.tensor(self.stop_sequence,
device=input_ids.device,
dtype=input_ids.dtype
)
# Check if the current sequence ends with the stop_sequence
current_sequence = input_ids[:, -len(self.stop_sequence) :]
return bool(torch.all(current_sequence == stop_sequence_tensor).item())
def complete_text_hf(message,
model="huggingface/codellama/CodeLlama-7b-hf",
max_tokens=2000,
temperature=0.5,
json_object=False,
max_retry=1,
sleep_time=0,
stop_sequences=[],
**kwargs):
if json_object:
message = "You are a helpful assistant designed to output in JSON format." + message
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.split("/", 1)[1]
if model in loaded_hf_models:
hf_model, tokenizer = loaded_hf_models[model]
else:
hf_model = AutoModelForCausalLM.from_pretrained(model).to(device)
tokenizer = AutoTokenizer.from_pretrained(model)
loaded_hf_models[model] = (hf_model, tokenizer)
encoded_input = tokenizer(message,
return_tensors="pt",
return_token_type_ids=False
).to(device)
for cnt in range(max_retry):
try:
output = hf_model.generate(
**encoded_input,
temperature=temperature,
max_new_tokens=max_tokens,
do_sample=True,
return_dict_in_generate=True,
output_scores=True,
**kwargs,
)
sequences = output.sequences
sequences = [sequence[len(encoded_input.input_ids[0]) :] for sequence in sequences]
all_decoded_text = tokenizer.batch_decode(sequences)
completion = all_decoded_text[0]
return completion
except Exception as e:
print(cnt, "=>", e)
time.sleep(sleep_time)
raise e
|