import torch class TextGen: def __init__(self,tokenizer,model,device): self.tokenizer = tokenizer self.model = model self.device = device self.model.to(self.device) def generate_text(self, user_input): inputs = self.tokenizer.encode(user_input, return_tensors="pt").to(self.device) # generate text attention_mask = torch.ones(inputs.shape, device=self.device) output = self.model.generate( inputs, attention_mask=attention_mask, num_return_sequences=1, max_length=50, max_new_tokens=100, temperature=0.5, repetition_penalty=1.2, pad_token_id=self.tokenizer.eos_token_id, ) generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True) return generated_text