Spaces:
Running
Running
Manually calculate dataloader len
Browse files- translate.py +20 -6
translate.py
CHANGED
|
@@ -1,15 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from transformers import (
|
| 2 |
AutoModelForSeq2SeqLM,
|
| 3 |
AutoTokenizer,
|
| 4 |
PreTrainedTokenizerBase,
|
| 5 |
DataCollatorForSeq2Seq,
|
| 6 |
)
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
from torch.utils.data import DataLoader
|
| 11 |
from dataset import DatasetReader, count_lines
|
| 12 |
-
|
| 13 |
from accelerate import Accelerator, DistributedType
|
| 14 |
from accelerate.memory_utils import find_executable_batch_size
|
| 15 |
|
|
@@ -183,7 +190,14 @@ def main(
|
|
| 183 |
generated_tokens, skip_special_tokens=True
|
| 184 |
)
|
| 185 |
if accelerator.is_main_process:
|
| 186 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
tgt_text = tgt_text[
|
| 188 |
: (total_lines * num_return_sequences) - samples_seen
|
| 189 |
]
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import math
|
| 3 |
+
import argparse
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
from transformers import (
|
| 11 |
AutoModelForSeq2SeqLM,
|
| 12 |
AutoTokenizer,
|
| 13 |
PreTrainedTokenizerBase,
|
| 14 |
DataCollatorForSeq2Seq,
|
| 15 |
)
|
| 16 |
+
|
| 17 |
+
|
|
|
|
|
|
|
| 18 |
from dataset import DatasetReader, count_lines
|
| 19 |
+
|
| 20 |
from accelerate import Accelerator, DistributedType
|
| 21 |
from accelerate.memory_utils import find_executable_batch_size
|
| 22 |
|
|
|
|
| 190 |
generated_tokens, skip_special_tokens=True
|
| 191 |
)
|
| 192 |
if accelerator.is_main_process:
|
| 193 |
+
if (
|
| 194 |
+
step
|
| 195 |
+
== math.ceil(
|
| 196 |
+
math.ceil(total_lines / batch_size)
|
| 197 |
+
/ accelerator.num_processes
|
| 198 |
+
)
|
| 199 |
+
- 1
|
| 200 |
+
):
|
| 201 |
tgt_text = tgt_text[
|
| 202 |
: (total_lines * num_return_sequences) - samples_seen
|
| 203 |
]
|