Spaces:
Runtime error
Runtime error
File size: 902 Bytes
91e858d b760fd0 91e858d b760fd0 91e858d b760fd0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
# Author: Ricardo Lisboa Santos
# Creation date: 2024-01-10
import torch
# import torch_directml
from transformers import pipeline
def getDevice(DEVICE):
device = None
if DEVICE == "cpu":
device = torch.device("cpu")
dtype = torch.float32
elif DEVICE == "cuda":
device = torch.device("cuda")
dtype = torch.float16
# elif DEVICE == "directml":
# device = torch_directml.device()
# dtype = torch.float16
return device
def loadClassifier(device):
classifier = pipeline("sentiment-analysis") # .to(device)
return classifier
def classify(classifier, text):
output = classifier(text)
return output
def clearCache(DEVICE, classifier):
classifier.tokenizer.save_pretrained("cache")
classifier.model.save_pretrained("cache")
del classifier
# if DEVICE == "directml":
# torch_directml.empty_cache()
|