English
naveensp commited on
Commit
2178aa3
·
verified ·
1 Parent(s): 80fa3bb

Upload util.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. util.py +655 -0
util.py ADDED
@@ -0,0 +1,655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import re
4
+ import socket
5
+ import sys
6
+ import time
7
+ import warnings
8
+ from datetime import datetime
9
+ from enum import Enum
10
+ from itertools import cycle, islice
11
+ from pathlib import Path
12
+ from queue import Queue
13
+ from threading import Thread
14
+ from typing import Any, Callable, Dict, Optional, Union
15
+
16
+ import boto3
17
+ import botocore.exceptions as boto_exceptions
18
+ import rich
19
+ from botocore.config import Config
20
+ from rich.console import Console, ConsoleRenderable
21
+ from rich.highlighter import NullHighlighter
22
+ from rich.progress import Progress
23
+ from rich.text import Text
24
+ from rich.traceback import Traceback
25
+
26
+ from .aliases import PathOrStr
27
+ from .exceptions import (
28
+ OLMoCliError,
29
+ OLMoEnvironmentError,
30
+ OLMoError,
31
+ OLMoNetworkError,
32
+ OLMoThreadError,
33
+ )
34
+ from .torch_util import get_global_rank, get_local_rank, get_node_rank, is_distributed
35
+
36
+ try:
37
+ from functools import cache
38
+ except ImportError:
39
+ from functools import lru_cache as cache
40
+
41
+
42
+ class StrEnum(str, Enum):
43
+ """
44
+ This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
45
+ We include this here for compatibility with older version of Python.
46
+ """
47
+
48
+ def __str__(self) -> str:
49
+ return self.value
50
+
51
+ def __repr__(self) -> str:
52
+ return f"'{str(self)}'"
53
+
54
+
55
+ _log_extra_fields: Dict[str, Any] = {}
56
+ log = logging.getLogger(__name__)
57
+
58
+
59
+ class LogFilterType(StrEnum):
60
+ rank0_only = "rank0_only"
61
+ local_rank0_only = "local_rank0_only"
62
+ all_ranks = "all_ranks"
63
+
64
+
65
+ def log_extra_field(field_name: str, field_value: Any) -> None:
66
+ global _log_extra_fields
67
+ if field_value is None:
68
+ if field_name in _log_extra_fields:
69
+ del _log_extra_fields[field_name]
70
+ else:
71
+ _log_extra_fields[field_name] = field_value
72
+
73
+
74
+ def setup_logging(log_filter_type: LogFilterType = LogFilterType.rank0_only) -> None:
75
+ """
76
+ :param rank0_only: INFO and below messages will only be emitted on the rank0 process.
77
+ """
78
+ log_extra_field("hostname", socket.gethostname())
79
+ if is_distributed():
80
+ log_extra_field("node_rank", get_node_rank())
81
+ log_extra_field("local_rank", get_local_rank())
82
+ log_extra_field("global_rank", get_global_rank())
83
+ else:
84
+ log_extra_field("node_rank", 0)
85
+ log_extra_field("local_rank", 0)
86
+ log_extra_field("global_rank", 0)
87
+
88
+ old_log_record_factory = logging.getLogRecordFactory()
89
+
90
+ def log_record_factory(*args, **kwargs) -> logging.LogRecord:
91
+ record = old_log_record_factory(*args, **kwargs)
92
+ for field_name, field_value in _log_extra_fields.items():
93
+ setattr(record, field_name, field_value)
94
+ return record
95
+
96
+ logging.setLogRecordFactory(log_record_factory)
97
+
98
+ handler: logging.Handler
99
+ if (
100
+ os.environ.get("OLMo_NONINTERACTIVE", False)
101
+ or os.environ.get("DEBIAN_FRONTEND", None) == "noninteractive"
102
+ or not sys.stdout.isatty()
103
+ ):
104
+ handler = logging.StreamHandler(sys.stdout)
105
+ formatter = logging.Formatter(
106
+ "%(asctime)s\t%(hostname)s:%(local_rank)s\t%(name)s:%(lineno)s\t%(levelname)s\t%(message)s"
107
+ )
108
+ formatter.default_time_format = "%Y-%m-%d %H:%M:%S"
109
+ formatter.default_msec_format = "%s.%03d"
110
+ handler.setFormatter(formatter)
111
+ else:
112
+ handler = RichHandler()
113
+
114
+ def rank0_filter(record: logging.LogRecord) -> int:
115
+ if record.levelno > logging.INFO:
116
+ return 1
117
+ if getattr(record, "global_rank", 0) == 0:
118
+ return 1
119
+ else:
120
+ return 0
121
+
122
+ def local_rank0_filter(record: logging.LogRecord) -> int:
123
+ if record.levelno > logging.INFO:
124
+ return 1
125
+ if getattr(record, "local_rank", 0) == 0:
126
+ return 1
127
+ else:
128
+ return 0
129
+
130
+ if log_filter_type == LogFilterType.rank0_only:
131
+ filter = rank0_filter
132
+ elif log_filter_type == LogFilterType.local_rank0_only:
133
+ filter = local_rank0_filter # type: ignore
134
+ elif log_filter_type == LogFilterType.all_ranks:
135
+ filter = None
136
+ else:
137
+ raise ValueError(log_filter_type)
138
+
139
+ if filter is not None:
140
+ handler.addFilter(filter) # type: ignore
141
+ logging.basicConfig(handlers=[handler], level=logging.INFO)
142
+
143
+ logging.captureWarnings(True)
144
+ logging.getLogger("urllib3").setLevel(logging.ERROR)
145
+
146
+
147
+ def excepthook(exctype, value, traceback):
148
+ """
149
+ Used to patch `sys.excepthook` in order to log exceptions.
150
+ """
151
+ if issubclass(exctype, KeyboardInterrupt):
152
+ sys.__excepthook__(exctype, value, traceback)
153
+ elif issubclass(exctype, OLMoCliError):
154
+ rich.get_console().print(f"[yellow]{value}[/]", highlight=False)
155
+ elif issubclass(exctype, OLMoError):
156
+ rich.get_console().print(Text(f"{exctype.__name__}:", style="red"), value, highlight=False)
157
+ else:
158
+ log.critical("Uncaught %s: %s", exctype.__name__, value, exc_info=(exctype, value, traceback))
159
+
160
+
161
+ def install_excepthook():
162
+ sys.excepthook = excepthook
163
+
164
+
165
+ def filter_warnings():
166
+ # Filter internal deprecation warnings from torch
167
+ warnings.filterwarnings(
168
+ action="ignore",
169
+ category=UserWarning,
170
+ message="torch.distributed.*_base is a private function and will be deprecated.*",
171
+ )
172
+ warnings.filterwarnings(
173
+ action="ignore",
174
+ category=UserWarning,
175
+ message="TypedStorage is deprecated.*",
176
+ )
177
+ warnings.filterwarnings(
178
+ action="ignore",
179
+ category=UserWarning,
180
+ message="Please use DTensor instead.*",
181
+ )
182
+ # Torchvision warnings. We don't actually use torchvision.
183
+ warnings.filterwarnings(
184
+ action="ignore",
185
+ message="failed to load.*",
186
+ module="torchvision.io.image",
187
+ )
188
+
189
+
190
+ def set_env_variables():
191
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
192
+
193
+
194
+ def prepare_cli_environment(log_filter_type: Optional[LogFilterType] = None):
195
+ if log_filter_type is None:
196
+ log_filter_type = LogFilterType(os.environ.get("LOG_FILTER_TYPE", "rank0_only"))
197
+ rich.reconfigure(width=max(rich.get_console().width, 180), soft_wrap=True)
198
+ setup_logging(log_filter_type=log_filter_type)
199
+ install_excepthook()
200
+ filter_warnings()
201
+ set_env_variables()
202
+
203
+
204
+ def clean_opt(arg: str) -> str:
205
+ if "=" not in arg:
206
+ arg = f"{arg}=True"
207
+ name, val = arg.split("=", 1)
208
+ name = name.strip("-").replace("-", "_")
209
+ return f"{name}={val}"
210
+
211
+
212
+ class RichHandler(logging.Handler):
213
+ """
214
+ A simplified version of rich.logging.RichHandler from
215
+ https://github.com/Textualize/rich/blob/master/rich/logging.py
216
+ """
217
+
218
+ def __init__(
219
+ self,
220
+ *,
221
+ level: Union[int, str] = logging.NOTSET,
222
+ console: Optional[Console] = None,
223
+ markup: bool = False,
224
+ ) -> None:
225
+ super().__init__(level=level)
226
+ self.console = console or rich.get_console()
227
+ self.highlighter = NullHighlighter()
228
+ self.markup = markup
229
+
230
+ def emit(self, record: logging.LogRecord) -> None:
231
+ try:
232
+ if hasattr(record.msg, "__rich__") or hasattr(record.msg, "__rich_console__"):
233
+ self.console.print(record.msg)
234
+ else:
235
+ msg: Any = record.msg
236
+ if isinstance(record.msg, str):
237
+ msg = self.render_message(record=record, message=record.getMessage())
238
+ renderables = [
239
+ self.get_time_text(record),
240
+ self.get_level_text(record),
241
+ self.get_location_text(record),
242
+ msg,
243
+ ]
244
+ if record.exc_info is not None:
245
+ tb = Traceback.from_exception(*record.exc_info) # type: ignore
246
+ renderables.append(tb)
247
+ self.console.print(*renderables)
248
+ except Exception:
249
+ self.handleError(record)
250
+
251
+ def render_message(self, *, record: logging.LogRecord, message: str) -> ConsoleRenderable:
252
+ use_markup = getattr(record, "markup", self.markup)
253
+ message_text = Text.from_markup(message) if use_markup else Text(message)
254
+
255
+ highlighter = getattr(record, "highlighter", self.highlighter)
256
+ if highlighter:
257
+ message_text = highlighter(message_text)
258
+
259
+ return message_text
260
+
261
+ def get_time_text(self, record: logging.LogRecord) -> Text:
262
+ log_time = datetime.fromtimestamp(record.created)
263
+ time_str = log_time.strftime("[%Y-%m-%d %X]")
264
+ return Text(time_str, style="log.time", end=" ")
265
+
266
+ def get_level_text(self, record: logging.LogRecord) -> Text:
267
+ level_name = record.levelname
268
+ level_text = Text.styled(level_name.ljust(8), f"logging.level.{level_name.lower()}")
269
+ level_text.style = "log.level"
270
+ level_text.end = " "
271
+ return level_text
272
+
273
+ def get_location_text(self, record: logging.LogRecord) -> Text:
274
+ name_and_line = f"{record.name}:{record.lineno}" if record.name != "root" else "root"
275
+ text = f"[{name_and_line}, rank={record.local_rank}]" # type: ignore
276
+ return Text(text, style="log.path")
277
+
278
+
279
+ def wait_for(condition: Callable[[], bool], description: str, timeout: float = 10.0):
280
+ """Wait for the condition function to return True."""
281
+ start_time = time.monotonic()
282
+ while not condition():
283
+ time.sleep(0.5)
284
+ if time.monotonic() - start_time > timeout:
285
+ raise TimeoutError(f"{description} timed out")
286
+
287
+
288
+ def is_url(path: PathOrStr) -> bool:
289
+ return re.match(r"[a-z0-9]+://.*", str(path)) is not None
290
+
291
+
292
+ def dir_is_empty(dir: PathOrStr) -> bool:
293
+ dir = Path(dir)
294
+ if not dir.is_dir():
295
+ return True
296
+ try:
297
+ next(dir.glob("*"))
298
+ return False
299
+ except StopIteration:
300
+ return True
301
+
302
+
303
+ def get_progress_bar() -> Progress:
304
+ from cached_path import get_download_progress
305
+
306
+ return get_download_progress()
307
+
308
+
309
+ def resource_path(
310
+ folder: PathOrStr, fname: str, local_cache: Optional[PathOrStr] = None, progress: Optional[Progress] = None
311
+ ) -> Path:
312
+ if local_cache is not None and (local_path := Path(local_cache) / fname).is_file():
313
+ log.info(f"Found local cache of {fname} at {local_path}")
314
+ return local_path
315
+ else:
316
+ from cached_path import cached_path
317
+
318
+ return cached_path(f"{str(folder).rstrip('/')}/{fname}", progress=progress)
319
+
320
+
321
+ def file_size(path: PathOrStr) -> int:
322
+ """
323
+ Get the size of a local or remote file in bytes.
324
+ """
325
+ if is_url(path):
326
+ from urllib.parse import urlparse
327
+
328
+ parsed = urlparse(str(path))
329
+ if parsed.scheme == "gs":
330
+ return _gcs_file_size(parsed.netloc, parsed.path.strip("/"))
331
+ elif parsed.scheme in ("s3", "r2"):
332
+ return _s3_file_size(parsed.scheme, parsed.netloc, parsed.path.strip("/"))
333
+ elif parsed.scheme == "file":
334
+ return file_size(str(path).replace("file://", "", 1))
335
+ else:
336
+ raise NotImplementedError(f"file size not implemented for '{parsed.scheme}' files")
337
+ else:
338
+ return os.stat(path).st_size
339
+
340
+
341
+ def upload(source: PathOrStr, target: str, save_overwrite: bool = False):
342
+ """Upload source file to a target location on GCS or S3."""
343
+ from urllib.parse import urlparse
344
+
345
+ source = Path(source)
346
+ assert source.is_file()
347
+ parsed = urlparse(target)
348
+ if parsed.scheme == "gs":
349
+ _gcs_upload(source, parsed.netloc, parsed.path.strip("/"), save_overwrite=save_overwrite)
350
+ elif parsed.scheme in ("s3", "r2"):
351
+ _s3_upload(source, parsed.scheme, parsed.netloc, parsed.path.strip("/"), save_overwrite=save_overwrite)
352
+ else:
353
+ raise NotImplementedError(f"Upload not implemented for '{parsed.scheme}' scheme")
354
+
355
+
356
+ def get_bytes_range(source: PathOrStr, bytes_start: int, num_bytes: int) -> bytes:
357
+ if is_url(source):
358
+ from urllib.parse import urlparse
359
+
360
+ parsed = urlparse(str(source))
361
+ if parsed.scheme == "gs":
362
+ return _gcs_get_bytes_range(parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes)
363
+ elif parsed.scheme in ("s3", "r2"):
364
+ return _s3_get_bytes_range(
365
+ parsed.scheme, parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes
366
+ )
367
+ elif parsed.scheme == "file":
368
+ return get_bytes_range(str(source).replace("file://", "", 1), bytes_start, num_bytes)
369
+ else:
370
+ raise NotImplementedError(f"file size not implemented for '{parsed.scheme}' files")
371
+ else:
372
+ with open(source, "rb") as f:
373
+ f.seek(bytes_start)
374
+ return f.read(num_bytes)
375
+
376
+
377
+ def find_latest_checkpoint(dir: PathOrStr) -> Optional[PathOrStr]:
378
+ if is_url(dir):
379
+ from urllib.parse import urlparse
380
+
381
+ parsed = urlparse(str(dir))
382
+ if parsed.scheme == "gs":
383
+ raise NotImplementedError
384
+ elif parsed.scheme in ("s3", "r2"):
385
+ return _s3_find_latest_checkpoint(parsed.scheme, parsed.netloc, parsed.path.strip("/"))
386
+ elif parsed.scheme == "file":
387
+ return find_latest_checkpoint(str(dir).replace("file://", "", 1))
388
+ else:
389
+ raise NotImplementedError(f"find_latest_checkpoint not implemented for '{parsed.scheme}' files")
390
+ else:
391
+ latest_step = 0
392
+ latest_checkpoint: Optional[Path] = None
393
+ for path in Path(dir).glob("step*"):
394
+ if path.is_dir():
395
+ try:
396
+ step = int(path.name.replace("step", "").replace("-unsharded", ""))
397
+ except ValueError:
398
+ continue
399
+ # We prioritize sharded checkpoints over unsharded checkpoints.
400
+ if step > latest_step or (step == latest_step and not path.name.endswith("-unsharded")):
401
+ latest_step = step
402
+ latest_checkpoint = path
403
+ return latest_checkpoint
404
+
405
+
406
+ def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False):
407
+ from google.cloud import storage as gcs
408
+
409
+ storage_client = gcs.Client()
410
+ bucket = storage_client.bucket(bucket_name)
411
+ blob = bucket.blob(key)
412
+ if not save_overwrite and blob.exists():
413
+ raise FileExistsError(f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it.")
414
+ blob.upload_from_filename(source)
415
+
416
+
417
+ def _gcs_file_size(bucket_name: str, key: str) -> int:
418
+ from google.api_core.exceptions import NotFound
419
+ from google.cloud import storage as gcs
420
+
421
+ storage_client = gcs.Client()
422
+ bucket = storage_client.bucket(bucket_name)
423
+ blob = bucket.blob(key)
424
+ try:
425
+ blob.reload()
426
+ except NotFound:
427
+ raise FileNotFoundError(f"gs://{bucket_name}/{key}")
428
+ assert blob.size is not None
429
+ return blob.size
430
+
431
+
432
+ def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes: int) -> bytes:
433
+ from google.api_core.exceptions import NotFound
434
+ from google.cloud import storage as gcs
435
+
436
+ storage_client = gcs.Client()
437
+ bucket = storage_client.bucket(bucket_name)
438
+ blob = bucket.blob(key)
439
+ try:
440
+ blob.reload()
441
+ except NotFound:
442
+ raise FileNotFoundError(f"gs://{bucket_name}/{key}")
443
+ return blob.download_as_bytes(start=bytes_start, end=bytes_start + num_bytes - 1)
444
+
445
+
446
+ def _get_s3_profile_name(scheme: str) -> Optional[str]:
447
+ if scheme == "s3":
448
+ # For backwards compatibility, we assume S3 uses the default profile if S3_PROFILE is not set.
449
+ return os.environ.get("S3_PROFILE")
450
+ if scheme == "r2":
451
+ profile_name = os.environ.get("R2_PROFILE")
452
+ if profile_name is None:
453
+ raise OLMoEnvironmentError(
454
+ "R2 profile name is not set. Did you forget to set the 'R2_PROFILE' env var?"
455
+ )
456
+
457
+ return profile_name
458
+
459
+ raise NotImplementedError(f"Cannot get profile name for scheme {scheme}")
460
+
461
+
462
+ def _get_s3_endpoint_url(scheme: str) -> Optional[str]:
463
+ if scheme == "s3":
464
+ return None
465
+ if scheme == "r2":
466
+ r2_endpoint_url = os.environ.get("R2_ENDPOINT_URL")
467
+ if r2_endpoint_url is None:
468
+ raise OLMoEnvironmentError(
469
+ "R2 endpoint url is not set. Did you forget to set the 'R2_ENDPOINT_URL' env var?"
470
+ )
471
+
472
+ return r2_endpoint_url
473
+
474
+ raise NotImplementedError(f"Cannot get endpoint url for scheme {scheme}")
475
+
476
+
477
+ @cache
478
+ def _get_s3_client(scheme: str):
479
+ session = boto3.Session(profile_name=_get_s3_profile_name(scheme))
480
+ return session.client(
481
+ "s3",
482
+ endpoint_url=_get_s3_endpoint_url(scheme),
483
+ config=Config(retries={"max_attempts": 10, "mode": "standard"}),
484
+ use_ssl=not int(os.environ.get("OLMO_NO_SSL", "0")),
485
+ )
486
+
487
+
488
+ def _wait_before_retry(attempt: int):
489
+ time.sleep(min(0.5 * 2**attempt, 3.0))
490
+
491
+
492
+ def _s3_upload(
493
+ source: Path, scheme: str, bucket_name: str, key: str, save_overwrite: bool = False, max_attempts: int = 3
494
+ ):
495
+ err: Optional[Exception] = None
496
+ if not save_overwrite:
497
+ for attempt in range(1, max_attempts + 1):
498
+ try:
499
+ _get_s3_client(scheme).head_object(Bucket=bucket_name, Key=key)
500
+ raise FileExistsError(
501
+ f"s3://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it."
502
+ )
503
+ except boto_exceptions.ClientError as e:
504
+ if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
505
+ err = None
506
+ break
507
+ err = e
508
+
509
+ if attempt < max_attempts:
510
+ log.warning("%s failed attempt %d with retriable error: %s", _s3_upload.__name__, attempt, err)
511
+ _wait_before_retry(attempt)
512
+
513
+ if err is not None:
514
+ raise OLMoNetworkError(f"Failed to check object existence during {scheme} upload") from err
515
+
516
+ try:
517
+ _get_s3_client(scheme).upload_file(source, bucket_name, key)
518
+ except boto_exceptions.ClientError as e:
519
+ raise OLMoNetworkError(f"Failed to upload to {scheme}") from e
520
+
521
+
522
+ def _s3_file_size(scheme: str, bucket_name: str, key: str, max_attempts: int = 3) -> int:
523
+ err: Optional[Exception] = None
524
+ for attempt in range(1, max_attempts + 1):
525
+ try:
526
+ return _get_s3_client(scheme).head_object(Bucket=bucket_name, Key=key)["ContentLength"]
527
+ except boto_exceptions.ClientError as e:
528
+ if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
529
+ raise FileNotFoundError(f"s3://{bucket_name}/{key}") from e
530
+ err = e
531
+
532
+ if attempt < max_attempts:
533
+ log.warning("%s failed attempt %d with retriable error: %s", _s3_file_size.__name__, attempt, err)
534
+ _wait_before_retry(attempt)
535
+
536
+ raise OLMoNetworkError(f"Failed to get {scheme} file size") from err
537
+
538
+
539
+ def _s3_get_bytes_range(
540
+ scheme: str, bucket_name: str, key: str, bytes_start: int, num_bytes: int, max_attempts: int = 3
541
+ ) -> bytes:
542
+ err: Optional[Exception] = None
543
+ for attempt in range(1, max_attempts + 1):
544
+ try:
545
+ return (
546
+ _get_s3_client(scheme)
547
+ .get_object(
548
+ Bucket=bucket_name, Key=key, Range=f"bytes={bytes_start}-{bytes_start + num_bytes - 1}"
549
+ )["Body"]
550
+ .read()
551
+ )
552
+ except boto_exceptions.ClientError as e:
553
+ if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
554
+ raise FileNotFoundError(f"{scheme}://{bucket_name}/{key}") from e
555
+ err = e
556
+ except (boto_exceptions.HTTPClientError, boto_exceptions.ConnectionError) as e:
557
+ # ResponseStreamingError (subclass of HTTPClientError) can happen as
558
+ # a result of a failed read from the stream (http.client.IncompleteRead).
559
+ # Retrying can help in this case.
560
+ err = e
561
+
562
+ if attempt < max_attempts:
563
+ log.warning(
564
+ "%s failed attempt %d with retriable error: %s", _s3_get_bytes_range.__name__, attempt, err
565
+ )
566
+ _wait_before_retry(attempt)
567
+
568
+ # When torch's DataLoader intercepts exceptions, it may try to re-raise them
569
+ # by recalling their constructor with a single message arg. Torch has some
570
+ # logic to deal with the absence of a single-parameter constructor, but it
571
+ # doesn't gracefully handle other possible failures in calling such a constructor
572
+ # This can cause an irrelevant exception (e.g. KeyError: 'error'), resulting
573
+ # in us losing the true exception info. To avoid this, we change the exception
574
+ # to a type that has a single-parameter constructor.
575
+ raise OLMoNetworkError(f"Failed to get bytes range from {scheme}") from err
576
+
577
+
578
+ def _s3_find_latest_checkpoint(scheme: str, bucket_name: str, prefix: str) -> Optional[str]:
579
+ if not prefix.endswith("/"):
580
+ prefix = f"{prefix}/"
581
+ response = _get_s3_client(scheme).list_objects(Bucket=bucket_name, Prefix=prefix, Delimiter="/")
582
+ assert not response["IsTruncated"] # need to handle this if it happens
583
+ latest_step = 0
584
+ latest_checkpoint: Optional[str] = None
585
+ for item in response["CommonPrefixes"]:
586
+ prefix = item["Prefix"].strip("/")
587
+ checkpoint_name = os.path.split(prefix)[-1]
588
+ if not checkpoint_name.startswith("step"):
589
+ continue
590
+ try:
591
+ step = int(checkpoint_name.replace("step", "").replace("-unsharded", ""))
592
+ except ValueError:
593
+ continue
594
+ # Make sure the checkpoint dir contains a config, otherwise the checkpoint is incomplete
595
+ # (upload might have have failed part way through).
596
+ try:
597
+ _s3_file_size(scheme, bucket_name, f"{prefix}/config.yaml")
598
+ except FileNotFoundError:
599
+ continue
600
+ # We prioritize sharded checkpoints over unsharded ones.
601
+ if step > latest_step or (step == latest_step and not checkpoint_name.endswith("-unsharded")):
602
+ latest_step = step
603
+ latest_checkpoint = f"{scheme}://ai2-llm/{prefix}"
604
+ return latest_checkpoint
605
+
606
+
607
+ def default_thread_count() -> int:
608
+ return int(os.environ.get("OLMO_NUM_THREADS") or min(32, (os.cpu_count() or 1) + 4))
609
+
610
+
611
+ def pass_through_fn(fn, *args, **kwargs):
612
+ return fn(*args, **kwargs)
613
+
614
+
615
+ def threaded_generator(g, maxsize: int = 16, thread_name: Optional[str] = None):
616
+ q: Queue = Queue(maxsize=maxsize)
617
+
618
+ sentinel = object()
619
+
620
+ def fill_queue():
621
+ try:
622
+ for value in g:
623
+ q.put(value)
624
+ except Exception as e:
625
+ q.put(e)
626
+ finally:
627
+ q.put(sentinel)
628
+
629
+ thread_name = thread_name or repr(g)
630
+ thread = Thread(name=thread_name, target=fill_queue, daemon=True)
631
+ thread.start()
632
+
633
+ for x in iter(q.get, sentinel):
634
+ if isinstance(x, Exception):
635
+ raise OLMoThreadError(f"generator thread {thread_name} failed") from x
636
+ else:
637
+ yield x
638
+
639
+
640
+ def roundrobin(*iterables):
641
+ """
642
+ Call the given iterables in a round-robin fashion. For example:
643
+ ``roundrobin('ABC', 'D', 'EF') --> A D E B F C``
644
+ """
645
+ # Adapted from https://docs.python.org/3/library/itertools.html#itertools-recipes
646
+ num_active = len(iterables)
647
+ nexts = cycle(iter(it).__next__ for it in iterables)
648
+ while num_active:
649
+ try:
650
+ for next in nexts:
651
+ yield next()
652
+ except StopIteration:
653
+ # Remove the iterator we just exhausted from the cycle.
654
+ num_active -= 1
655
+ nexts = cycle(islice(nexts, num_active))