ar5entum commited on
Commit
8f517ec
·
verified ·
1 Parent(s): 24949b5

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +79 -6
README.md CHANGED
@@ -28,17 +28,90 @@ It achieves the following results on the evaluation set:
28
 
29
  ## Model description
30
 
31
- More information needed
32
 
33
- ## Intended uses & limitations
34
 
35
- More information needed
 
 
 
36
 
37
- ## Training and evaluation data
 
 
 
 
 
 
 
38
 
39
- More information needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- ## Training procedure
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  ### Training hyperparameters
44
 
 
28
 
29
  ## Model description
30
 
31
+ Machine Translation model from English to Hindi on bart small model.
32
 
33
+ ## Inference and Evaluation
34
 
35
+ ```python
36
+ import torch
37
+ import evaluate
38
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
39
 
40
+ class BartSmall():
41
+ def __init__(self, model_path = 'ar5entum/bart_eng_hin_mt', device = None):
42
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
43
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
44
+ if not device:
45
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
+ self.device = device
47
+ self.model.to(device)
48
 
49
+ def predict(self, input_text):
50
+ inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).to(self.device)
51
+ pred_ids = self.model.generate(inputs.input_ids, max_length=512, num_beams=4, early_stopping=True)
52
+ prediction = self.tokenizer.decode(pred_ids[0], skip_special_tokens=True)
53
+ return prediction
54
+
55
+ def predict_batch(self, input_texts, batch_size=32):
56
+ all_predictions = []
57
+ for i in range(0, len(input_texts), batch_size):
58
+ batch_texts = input_texts[i:i+batch_size]
59
+ inputs = self.tokenizer(batch_texts, return_tensors="pt", max_length=512,
60
+ truncation=True, padding=True).to(self.device)
61
+
62
+ with torch.no_grad():
63
+ pred_ids = self.model.generate(inputs.input_ids,
64
+ max_length=512,
65
+ num_beams=4,
66
+ early_stopping=True)
67
+
68
+ predictions = self.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
69
+ all_predictions.extend(predictions)
70
 
71
+ return all_predictions
72
+
73
+ model = BartSmall(device='cuda')
74
+
75
+ input_texts = [
76
+ "This is a repayable amount.",
77
+ "Watch this video to find out.",
78
+ "He was a father of two daughters and a son."
79
+ ]
80
+ ground_truths = [
81
+ "यह शोध्य रकम है।",
82
+ "जानने के लिए देखें ये वीडियो.",
83
+ "वह दो बेटियों व एक बेटे का पिता था।"
84
+ ]
85
+ import time
86
+ start = time.time()
87
+
88
+ predictions = model.predict_batch(input_texts, batch_size=len(input_texts))
89
+ end = time.time()
90
+ print("TIME: ", end-start)
91
+ for i in range(len(input_texts)):
92
+ print("‾‾‾‾‾‾‾‾‾‾‾‾")
93
+ print("Input text:\t", input_texts[i])
94
+ print("Prediction:\t", predictions[i])
95
+ print("Ground Truth:\t", ground_truths[i])
96
+ bleu = evaluate.load("bleu")
97
+ results = bleu.compute(predictions=predictions, references=ground_truths)
98
+ print(results)
99
+
100
+ # TIME: 3.65848970413208
101
+ # ‾‾‾‾‾‾‾‾‾‾‾‾
102
+ # Input text: This is a repayable amount.
103
+ # Prediction: यह एक चुकौती राशि है।
104
+ # Ground Truth: यह शोध्य रकम है।
105
+ # ‾‾‾‾‾‾‾‾‾‾‾‾
106
+ # Input text: Watch this video to find out.
107
+ # Prediction: इस वीडियो को बाहर ढूंढने के लिए इस वीडियो को देख�
108
+ # Ground Truth: जानने के लिए देखें ये वीडियो.
109
+ # ‾‾‾‾‾‾‾‾‾‾‾‾
110
+ # Input text: He was a father of two daughters and a son.
111
+ # Prediction: वह दो बेटियों और एक पुत्र के पिता थे।
112
+ # Ground Truth: वह दो बेटियों व एक बेटे का पिता था।
113
+ # {'bleu': 0.0, 'precisions': [0.4, 0.13636363636363635, 0.05263157894736842, 0.0], 'brevity_penalty': 1.0, 'length_ratio': 1.25, 'translation_length': 25, 'reference_length': 20}
114
+ ```
115
 
116
  ### Training hyperparameters
117