--- library_name: transformers tags: [] --- # T5 for Search Query Generation ```python class T5ForSQG: def __init__(self, model_path): self.model = T5ForConditionalGeneration.from_pretrained(model_path) self.tokenizer = T5Tokenizer.from_pretrained(model_path) def make_queries(self, topic, n=1, device='cpu', batch_size=16): ds = YourDataSetClass(pd.DataFrame({'topic': ['make queries: '+topic]*n, 'queries': [[]*n]}, index=range(n)), self.tokenizer, 64, 64, 'topic', 'queries') loader_params = {'batch_size': n if n < batch_size else batch_size, 'shuffle': False, 'num_workers': 0} loader = DataLoader(ds, **loader_params) self.model.eval() predictions = [] with torch.no_grad(): for _, data in enumerate(loader, 0): y = data['target_ids'].to(device, dtype = torch.long) ids = data['source_ids'].to(device, dtype = torch.long) mask = data['source_mask'].to(device, dtype = torch.long) generated_ids = self.model.generate( input_ids = ids, attention_mask = mask, max_length=64, num_beams=1, repetition_penalty=2.5, length_penalty=1.0, do_sample = True, temperature = 1.5, top_k = 10, top_p = 0.95 ) preds = list(set([self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids])) predictions.extend(preds) return list(set(predictions)) ```