English
naveensp commited on
Commit
9b5ebda
·
verified ·
1 Parent(s): 8aa1ccc

Delete util.py

Browse files
Files changed (1) hide show
  1. util.py +0 -655
util.py DELETED
@@ -1,655 +0,0 @@
1
- import logging
2
- import os
3
- import re
4
- import socket
5
- import sys
6
- import time
7
- import warnings
8
- from datetime import datetime
9
- from enum import Enum
10
- from itertools import cycle, islice
11
- from pathlib import Path
12
- from queue import Queue
13
- from threading import Thread
14
- from typing import Any, Callable, Dict, Optional, Union
15
-
16
- import boto3
17
- import botocore.exceptions as boto_exceptions
18
- import rich
19
- from botocore.config import Config
20
- from rich.console import Console, ConsoleRenderable
21
- from rich.highlighter import NullHighlighter
22
- from rich.progress import Progress
23
- from rich.text import Text
24
- from rich.traceback import Traceback
25
-
26
- from .aliases import PathOrStr
27
- from .exceptions import (
28
- OLMoCliError,
29
- OLMoEnvironmentError,
30
- OLMoError,
31
- OLMoNetworkError,
32
- OLMoThreadError,
33
- )
34
- from .torch_util import get_global_rank, get_local_rank, get_node_rank, is_distributed
35
-
36
- try:
37
- from functools import cache
38
- except ImportError:
39
- from functools import lru_cache as cache
40
-
41
-
42
- class StrEnum(str, Enum):
43
- """
44
- This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
45
- We include this here for compatibility with older version of Python.
46
- """
47
-
48
- def __str__(self) -> str:
49
- return self.value
50
-
51
- def __repr__(self) -> str:
52
- return f"'{str(self)}'"
53
-
54
-
55
- _log_extra_fields: Dict[str, Any] = {}
56
- log = logging.getLogger(__name__)
57
-
58
-
59
- class LogFilterType(StrEnum):
60
- rank0_only = "rank0_only"
61
- local_rank0_only = "local_rank0_only"
62
- all_ranks = "all_ranks"
63
-
64
-
65
- def log_extra_field(field_name: str, field_value: Any) -> None:
66
- global _log_extra_fields
67
- if field_value is None:
68
- if field_name in _log_extra_fields:
69
- del _log_extra_fields[field_name]
70
- else:
71
- _log_extra_fields[field_name] = field_value
72
-
73
-
74
- def setup_logging(log_filter_type: LogFilterType = LogFilterType.rank0_only) -> None:
75
- """
76
- :param rank0_only: INFO and below messages will only be emitted on the rank0 process.
77
- """
78
- log_extra_field("hostname", socket.gethostname())
79
- if is_distributed():
80
- log_extra_field("node_rank", get_node_rank())
81
- log_extra_field("local_rank", get_local_rank())
82
- log_extra_field("global_rank", get_global_rank())
83
- else:
84
- log_extra_field("node_rank", 0)
85
- log_extra_field("local_rank", 0)
86
- log_extra_field("global_rank", 0)
87
-
88
- old_log_record_factory = logging.getLogRecordFactory()
89
-
90
- def log_record_factory(*args, **kwargs) -> logging.LogRecord:
91
- record = old_log_record_factory(*args, **kwargs)
92
- for field_name, field_value in _log_extra_fields.items():
93
- setattr(record, field_name, field_value)
94
- return record
95
-
96
- logging.setLogRecordFactory(log_record_factory)
97
-
98
- handler: logging.Handler
99
- if (
100
- os.environ.get("OLMo_NONINTERACTIVE", False)
101
- or os.environ.get("DEBIAN_FRONTEND", None) == "noninteractive"
102
- or not sys.stdout.isatty()
103
- ):
104
- handler = logging.StreamHandler(sys.stdout)
105
- formatter = logging.Formatter(
106
- "%(asctime)s\t%(hostname)s:%(local_rank)s\t%(name)s:%(lineno)s\t%(levelname)s\t%(message)s"
107
- )
108
- formatter.default_time_format = "%Y-%m-%d %H:%M:%S"
109
- formatter.default_msec_format = "%s.%03d"
110
- handler.setFormatter(formatter)
111
- else:
112
- handler = RichHandler()
113
-
114
- def rank0_filter(record: logging.LogRecord) -> int:
115
- if record.levelno > logging.INFO:
116
- return 1
117
- if getattr(record, "global_rank", 0) == 0:
118
- return 1
119
- else:
120
- return 0
121
-
122
- def local_rank0_filter(record: logging.LogRecord) -> int:
123
- if record.levelno > logging.INFO:
124
- return 1
125
- if getattr(record, "local_rank", 0) == 0:
126
- return 1
127
- else:
128
- return 0
129
-
130
- if log_filter_type == LogFilterType.rank0_only:
131
- filter = rank0_filter
132
- elif log_filter_type == LogFilterType.local_rank0_only:
133
- filter = local_rank0_filter # type: ignore
134
- elif log_filter_type == LogFilterType.all_ranks:
135
- filter = None
136
- else:
137
- raise ValueError(log_filter_type)
138
-
139
- if filter is not None:
140
- handler.addFilter(filter) # type: ignore
141
- logging.basicConfig(handlers=[handler], level=logging.INFO)
142
-
143
- logging.captureWarnings(True)
144
- logging.getLogger("urllib3").setLevel(logging.ERROR)
145
-
146
-
147
- def excepthook(exctype, value, traceback):
148
- """
149
- Used to patch `sys.excepthook` in order to log exceptions.
150
- """
151
- if issubclass(exctype, KeyboardInterrupt):
152
- sys.__excepthook__(exctype, value, traceback)
153
- elif issubclass(exctype, OLMoCliError):
154
- rich.get_console().print(f"[yellow]{value}[/]", highlight=False)
155
- elif issubclass(exctype, OLMoError):
156
- rich.get_console().print(Text(f"{exctype.__name__}:", style="red"), value, highlight=False)
157
- else:
158
- log.critical("Uncaught %s: %s", exctype.__name__, value, exc_info=(exctype, value, traceback))
159
-
160
-
161
- def install_excepthook():
162
- sys.excepthook = excepthook
163
-
164
-
165
- def filter_warnings():
166
- # Filter internal deprecation warnings from torch
167
- warnings.filterwarnings(
168
- action="ignore",
169
- category=UserWarning,
170
- message="torch.distributed.*_base is a private function and will be deprecated.*",
171
- )
172
- warnings.filterwarnings(
173
- action="ignore",
174
- category=UserWarning,
175
- message="TypedStorage is deprecated.*",
176
- )
177
- warnings.filterwarnings(
178
- action="ignore",
179
- category=UserWarning,
180
- message="Please use DTensor instead.*",
181
- )
182
- # Torchvision warnings. We don't actually use torchvision.
183
- warnings.filterwarnings(
184
- action="ignore",
185
- message="failed to load.*",
186
- module="torchvision.io.image",
187
- )
188
-
189
-
190
- def set_env_variables():
191
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
192
-
193
-
194
- def prepare_cli_environment(log_filter_type: Optional[LogFilterType] = None):
195
- if log_filter_type is None:
196
- log_filter_type = LogFilterType(os.environ.get("LOG_FILTER_TYPE", "rank0_only"))
197
- rich.reconfigure(width=max(rich.get_console().width, 180), soft_wrap=True)
198
- setup_logging(log_filter_type=log_filter_type)
199
- install_excepthook()
200
- filter_warnings()
201
- set_env_variables()
202
-
203
-
204
- def clean_opt(arg: str) -> str:
205
- if "=" not in arg:
206
- arg = f"{arg}=True"
207
- name, val = arg.split("=", 1)
208
- name = name.strip("-").replace("-", "_")
209
- return f"{name}={val}"
210
-
211
-
212
- class RichHandler(logging.Handler):
213
- """
214
- A simplified version of rich.logging.RichHandler from
215
- https://github.com/Textualize/rich/blob/master/rich/logging.py
216
- """
217
-
218
- def __init__(
219
- self,
220
- *,
221
- level: Union[int, str] = logging.NOTSET,
222
- console: Optional[Console] = None,
223
- markup: bool = False,
224
- ) -> None:
225
- super().__init__(level=level)
226
- self.console = console or rich.get_console()
227
- self.highlighter = NullHighlighter()
228
- self.markup = markup
229
-
230
- def emit(self, record: logging.LogRecord) -> None:
231
- try:
232
- if hasattr(record.msg, "__rich__") or hasattr(record.msg, "__rich_console__"):
233
- self.console.print(record.msg)
234
- else:
235
- msg: Any = record.msg
236
- if isinstance(record.msg, str):
237
- msg = self.render_message(record=record, message=record.getMessage())
238
- renderables = [
239
- self.get_time_text(record),
240
- self.get_level_text(record),
241
- self.get_location_text(record),
242
- msg,
243
- ]
244
- if record.exc_info is not None:
245
- tb = Traceback.from_exception(*record.exc_info) # type: ignore
246
- renderables.append(tb)
247
- self.console.print(*renderables)
248
- except Exception:
249
- self.handleError(record)
250
-
251
- def render_message(self, *, record: logging.LogRecord, message: str) -> ConsoleRenderable:
252
- use_markup = getattr(record, "markup", self.markup)
253
- message_text = Text.from_markup(message) if use_markup else Text(message)
254
-
255
- highlighter = getattr(record, "highlighter", self.highlighter)
256
- if highlighter:
257
- message_text = highlighter(message_text)
258
-
259
- return message_text
260
-
261
- def get_time_text(self, record: logging.LogRecord) -> Text:
262
- log_time = datetime.fromtimestamp(record.created)
263
- time_str = log_time.strftime("[%Y-%m-%d %X]")
264
- return Text(time_str, style="log.time", end=" ")
265
-
266
- def get_level_text(self, record: logging.LogRecord) -> Text:
267
- level_name = record.levelname
268
- level_text = Text.styled(level_name.ljust(8), f"logging.level.{level_name.lower()}")
269
- level_text.style = "log.level"
270
- level_text.end = " "
271
- return level_text
272
-
273
- def get_location_text(self, record: logging.LogRecord) -> Text:
274
- name_and_line = f"{record.name}:{record.lineno}" if record.name != "root" else "root"
275
- text = f"[{name_and_line}, rank={record.local_rank}]" # type: ignore
276
- return Text(text, style="log.path")
277
-
278
-
279
- def wait_for(condition: Callable[[], bool], description: str, timeout: float = 10.0):
280
- """Wait for the condition function to return True."""
281
- start_time = time.monotonic()
282
- while not condition():
283
- time.sleep(0.5)
284
- if time.monotonic() - start_time > timeout:
285
- raise TimeoutError(f"{description} timed out")
286
-
287
-
288
- def is_url(path: PathOrStr) -> bool:
289
- return re.match(r"[a-z0-9]+://.*", str(path)) is not None
290
-
291
-
292
- def dir_is_empty(dir: PathOrStr) -> bool:
293
- dir = Path(dir)
294
- if not dir.is_dir():
295
- return True
296
- try:
297
- next(dir.glob("*"))
298
- return False
299
- except StopIteration:
300
- return True
301
-
302
-
303
- def get_progress_bar() -> Progress:
304
- from cached_path import get_download_progress
305
-
306
- return get_download_progress()
307
-
308
-
309
- def resource_path(
310
- folder: PathOrStr, fname: str, local_cache: Optional[PathOrStr] = None, progress: Optional[Progress] = None
311
- ) -> Path:
312
- if local_cache is not None and (local_path := Path(local_cache) / fname).is_file():
313
- log.info(f"Found local cache of {fname} at {local_path}")
314
- return local_path
315
- else:
316
- from cached_path import cached_path
317
-
318
- return cached_path(f"{str(folder).rstrip('/')}/{fname}", progress=progress)
319
-
320
-
321
- def file_size(path: PathOrStr) -> int:
322
- """
323
- Get the size of a local or remote file in bytes.
324
- """
325
- if is_url(path):
326
- from urllib.parse import urlparse
327
-
328
- parsed = urlparse(str(path))
329
- if parsed.scheme == "gs":
330
- return _gcs_file_size(parsed.netloc, parsed.path.strip("/"))
331
- elif parsed.scheme in ("s3", "r2"):
332
- return _s3_file_size(parsed.scheme, parsed.netloc, parsed.path.strip("/"))
333
- elif parsed.scheme == "file":
334
- return file_size(str(path).replace("file://", "", 1))
335
- else:
336
- raise NotImplementedError(f"file size not implemented for '{parsed.scheme}' files")
337
- else:
338
- return os.stat(path).st_size
339
-
340
-
341
- def upload(source: PathOrStr, target: str, save_overwrite: bool = False):
342
- """Upload source file to a target location on GCS or S3."""
343
- from urllib.parse import urlparse
344
-
345
- source = Path(source)
346
- assert source.is_file()
347
- parsed = urlparse(target)
348
- if parsed.scheme == "gs":
349
- _gcs_upload(source, parsed.netloc, parsed.path.strip("/"), save_overwrite=save_overwrite)
350
- elif parsed.scheme in ("s3", "r2"):
351
- _s3_upload(source, parsed.scheme, parsed.netloc, parsed.path.strip("/"), save_overwrite=save_overwrite)
352
- else:
353
- raise NotImplementedError(f"Upload not implemented for '{parsed.scheme}' scheme")
354
-
355
-
356
- def get_bytes_range(source: PathOrStr, bytes_start: int, num_bytes: int) -> bytes:
357
- if is_url(source):
358
- from urllib.parse import urlparse
359
-
360
- parsed = urlparse(str(source))
361
- if parsed.scheme == "gs":
362
- return _gcs_get_bytes_range(parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes)
363
- elif parsed.scheme in ("s3", "r2"):
364
- return _s3_get_bytes_range(
365
- parsed.scheme, parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes
366
- )
367
- elif parsed.scheme == "file":
368
- return get_bytes_range(str(source).replace("file://", "", 1), bytes_start, num_bytes)
369
- else:
370
- raise NotImplementedError(f"file size not implemented for '{parsed.scheme}' files")
371
- else:
372
- with open(source, "rb") as f:
373
- f.seek(bytes_start)
374
- return f.read(num_bytes)
375
-
376
-
377
- def find_latest_checkpoint(dir: PathOrStr) -> Optional[PathOrStr]:
378
- if is_url(dir):
379
- from urllib.parse import urlparse
380
-
381
- parsed = urlparse(str(dir))
382
- if parsed.scheme == "gs":
383
- raise NotImplementedError
384
- elif parsed.scheme in ("s3", "r2"):
385
- return _s3_find_latest_checkpoint(parsed.scheme, parsed.netloc, parsed.path.strip("/"))
386
- elif parsed.scheme == "file":
387
- return find_latest_checkpoint(str(dir).replace("file://", "", 1))
388
- else:
389
- raise NotImplementedError(f"find_latest_checkpoint not implemented for '{parsed.scheme}' files")
390
- else:
391
- latest_step = 0
392
- latest_checkpoint: Optional[Path] = None
393
- for path in Path(dir).glob("step*"):
394
- if path.is_dir():
395
- try:
396
- step = int(path.name.replace("step", "").replace("-unsharded", ""))
397
- except ValueError:
398
- continue
399
- # We prioritize sharded checkpoints over unsharded checkpoints.
400
- if step > latest_step or (step == latest_step and not path.name.endswith("-unsharded")):
401
- latest_step = step
402
- latest_checkpoint = path
403
- return latest_checkpoint
404
-
405
-
406
- def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False):
407
- from google.cloud import storage as gcs
408
-
409
- storage_client = gcs.Client()
410
- bucket = storage_client.bucket(bucket_name)
411
- blob = bucket.blob(key)
412
- if not save_overwrite and blob.exists():
413
- raise FileExistsError(f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it.")
414
- blob.upload_from_filename(source)
415
-
416
-
417
- def _gcs_file_size(bucket_name: str, key: str) -> int:
418
- from google.api_core.exceptions import NotFound
419
- from google.cloud import storage as gcs
420
-
421
- storage_client = gcs.Client()
422
- bucket = storage_client.bucket(bucket_name)
423
- blob = bucket.blob(key)
424
- try:
425
- blob.reload()
426
- except NotFound:
427
- raise FileNotFoundError(f"gs://{bucket_name}/{key}")
428
- assert blob.size is not None
429
- return blob.size
430
-
431
-
432
- def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes: int) -> bytes:
433
- from google.api_core.exceptions import NotFound
434
- from google.cloud import storage as gcs
435
-
436
- storage_client = gcs.Client()
437
- bucket = storage_client.bucket(bucket_name)
438
- blob = bucket.blob(key)
439
- try:
440
- blob.reload()
441
- except NotFound:
442
- raise FileNotFoundError(f"gs://{bucket_name}/{key}")
443
- return blob.download_as_bytes(start=bytes_start, end=bytes_start + num_bytes - 1)
444
-
445
-
446
- def _get_s3_profile_name(scheme: str) -> Optional[str]:
447
- if scheme == "s3":
448
- # For backwards compatibility, we assume S3 uses the default profile if S3_PROFILE is not set.
449
- return os.environ.get("S3_PROFILE")
450
- if scheme == "r2":
451
- profile_name = os.environ.get("R2_PROFILE")
452
- if profile_name is None:
453
- raise OLMoEnvironmentError(
454
- "R2 profile name is not set. Did you forget to set the 'R2_PROFILE' env var?"
455
- )
456
-
457
- return profile_name
458
-
459
- raise NotImplementedError(f"Cannot get profile name for scheme {scheme}")
460
-
461
-
462
- def _get_s3_endpoint_url(scheme: str) -> Optional[str]:
463
- if scheme == "s3":
464
- return None
465
- if scheme == "r2":
466
- r2_endpoint_url = os.environ.get("R2_ENDPOINT_URL")
467
- if r2_endpoint_url is None:
468
- raise OLMoEnvironmentError(
469
- "R2 endpoint url is not set. Did you forget to set the 'R2_ENDPOINT_URL' env var?"
470
- )
471
-
472
- return r2_endpoint_url
473
-
474
- raise NotImplementedError(f"Cannot get endpoint url for scheme {scheme}")
475
-
476
-
477
- @cache
478
- def _get_s3_client(scheme: str):
479
- session = boto3.Session(profile_name=_get_s3_profile_name(scheme))
480
- return session.client(
481
- "s3",
482
- endpoint_url=_get_s3_endpoint_url(scheme),
483
- config=Config(retries={"max_attempts": 10, "mode": "standard"}),
484
- use_ssl=not int(os.environ.get("OLMO_NO_SSL", "0")),
485
- )
486
-
487
-
488
- def _wait_before_retry(attempt: int):
489
- time.sleep(min(0.5 * 2**attempt, 3.0))
490
-
491
-
492
- def _s3_upload(
493
- source: Path, scheme: str, bucket_name: str, key: str, save_overwrite: bool = False, max_attempts: int = 3
494
- ):
495
- err: Optional[Exception] = None
496
- if not save_overwrite:
497
- for attempt in range(1, max_attempts + 1):
498
- try:
499
- _get_s3_client(scheme).head_object(Bucket=bucket_name, Key=key)
500
- raise FileExistsError(
501
- f"s3://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it."
502
- )
503
- except boto_exceptions.ClientError as e:
504
- if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
505
- err = None
506
- break
507
- err = e
508
-
509
- if attempt < max_attempts:
510
- log.warning("%s failed attempt %d with retriable error: %s", _s3_upload.__name__, attempt, err)
511
- _wait_before_retry(attempt)
512
-
513
- if err is not None:
514
- raise OLMoNetworkError(f"Failed to check object existence during {scheme} upload") from err
515
-
516
- try:
517
- _get_s3_client(scheme).upload_file(source, bucket_name, key)
518
- except boto_exceptions.ClientError as e:
519
- raise OLMoNetworkError(f"Failed to upload to {scheme}") from e
520
-
521
-
522
- def _s3_file_size(scheme: str, bucket_name: str, key: str, max_attempts: int = 3) -> int:
523
- err: Optional[Exception] = None
524
- for attempt in range(1, max_attempts + 1):
525
- try:
526
- return _get_s3_client(scheme).head_object(Bucket=bucket_name, Key=key)["ContentLength"]
527
- except boto_exceptions.ClientError as e:
528
- if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
529
- raise FileNotFoundError(f"s3://{bucket_name}/{key}") from e
530
- err = e
531
-
532
- if attempt < max_attempts:
533
- log.warning("%s failed attempt %d with retriable error: %s", _s3_file_size.__name__, attempt, err)
534
- _wait_before_retry(attempt)
535
-
536
- raise OLMoNetworkError(f"Failed to get {scheme} file size") from err
537
-
538
-
539
- def _s3_get_bytes_range(
540
- scheme: str, bucket_name: str, key: str, bytes_start: int, num_bytes: int, max_attempts: int = 3
541
- ) -> bytes:
542
- err: Optional[Exception] = None
543
- for attempt in range(1, max_attempts + 1):
544
- try:
545
- return (
546
- _get_s3_client(scheme)
547
- .get_object(
548
- Bucket=bucket_name, Key=key, Range=f"bytes={bytes_start}-{bytes_start + num_bytes - 1}"
549
- )["Body"]
550
- .read()
551
- )
552
- except boto_exceptions.ClientError as e:
553
- if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
554
- raise FileNotFoundError(f"{scheme}://{bucket_name}/{key}") from e
555
- err = e
556
- except (boto_exceptions.HTTPClientError, boto_exceptions.ConnectionError) as e:
557
- # ResponseStreamingError (subclass of HTTPClientError) can happen as
558
- # a result of a failed read from the stream (http.client.IncompleteRead).
559
- # Retrying can help in this case.
560
- err = e
561
-
562
- if attempt < max_attempts:
563
- log.warning(
564
- "%s failed attempt %d with retriable error: %s", _s3_get_bytes_range.__name__, attempt, err
565
- )
566
- _wait_before_retry(attempt)
567
-
568
- # When torch's DataLoader intercepts exceptions, it may try to re-raise them
569
- # by recalling their constructor with a single message arg. Torch has some
570
- # logic to deal with the absence of a single-parameter constructor, but it
571
- # doesn't gracefully handle other possible failures in calling such a constructor
572
- # This can cause an irrelevant exception (e.g. KeyError: 'error'), resulting
573
- # in us losing the true exception info. To avoid this, we change the exception
574
- # to a type that has a single-parameter constructor.
575
- raise OLMoNetworkError(f"Failed to get bytes range from {scheme}") from err
576
-
577
-
578
- def _s3_find_latest_checkpoint(scheme: str, bucket_name: str, prefix: str) -> Optional[str]:
579
- if not prefix.endswith("/"):
580
- prefix = f"{prefix}/"
581
- response = _get_s3_client(scheme).list_objects(Bucket=bucket_name, Prefix=prefix, Delimiter="/")
582
- assert not response["IsTruncated"] # need to handle this if it happens
583
- latest_step = 0
584
- latest_checkpoint: Optional[str] = None
585
- for item in response["CommonPrefixes"]:
586
- prefix = item["Prefix"].strip("/")
587
- checkpoint_name = os.path.split(prefix)[-1]
588
- if not checkpoint_name.startswith("step"):
589
- continue
590
- try:
591
- step = int(checkpoint_name.replace("step", "").replace("-unsharded", ""))
592
- except ValueError:
593
- continue
594
- # Make sure the checkpoint dir contains a config, otherwise the checkpoint is incomplete
595
- # (upload might have have failed part way through).
596
- try:
597
- _s3_file_size(scheme, bucket_name, f"{prefix}/config.yaml")
598
- except FileNotFoundError:
599
- continue
600
- # We prioritize sharded checkpoints over unsharded ones.
601
- if step > latest_step or (step == latest_step and not checkpoint_name.endswith("-unsharded")):
602
- latest_step = step
603
- latest_checkpoint = f"{scheme}://ai2-llm/{prefix}"
604
- return latest_checkpoint
605
-
606
-
607
- def default_thread_count() -> int:
608
- return int(os.environ.get("OLMO_NUM_THREADS") or min(32, (os.cpu_count() or 1) + 4))
609
-
610
-
611
- def pass_through_fn(fn, *args, **kwargs):
612
- return fn(*args, **kwargs)
613
-
614
-
615
- def threaded_generator(g, maxsize: int = 16, thread_name: Optional[str] = None):
616
- q: Queue = Queue(maxsize=maxsize)
617
-
618
- sentinel = object()
619
-
620
- def fill_queue():
621
- try:
622
- for value in g:
623
- q.put(value)
624
- except Exception as e:
625
- q.put(e)
626
- finally:
627
- q.put(sentinel)
628
-
629
- thread_name = thread_name or repr(g)
630
- thread = Thread(name=thread_name, target=fill_queue, daemon=True)
631
- thread.start()
632
-
633
- for x in iter(q.get, sentinel):
634
- if isinstance(x, Exception):
635
- raise OLMoThreadError(f"generator thread {thread_name} failed") from x
636
- else:
637
- yield x
638
-
639
-
640
- def roundrobin(*iterables):
641
- """
642
- Call the given iterables in a round-robin fashion. For example:
643
- ``roundrobin('ABC', 'D', 'EF') --> A D E B F C``
644
- """
645
- # Adapted from https://docs.python.org/3/library/itertools.html#itertools-recipes
646
- num_active = len(iterables)
647
- nexts = cycle(iter(it).__next__ for it in iterables)
648
- while num_active:
649
- try:
650
- for next in nexts:
651
- yield next()
652
- except StopIteration:
653
- # Remove the iterator we just exhausted from the cycle.
654
- num_active -= 1
655
- nexts = cycle(islice(nexts, num_active))