Spaces:
Paused
Paused
import torch | |
from torch import nn | |
from sentence_transformers import SentenceTransformer | |
from datasets import load_dataset | |
from sklearn.utils.class_weight import compute_class_weight | |
from safetensors.torch import load_model | |
from setfit.__init__ import SetFitModel | |
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
class MLP(nn.Module): | |
def __init__(self, input_size=768, output_size=3, dropout_rate=.2, class_weights=None): | |
super(MLP, self).__init__() | |
self.class_weights = class_weights | |
# self.bn1 = nn.BatchNorm1d(hidden_size) | |
self.dropout = nn.Dropout(dropout_rate) | |
self.linear = nn.Linear(input_size, output_size) | |
# nn.init.kaiming_normal_(self.fc1.weight, nonlinearity='relu') | |
# nn.init.kaiming_normal_(self.fc2.weight) | |
def forward(self, x): | |
# return self.linear(self.dropout(x)) | |
return self.dropout(self.linear(x)) | |
def predict(self, x): | |
_, predicted = torch.max(self.forward(x), 1) | |
return predicted | |
def predict_proba(self, x): | |
return self.forward(x) | |
def get_loss_fn(self): | |
return nn.CrossEntropyLoss(weight=self.class_weights, reduction='mean') | |
dataset = load_dataset("CabraVC/vector_dataset_roberta-fine-tuned") | |
class_weights = torch.tensor(compute_class_weight('balanced', classes=[0, 1, 2], y=dataset['train']['labels']), dtype=torch.float) ** .5 | |
model_head = MLP(class_weights=class_weights) | |
if __name__ == '__main__' or __name__ == 'create_setfit_model': | |
model_body = SentenceTransformer('financial-roberta') | |
load_model(model_head, f'models/linear_head.pth') | |
elif __name__ == 'test_models.create_setfit_model': | |
model_body = SentenceTransformer('test_models/financial-roberta') | |
load_model(model_head, f'/test_models/models/linear_head.pth') | |
model = SetFitModel(model_body=model_body, | |
model_head=model_head, | |
labels=dataset['train'].features['labels'].names).to(DEVICE) | |
if __name__ == '__main__': | |
from time import perf_counter | |
start = perf_counter() | |
test_sentences = [ | |
"""Two thousand and six was a very good year for The Coca-Cola Company. We achieved our 52nd | |
consecutive year of unit case volume growth. Volume reached a record high of 2.4 billion unit cases. | |
Net operating revenues grew 4 percent to $24.billion, and operating income grew | |
4 percent to $6.3 billion. Our total return to shareowners was 23 percent, outperforming the Dow | |
Jones Industrial Average and the S&P 500. By virtually every measure, we met or exceeded our | |
objectives—a strong ending for the year with great momentum for entering 2007.""", | |
""" | |
The secret formula to our success in 2006? There is no one answer. Our inspiration comes from | |
many sources—our bottling partners, retail customers and consumers, as well as our critics. And the | |
men and women of The Coca-Cola Company have a passion for what they do that ignites this | |
inspiration every day, everywhere we do business. We remain fresh, relevant and original by knowing | |
what | |
to change without changing what we know. We are asking more questions, listening more closely and | |
collaborating more effectively with our bottling partners, suppliers and retail customers to give | |
consumers what they want. | |
""", | |
""" | |
And we continue to strengthen our bench, nurturing leaders and promoting from within our | |
organization. As 2006 came to a close, our Board of Directors elected Muhtar Kent as president and | |
chief operating officer of our Company. Muhtar is a 28-year veteran of the Coca-Cola system (the | |
Company and our bottling partners). Muhtar’s close working relationships with our bottling partners | |
will enable us to continue capturing marketplace opportunities and improving our business. Other | |
system veterans promoted and now leading operating groups include Ahmet Bozer, Eurasia; Sandy | |
Douglas, North America; and Glenn Jordan, Pacific. Combined, these leaders have 65 years of Coca- | |
Cola system experience. | |
""" | |
] | |
# for sentence in test_sentences: | |
# print(model(sentence)) | |
# print('-' * 50) | |
print(model(test_sentences)) | |
print(f'It took me: {(perf_counter() - start) // 60:.0f} mins {(perf_counter() - start) % 60:.0f} secs') |