basicTransformersExample / AI /summarization.py
ricardo-lsantos's picture
Commented torch_directml
b760fd0
raw
history blame contribute delete
898 Bytes
# Author: Ricardo Lisboa Santos
# Creation date: 2024-01-10
import torch
# import torch_directml
from transformers import pipeline
def getDevice(DEVICE):
device = None
if DEVICE == "cpu":
device = torch.device("cpu")
dtype = torch.float32
elif DEVICE == "cuda":
device = torch.device("cuda")
dtype = torch.float16
# elif DEVICE == "directml":
# device = torch_directml.device()
# dtype = torch.float16
return device
def loadSummarizer(device):
summarizer = pipeline("summarization") # .to(device)
return summarizer
def summarize(summarizer, text):
output = summarizer(text)
return output
def clearCache(DEVICE, summarizer):
summarizer.tokenizer.save_pretrained("cache")
summarizer.model.save_pretrained("cache")
del summarizer
# if DEVICE == "directml":
# torch_directml.empty_cache()