Spaces:
Sleeping
Sleeping
import torch | |
class TextGen: | |
def __init__(self,tokenizer,model,device): | |
self.tokenizer = tokenizer | |
self.model = model | |
self.device = device | |
self.model.to(self.device) | |
def generate_text(self, user_input): | |
inputs = self.tokenizer.encode(user_input, return_tensors="pt").to(self.device) | |
# generate text | |
attention_mask = torch.ones(inputs.shape, device=self.device) | |
output = self.model.generate( | |
inputs, | |
attention_mask=attention_mask, | |
num_return_sequences=1, | |
max_length=50, | |
max_new_tokens=100, | |
temperature=0.5, | |
repetition_penalty=1.2, | |
pad_token_id=self.tokenizer.eos_token_id, | |
) | |
generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True) | |
return generated_text |