Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
""" | |
https://github.com/boreas-l/zipEnhancer/blob/main/models/layers/scaling.py | |
""" | |
import logging | |
import random | |
from typing import Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
def logaddexp_onnx(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
max_value = torch.max(x, y) | |
diff = torch.abs(x - y) | |
return max_value + torch.log1p(torch.exp(-diff)) | |
# RuntimeError: Exporting the operator logaddexp to ONNX opset version | |
# 14 is not supported. Please feel free to request support or submit | |
# a pull request on PyTorch GitHub. | |
# | |
# The following function is to solve the above error when exporting | |
# models to ONNX via torch.jit.trace() | |
def logaddexp(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
if torch.jit.is_scripting(): | |
# Note: We cannot use torch.jit.is_tracing() here as it also | |
# matches torch.onnx.export(). | |
return torch.logaddexp(x, y) | |
elif torch.onnx.is_in_onnx_export(): | |
return logaddexp_onnx(x, y) | |
else: | |
# for torch.jit.trace() | |
return torch.logaddexp(x, y) | |
class PiecewiseLinear(object): | |
""" | |
Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with | |
the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y] | |
respectively. | |
""" | |
def __init__(self, *args): | |
assert len(args) >= 1, len(args) | |
if len(args) == 1 and isinstance(args[0], PiecewiseLinear): | |
self.pairs = list(args[0].pairs) | |
else: | |
self.pairs = [(float(x), float(y)) for x, y in args] | |
for x, y in self.pairs: | |
assert isinstance(x, (float, int)), type(x) | |
assert isinstance(y, (float, int)), type(y) | |
for i in range(len(self.pairs) - 1): | |
assert self.pairs[i + 1][0] > self.pairs[i][0], ( | |
i, | |
self.pairs[i], | |
self.pairs[i + 1], | |
) | |
def __str__(self): | |
# e.g. 'PiecewiseLinear((0., 10.), (100., 0.))' | |
return f'PiecewiseLinear({str(self.pairs)[1:-1]})' | |
def __call__(self, x): | |
if x <= self.pairs[0][0]: | |
return self.pairs[0][1] | |
elif x >= self.pairs[-1][0]: | |
return self.pairs[-1][1] | |
else: | |
cur_x, cur_y = self.pairs[0] | |
for i in range(1, len(self.pairs)): | |
next_x, next_y = self.pairs[i] | |
if cur_x <= x <= next_x: | |
return cur_y + (next_y - cur_y) * (x - cur_x) / ( | |
next_x - cur_x) | |
cur_x, cur_y = next_x, next_y | |
assert False | |
def __mul__(self, alpha): | |
return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs]) | |
def __add__(self, x): | |
if isinstance(x, (float, int)): | |
return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs]) | |
s, x = self.get_common_basis(x) | |
return PiecewiseLinear(*[(sp[0], sp[1] + xp[1]) | |
for sp, xp in zip(s.pairs, x.pairs)]) | |
def max(self, x): | |
if isinstance(x, (float, int)): | |
x = PiecewiseLinear((0, x)) | |
s, x = self.get_common_basis(x, include_crossings=True) | |
return PiecewiseLinear(*[(sp[0], max(sp[1], xp[1])) | |
for sp, xp in zip(s.pairs, x.pairs)]) | |
def min(self, x): | |
if isinstance(x, float) or isinstance(x, int): | |
x = PiecewiseLinear((0, x)) | |
s, x = self.get_common_basis(x, include_crossings=True) | |
return PiecewiseLinear(*[(sp[0], min(sp[1], xp[1])) | |
for sp, xp in zip(s.pairs, x.pairs)]) | |
def __eq__(self, other): | |
return self.pairs == other.pairs | |
def get_common_basis(self, | |
p: 'PiecewiseLinear', | |
include_crossings: bool = False): | |
""" | |
Returns (self_mod, p_mod) which are equivalent piecewise linear | |
functions to self and p, but with the same x values. | |
p: the other piecewise linear function | |
include_crossings: if true, include in the x values positions | |
where the functions indicate by this and p cross. | |
""" | |
assert isinstance(p, PiecewiseLinear), type(p) | |
# get sorted x-values without repetition. | |
x_vals = sorted( | |
set([x for x, _ in self.pairs] + [x for x, _ in p.pairs])) | |
y_vals1 = [self(x) for x in x_vals] | |
y_vals2 = [p(x) for x in x_vals] | |
if include_crossings: | |
extra_x_vals = [] | |
for i in range(len(x_vals) - 1): | |
_compare_results1 = (y_vals1[i] > y_vals2[i]) | |
_compare_results2 = (y_vals1[i + 1] > y_vals2[i + 1]) | |
if _compare_results1 != _compare_results2: | |
# if ((y_vals1[i] > y_vals2[i]) != | |
# (y_vals1[i + 1] > y_vals2[i + 1])): | |
# if the two lines in this subsegment potentially cross each other. | |
diff_cur = abs(y_vals1[i] - y_vals2[i]) | |
diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1]) | |
# `pos`, between 0 and 1, gives the relative x position, | |
# with 0 being x_vals[i] and 1 being x_vals[i+1]. | |
pos = diff_cur / (diff_cur + diff_next) | |
extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i]) | |
extra_x_vals.append(extra_x_val) | |
if len(extra_x_vals) > 0: | |
x_vals = sorted(set(x_vals + extra_x_vals)) | |
y_vals1 = [self(x) for x in x_vals] | |
y_vals2 = [p(x) for x in x_vals] | |
return ( | |
PiecewiseLinear(*zip(x_vals, y_vals1)), | |
PiecewiseLinear(*zip(x_vals, y_vals2)), | |
) | |
class ScheduledFloat(torch.nn.Module): | |
""" | |
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); | |
it does not have a working forward() function. You are supposed to cast it to float, as | |
in, float(parent_module.whatever), and use it as something like a dropout prob. | |
It is a floating point value whose value changes depending on the batch count of the | |
training loop. It is a piecewise linear function where you specify the (x,y) pairs | |
in sorted order on x; x corresponds to the batch index. For batch-index values before the | |
first x or after the last x, we just use the first or last y value. | |
Example: | |
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0) | |
`default` is used when self.batch_count is not set or not in training mode or in | |
torch.jit scripting mode. | |
""" | |
def __init__(self, *args, default: float = 0.0): | |
super().__init__() | |
# self.batch_count and self.name will be written to in the training loop. | |
self.batch_count = None | |
self.name = None | |
self.default = default | |
self.schedule = PiecewiseLinear(*args) | |
def extra_repr(self) -> str: | |
return ( | |
f'batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}' | |
) | |
def __float__(self): | |
batch_count = self.batch_count | |
if (batch_count is None or not self.training | |
or torch.jit.is_scripting() or torch.jit.is_tracing()): | |
return float(self.default) | |
else: | |
ans = self.schedule(self.batch_count) | |
if random.random() < 0.0002: | |
logging.info( | |
f'ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}' | |
) | |
return ans | |
def __add__(self, x): | |
if isinstance(x, float) or isinstance(x, int): | |
return ScheduledFloat(self.schedule + x, default=self.default) | |
else: | |
return ScheduledFloat( | |
self.schedule + x.schedule, default=self.default + x.default) | |
def max(self, x): | |
if isinstance(x, float) or isinstance(x, int): | |
return ScheduledFloat(self.schedule.max(x), default=self.default) | |
else: | |
return ScheduledFloat( | |
self.schedule.max(x.schedule), | |
default=max(self.default, x.default)) | |
FloatLike = Union[float, ScheduledFloat] | |
class SoftmaxFunction(torch.autograd.Function): | |
""" | |
Tries to handle half-precision derivatives in a randomized way that should | |
be more accurate for training than the default behavior. | |
""" | |
def forward(ctx, x: torch.Tensor, dim: int): | |
ans = x.softmax(dim=dim) | |
# if x dtype is float16, x.softmax() returns a float32 because | |
# (presumably) that op does not support float16, and autocast | |
# is enabled. | |
if torch.is_autocast_enabled(): | |
ans = ans.to(torch.float16) | |
ctx.save_for_backward(ans) | |
ctx.x_dtype = x.dtype | |
ctx.dim = dim | |
return ans | |
def backward(ctx, ans_grad: torch.Tensor): | |
(ans,) = ctx.saved_tensors | |
with torch.cuda.amp.autocast(enabled=False): | |
ans_grad = ans_grad.to(torch.float32) | |
ans = ans.to(torch.float32) | |
x_grad = ans_grad * ans | |
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) | |
return x_grad, None | |
if __name__ == "__main__": | |
pass | |