Spaces:
Running
Running
File size: 9,379 Bytes
8c3c188 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 |
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://github.com/boreas-l/zipEnhancer/blob/main/models/layers/scaling.py
"""
import logging
import random
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
def logaddexp_onnx(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
max_value = torch.max(x, y)
diff = torch.abs(x - y)
return max_value + torch.log1p(torch.exp(-diff))
# RuntimeError: Exporting the operator logaddexp to ONNX opset version
# 14 is not supported. Please feel free to request support or submit
# a pull request on PyTorch GitHub.
#
# The following function is to solve the above error when exporting
# models to ONNX via torch.jit.trace()
def logaddexp(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
if torch.jit.is_scripting():
# Note: We cannot use torch.jit.is_tracing() here as it also
# matches torch.onnx.export().
return torch.logaddexp(x, y)
elif torch.onnx.is_in_onnx_export():
return logaddexp_onnx(x, y)
else:
# for torch.jit.trace()
return torch.logaddexp(x, y)
class PiecewiseLinear(object):
"""
Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with
the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y]
respectively.
"""
def __init__(self, *args):
assert len(args) >= 1, len(args)
if len(args) == 1 and isinstance(args[0], PiecewiseLinear):
self.pairs = list(args[0].pairs)
else:
self.pairs = [(float(x), float(y)) for x, y in args]
for x, y in self.pairs:
assert isinstance(x, (float, int)), type(x)
assert isinstance(y, (float, int)), type(y)
for i in range(len(self.pairs) - 1):
assert self.pairs[i + 1][0] > self.pairs[i][0], (
i,
self.pairs[i],
self.pairs[i + 1],
)
def __str__(self):
# e.g. 'PiecewiseLinear((0., 10.), (100., 0.))'
return f'PiecewiseLinear({str(self.pairs)[1:-1]})'
def __call__(self, x):
if x <= self.pairs[0][0]:
return self.pairs[0][1]
elif x >= self.pairs[-1][0]:
return self.pairs[-1][1]
else:
cur_x, cur_y = self.pairs[0]
for i in range(1, len(self.pairs)):
next_x, next_y = self.pairs[i]
if cur_x <= x <= next_x:
return cur_y + (next_y - cur_y) * (x - cur_x) / (
next_x - cur_x)
cur_x, cur_y = next_x, next_y
assert False
def __mul__(self, alpha):
return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs])
def __add__(self, x):
if isinstance(x, (float, int)):
return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs])
s, x = self.get_common_basis(x)
return PiecewiseLinear(*[(sp[0], sp[1] + xp[1])
for sp, xp in zip(s.pairs, x.pairs)])
def max(self, x):
if isinstance(x, (float, int)):
x = PiecewiseLinear((0, x))
s, x = self.get_common_basis(x, include_crossings=True)
return PiecewiseLinear(*[(sp[0], max(sp[1], xp[1]))
for sp, xp in zip(s.pairs, x.pairs)])
def min(self, x):
if isinstance(x, float) or isinstance(x, int):
x = PiecewiseLinear((0, x))
s, x = self.get_common_basis(x, include_crossings=True)
return PiecewiseLinear(*[(sp[0], min(sp[1], xp[1]))
for sp, xp in zip(s.pairs, x.pairs)])
def __eq__(self, other):
return self.pairs == other.pairs
def get_common_basis(self,
p: 'PiecewiseLinear',
include_crossings: bool = False):
"""
Returns (self_mod, p_mod) which are equivalent piecewise linear
functions to self and p, but with the same x values.
p: the other piecewise linear function
include_crossings: if true, include in the x values positions
where the functions indicate by this and p cross.
"""
assert isinstance(p, PiecewiseLinear), type(p)
# get sorted x-values without repetition.
x_vals = sorted(
set([x for x, _ in self.pairs] + [x for x, _ in p.pairs]))
y_vals1 = [self(x) for x in x_vals]
y_vals2 = [p(x) for x in x_vals]
if include_crossings:
extra_x_vals = []
for i in range(len(x_vals) - 1):
_compare_results1 = (y_vals1[i] > y_vals2[i])
_compare_results2 = (y_vals1[i + 1] > y_vals2[i + 1])
if _compare_results1 != _compare_results2:
# if ((y_vals1[i] > y_vals2[i]) !=
# (y_vals1[i + 1] > y_vals2[i + 1])):
# if the two lines in this subsegment potentially cross each other.
diff_cur = abs(y_vals1[i] - y_vals2[i])
diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1])
# `pos`, between 0 and 1, gives the relative x position,
# with 0 being x_vals[i] and 1 being x_vals[i+1].
pos = diff_cur / (diff_cur + diff_next)
extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i])
extra_x_vals.append(extra_x_val)
if len(extra_x_vals) > 0:
x_vals = sorted(set(x_vals + extra_x_vals))
y_vals1 = [self(x) for x in x_vals]
y_vals2 = [p(x) for x in x_vals]
return (
PiecewiseLinear(*zip(x_vals, y_vals1)),
PiecewiseLinear(*zip(x_vals, y_vals2)),
)
class ScheduledFloat(torch.nn.Module):
"""
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
it does not have a working forward() function. You are supposed to cast it to float, as
in, float(parent_module.whatever), and use it as something like a dropout prob.
It is a floating point value whose value changes depending on the batch count of the
training loop. It is a piecewise linear function where you specify the (x,y) pairs
in sorted order on x; x corresponds to the batch index. For batch-index values before the
first x or after the last x, we just use the first or last y value.
Example:
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
`default` is used when self.batch_count is not set or not in training mode or in
torch.jit scripting mode.
"""
def __init__(self, *args, default: float = 0.0):
super().__init__()
# self.batch_count and self.name will be written to in the training loop.
self.batch_count = None
self.name = None
self.default = default
self.schedule = PiecewiseLinear(*args)
def extra_repr(self) -> str:
return (
f'batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}'
)
def __float__(self):
batch_count = self.batch_count
if (batch_count is None or not self.training
or torch.jit.is_scripting() or torch.jit.is_tracing()):
return float(self.default)
else:
ans = self.schedule(self.batch_count)
if random.random() < 0.0002:
logging.info(
f'ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}'
)
return ans
def __add__(self, x):
if isinstance(x, float) or isinstance(x, int):
return ScheduledFloat(self.schedule + x, default=self.default)
else:
return ScheduledFloat(
self.schedule + x.schedule, default=self.default + x.default)
def max(self, x):
if isinstance(x, float) or isinstance(x, int):
return ScheduledFloat(self.schedule.max(x), default=self.default)
else:
return ScheduledFloat(
self.schedule.max(x.schedule),
default=max(self.default, x.default))
FloatLike = Union[float, ScheduledFloat]
class SoftmaxFunction(torch.autograd.Function):
"""
Tries to handle half-precision derivatives in a randomized way that should
be more accurate for training than the default behavior.
"""
@staticmethod
def forward(ctx, x: torch.Tensor, dim: int):
ans = x.softmax(dim=dim)
# if x dtype is float16, x.softmax() returns a float32 because
# (presumably) that op does not support float16, and autocast
# is enabled.
if torch.is_autocast_enabled():
ans = ans.to(torch.float16)
ctx.save_for_backward(ans)
ctx.x_dtype = x.dtype
ctx.dim = dim
return ans
@staticmethod
def backward(ctx, ans_grad: torch.Tensor):
(ans,) = ctx.saved_tensors
with torch.cuda.amp.autocast(enabled=False):
ans_grad = ans_grad.to(torch.float32)
ans = ans.to(torch.float32)
x_grad = ans_grad * ans
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
return x_grad, None
if __name__ == "__main__":
pass
|