daviddrzik commited on
Commit
dc0cf12
·
verified ·
1 Parent(s): 62906d4

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +4 -4
README.md CHANGED
@@ -63,7 +63,7 @@ class SentimentClassifier:
63
  self.model = RobertaForSequenceClassification.from_pretrained(model, num_labels=3)
64
  self.tokenizer = RobertaTokenizerFast.from_pretrained(tokenizer, max_length=256)
65
 
66
- def tokenize_text_bpe(self, text):
67
  encoded_text = self.tokenizer.encode_plus(
68
  text.lower(),
69
  max_length=256,
@@ -73,7 +73,7 @@ class SentimentClassifier:
73
  )
74
  return encoded_text
75
 
76
- def classify_text_with_bpe(self, encoded_text):
77
  with torch.no_grad():
78
  output = self.model(**encoded_text)
79
  logits = output.logits
@@ -91,10 +91,10 @@ text_to_classify = "Kábel dodaný k SSD je krátky a veľmi zle sa ohýba, ten
91
  print("Text to classify: " + text_to_classify + "\n")
92
 
93
  # Tokenize the input text
94
- encoded_text_bpe = classifier.tokenize_text_bpe(text_to_classify)
95
 
96
  # Classify the sentiment of the tokenized text
97
- predicted_class, predicted_class_text, logits = classifier.classify_text_with_bpe(encoded_text_bpe)
98
 
99
  # Print the predicted class label and index
100
  print(f"Predicted class: {predicted_class_text} ({predicted_class})")
 
63
  self.model = RobertaForSequenceClassification.from_pretrained(model, num_labels=3)
64
  self.tokenizer = RobertaTokenizerFast.from_pretrained(tokenizer, max_length=256)
65
 
66
+ def tokenize_text(self, text):
67
  encoded_text = self.tokenizer.encode_plus(
68
  text.lower(),
69
  max_length=256,
 
73
  )
74
  return encoded_text
75
 
76
+ def classify_text(self, encoded_text):
77
  with torch.no_grad():
78
  output = self.model(**encoded_text)
79
  logits = output.logits
 
91
  print("Text to classify: " + text_to_classify + "\n")
92
 
93
  # Tokenize the input text
94
+ encoded_text = classifier.tokenize_text(text_to_classify)
95
 
96
  # Classify the sentiment of the tokenized text
97
+ predicted_class, predicted_class_text, logits = classifier.classify_text(encoded_text)
98
 
99
  # Print the predicted class label and index
100
  print(f"Predicted class: {predicted_class_text} ({predicted_class})")