daviddrzik
commited on
Update README.md
Browse files
README.md
CHANGED
@@ -63,7 +63,7 @@ class SentimentClassifier:
|
|
63 |
self.model = RobertaForSequenceClassification.from_pretrained(model, num_labels=3)
|
64 |
self.tokenizer = RobertaTokenizerFast.from_pretrained(tokenizer, max_length=256)
|
65 |
|
66 |
-
def
|
67 |
encoded_text = self.tokenizer.encode_plus(
|
68 |
text.lower(),
|
69 |
max_length=256,
|
@@ -73,7 +73,7 @@ class SentimentClassifier:
|
|
73 |
)
|
74 |
return encoded_text
|
75 |
|
76 |
-
def
|
77 |
with torch.no_grad():
|
78 |
output = self.model(**encoded_text)
|
79 |
logits = output.logits
|
@@ -91,10 +91,10 @@ text_to_classify = "Kábel dodaný k SSD je krátky a veľmi zle sa ohýba, ten
|
|
91 |
print("Text to classify: " + text_to_classify + "\n")
|
92 |
|
93 |
# Tokenize the input text
|
94 |
-
|
95 |
|
96 |
# Classify the sentiment of the tokenized text
|
97 |
-
predicted_class, predicted_class_text, logits = classifier.
|
98 |
|
99 |
# Print the predicted class label and index
|
100 |
print(f"Predicted class: {predicted_class_text} ({predicted_class})")
|
|
|
63 |
self.model = RobertaForSequenceClassification.from_pretrained(model, num_labels=3)
|
64 |
self.tokenizer = RobertaTokenizerFast.from_pretrained(tokenizer, max_length=256)
|
65 |
|
66 |
+
def tokenize_text(self, text):
|
67 |
encoded_text = self.tokenizer.encode_plus(
|
68 |
text.lower(),
|
69 |
max_length=256,
|
|
|
73 |
)
|
74 |
return encoded_text
|
75 |
|
76 |
+
def classify_text(self, encoded_text):
|
77 |
with torch.no_grad():
|
78 |
output = self.model(**encoded_text)
|
79 |
logits = output.logits
|
|
|
91 |
print("Text to classify: " + text_to_classify + "\n")
|
92 |
|
93 |
# Tokenize the input text
|
94 |
+
encoded_text = classifier.tokenize_text(text_to_classify)
|
95 |
|
96 |
# Classify the sentiment of the tokenized text
|
97 |
+
predicted_class, predicted_class_text, logits = classifier.classify_text(encoded_text)
|
98 |
|
99 |
# Print the predicted class label and index
|
100 |
print(f"Predicted class: {predicted_class_text} ({predicted_class})")
|