|
|
|
|
|
""" |
|
Python polyfills for common builtins. |
|
""" |
|
import math |
|
from typing import Any, Callable, Sequence |
|
|
|
import torch |
|
|
|
|
|
def all(iterator): |
|
for elem in iterator: |
|
if not elem: |
|
return False |
|
return True |
|
|
|
|
|
def any(iterator): |
|
for elem in iterator: |
|
if elem: |
|
return True |
|
return False |
|
|
|
|
|
def index(iterator, item, start=0, end=None): |
|
for i, elem in enumerate(list(iterator))[start:end]: |
|
if item == elem: |
|
return i |
|
|
|
raise ValueError(f"{item} is not in {type(iterator)}") |
|
|
|
|
|
def repeat(item, count): |
|
for i in range(count): |
|
yield item |
|
|
|
|
|
def radians(x): |
|
return math.pi / 180.0 * x |
|
|
|
|
|
def accumulate_grad(x, new_grad): |
|
new_grad = torch.clone(new_grad) |
|
if x.grad is None: |
|
x.grad = new_grad |
|
else: |
|
x.grad.add_(new_grad) |
|
|
|
|
|
def list_cmp(op: Callable[[Any, Any], bool], left: Sequence[Any], right: Sequence[Any]): |
|
"""emulate `(1,2,3) > (1,2)` etc""" |
|
for a, b in zip(left, right): |
|
if a != b: |
|
return op(a, b) |
|
return op(len(left), len(right)) |
|
|
|
|
|
def set_isdisjoint(set1, set2): |
|
for x in set1: |
|
if x in set2: |
|
return False |
|
return True |
|
|
|
|
|
def dropwhile(predicate, iterable): |
|
|
|
iterable = iter(iterable) |
|
for x in iterable: |
|
if not predicate(x): |
|
yield x |
|
break |
|
yield from iterable |
|
|