# Author: Ricardo Lisboa Santos # Creation date: 2024-01-10 import torch # import torch_directml from transformers import pipeline def getDevice(DEVICE): device = None if DEVICE == "cpu": device = torch.device("cpu") dtype = torch.float32 elif DEVICE == "cuda": device = torch.device("cuda") dtype = torch.float16 # elif DEVICE == "directml": # device = torch_directml.device() # dtype = torch.float16 return device def loadSummarizer(device): summarizer = pipeline("summarization") # .to(device) return summarizer def summarize(summarizer, text): output = summarizer(text) return output def clearCache(DEVICE, summarizer): summarizer.tokenizer.save_pretrained("cache") summarizer.model.save_pretrained("cache") del summarizer # if DEVICE == "directml": # torch_directml.empty_cache()