Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,193 Bytes
b273838 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from abc import ABC, abstractmethod
import torch
__CONDITIONING_METHOD__ = {}
def register_conditioning_method(name: str):
def wrapper(cls):
if __CONDITIONING_METHOD__.get(name, None):
raise NameError(f"Name {name} is already registered!")
__CONDITIONING_METHOD__[name] = cls
return cls
return wrapper
def get_conditioning_method(name: str, operator, noiser, **kwargs):
if __CONDITIONING_METHOD__.get(name, None) is None:
raise NameError(f"Name {name} is not defined!")
return __CONDITIONING_METHOD__[name](operator=operator, noiser=noiser, **kwargs)
class ConditioningMethod(ABC):
def __init__(self, operator, noiser, **kwargs):
self.operator = operator
self.noiser = noiser
def project(self, data, noisy_measurement, **kwargs):
return self.operator.project(data=data, measurement=noisy_measurement, **kwargs)
def grad_and_value(self, x_prev, x_0_hat, measurement, **kwargs):
if self.noiser.__name__ == 'gaussian':
difference = measurement - self.operator.forward(x_0_hat, **kwargs)
norm = torch.linalg.norm(difference)
norm_grad = torch.autograd.grad(outputs=norm, inputs=x_prev)[0]
elif self.noiser.__name__ == 'poisson':
Ax = self.operator.forward(x_0_hat, **kwargs)
difference = measurement-Ax
norm = torch.linalg.norm(difference) / measurement.abs()
norm = norm.mean()
norm_grad = torch.autograd.grad(outputs=norm, inputs=x_prev)[0]
else:
raise NotImplementedError
return norm_grad, norm
@abstractmethod
def conditioning(self, x_t, measurement, noisy_measurement=None, **kwargs):
pass
@register_conditioning_method(name='vanilla')
class Identity(ConditioningMethod):
# just pass the input without conditioning
def conditioning(self, x_t):
return x_t
@register_conditioning_method(name='projection')
class Projection(ConditioningMethod):
def conditioning(self, x_t, noisy_measurement, **kwargs):
x_t = self.project(data=x_t, noisy_measurement=noisy_measurement)
return x_t
@register_conditioning_method(name='mcg')
class ManifoldConstraintGradient(ConditioningMethod):
def __init__(self, operator, noiser, **kwargs):
super().__init__(operator, noiser)
self.scale = kwargs.get('scale', 1.0)
def conditioning(self, x_prev, x_t, x_0_hat, measurement, noisy_measurement, **kwargs):
# posterior sampling
norm_grad, norm = self.grad_and_value(x_prev=x_prev, x_0_hat=x_0_hat, measurement=measurement, **kwargs)
x_t -= norm_grad * self.scale
# projection
x_t = self.project(data=x_t, noisy_measurement=noisy_measurement, **kwargs)
return x_t, norm
@register_conditioning_method(name='ps')
class PosteriorSampling(ConditioningMethod):
def __init__(self, operator, noiser, **kwargs):
super().__init__(operator, noiser)
self.scale = kwargs.get('scale', 1.0)
def conditioning(self, x_prev, x_t, x_0_hat, measurement, **kwargs):
norm_grad, norm = self.grad_and_value(x_prev=x_prev, x_0_hat=x_0_hat, measurement=measurement, **kwargs)
x_t -= norm_grad * self.scale
return x_t, norm
@register_conditioning_method(name='ps+')
class PosteriorSamplingPlus(ConditioningMethod):
def __init__(self, operator, noiser, **kwargs):
super().__init__(operator, noiser)
self.num_sampling = kwargs.get('num_sampling', 5)
self.scale = kwargs.get('scale', 1.0)
def conditioning(self, x_prev, x_t, x_0_hat, measurement, **kwargs):
norm = 0
for _ in range(self.num_sampling):
# TODO: use noiser?
x_0_hat_noise = x_0_hat + 0.05 * torch.rand_like(x_0_hat)
difference = measurement - self.operator.forward(x_0_hat_noise)
norm += torch.linalg.norm(difference) / self.num_sampling
norm_grad = torch.autograd.grad(outputs=norm, inputs=x_prev)[0]
x_t -= norm_grad * self.scale
return x_t, norm
|