Spaces:
Runtime error
Runtime error
import gradio as gr | |
import matplotlib.pyplot as plt | |
import torch | |
import seaborn as sns | |
import pandas as pd | |
import os | |
import os.path as osp | |
import ffmpeg | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn.modules.loss import _Loss | |
from torch.utils.data import Dataset, DataLoader | |
NUM_PER_BUCKET = 1000 | |
NOISE_SIGMA = 1 | |
Y_UB = 10 | |
Y_LB = 0 | |
K = 1 | |
B = 0 | |
NUM_SEG = 5 | |
NUM_EPOCHS = 100 | |
PRINT_FREQ = NUM_EPOCHS // 20 | |
NUM_TRAIN_SAMPLES = NUM_PER_BUCKET * NUM_SEG | |
BATCH_SIZE = 256 | |
def make_dataframe(x, y, method=None): | |
x = list(x[:, 0].detach().numpy()) | |
y = list(y[:, 0].detach().numpy()) | |
if method is not None: | |
method = [method for _ in range(len(x))] | |
df = pd.DataFrame({'x': x, 'y': y, 'Method': method}) | |
else: | |
df = pd.DataFrame({'x': x, 'y': y}) | |
return df | |
Y_demo = torch.linspace(Y_LB, Y_UB, 2).unsqueeze(-1) | |
X_demo = (Y_demo - B) / K | |
df_oracle = make_dataframe(X_demo, Y_demo, 'Oracle') | |
def prepare_data(sel_num): | |
interval = (Y_UB - Y_LB) / NUM_SEG | |
all_x, all_y = [], [] | |
prob = [] | |
for i in range(NUM_SEG): | |
uniform_y_distribution = torch.distributions.Uniform(Y_UB - (i + 1) * interval, Y_UB - i * interval) | |
y_uniform = uniform_y_distribution.sample((NUM_TRAIN_SAMPLES, 1))[:sel_num[i]] | |
noise_distribution = torch.distributions.Normal(loc=0, scale=NOISE_SIGMA) | |
noise = noise_distribution.sample((NUM_TRAIN_SAMPLES, 1))[:sel_num[i]] | |
y_uniform_oracle = y_uniform - noise | |
x_uniform = (y_uniform_oracle - B) / K | |
all_x += x_uniform | |
all_y += y_uniform | |
prob += [torch.tensor(sel_num[i]).float() for _ in range(sel_num[i])] | |
all_x = torch.stack(all_x) | |
all_y = torch.stack(all_y) | |
prob = torch.stack(prob) | |
return all_x, all_y, prob | |
def unzip_dataloader(training_loader): | |
all_x = [] | |
all_y = [] | |
for data, label, _ in training_loader: | |
all_x.append(data) | |
all_y.append(label) | |
all_x = torch.cat(all_x) | |
all_y = torch.cat(all_y) | |
return all_x, all_y | |
def train(train_loader, training_df, training_bundle, num_epochs): | |
visualize_training_process(training_df, training_bundle, -1) | |
for epoch in range(num_epochs): | |
for model, optimizer, scheduler, criterion, criterion_name in training_bundle: | |
model.train() | |
for data, target, prob in train_loader: | |
optimizer.zero_grad() | |
pred = model(data) | |
if criterion_name == 'Reweight': | |
loss = criterion(pred, target, prob) | |
else: | |
loss = criterion(pred, target) | |
loss.backward() | |
optimizer.step() | |
scheduler.step() | |
if (epoch + 1) % PRINT_FREQ == 0: | |
visualize_training_process(training_df, training_bundle, epoch) | |
visualize_training_process(training_df, training_bundle, num_epochs-1, final=True) | |
def visualize_training_process(training_df, training_bundle, epoch, final=False): | |
df = df_oracle | |
for model, optimizer, scheduler, criterion, criterion_name in training_bundle: | |
model.eval() | |
y = model(X_demo) | |
df = df.append(make_dataframe(X_demo, y, criterion_name), ignore_index=True) | |
visualize(training_df, df, 'train_log/{:05d}.png'.format(epoch + 1), fast=True, epoch=epoch) | |
if final: | |
visualize(training_df, df, 'regression_result.png', fast=False) | |
def make_video(): | |
( | |
ffmpeg | |
.input('train_log/*.png', pattern_type='glob', framerate=3) | |
.output('movie.mp4') | |
.run() | |
) | |
class ReweightL2(_Loss): | |
def __init__(self, reweight='inverse'): | |
super(ReweightL2, self).__init__() | |
self.reweight = reweight | |
def forward(self, pred, target, prob): | |
reweight = self.reweight | |
if reweight == 'inverse': | |
inv_prob = prob.pow(-1) | |
elif reweight == 'sqrt_inv': | |
inv_prob = prob.pow(-0.5) | |
else: | |
raise NotImplementedError | |
inv_prob = inv_prob / inv_prob.sum() | |
loss = F.mse_loss(pred, target, reduction='none').sum(-1) * inv_prob | |
loss = loss.sum() | |
return loss | |
class LinearModel(nn.Module): | |
def __init__(self, input_dim, output_dim): | |
super(LinearModel, self).__init__() | |
self.mlp = nn.Sequential( | |
nn.Linear(input_dim, output_dim), | |
) | |
def forward(self, x): | |
x = self.mlp(x) | |
return x | |
def prepare_model(): | |
model = LinearModel(input_dim=1, output_dim=1) | |
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9) | |
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS) | |
return model, optimizer, scheduler | |
class BMCLoss(_Loss): | |
def __init__(self): | |
super(BMCLoss, self).__init__() | |
self.noise_sigma = NOISE_SIGMA | |
def forward(self, pred, target): | |
pred = pred.reshape(-1, 1) | |
target = target.reshape(-1, 1) | |
noise_var = self.noise_sigma ** 2 | |
loss = bmc_loss(pred, target, noise_var) | |
return loss | |
def bmc_loss(pred, target, noise_var): | |
logits = - 0.5 * (pred - target.T).pow(2) / noise_var | |
loss = F.cross_entropy(logits, torch.arange(pred.shape[0])) | |
return loss * (2 * noise_var) | |
def regress(train_loader, training_df): | |
training_bundle = [] | |
criterions = { | |
'MSE': torch.nn.MSELoss(), | |
'Reweight': ReweightL2(), | |
'Balanced MSE': BMCLoss(), | |
} | |
for criterion_name in criterions: | |
criterion = criterions[criterion_name] | |
model, optimizer, scheduler = prepare_model() | |
training_bundle.append((model, optimizer, scheduler, criterion, criterion_name)) | |
train(train_loader, training_df, training_bundle, NUM_EPOCHS) | |
class DummyDataset(Dataset): | |
def __init__(self, inputs, targets, prob): | |
self.inputs = inputs | |
self.targets = targets | |
self.prob = prob | |
def __getitem__(self, index): | |
return self.inputs[index], self.targets[index], self.prob[index] | |
def __len__(self): | |
return len(self.inputs) | |
def visualize(training_df, df, save_path, fast=False, epoch=None): | |
if fast: | |
f = plt.figure(figsize=(3, 3)) | |
g = f.add_subplot(111) | |
g_line = sns.lineplot(data=df, x='x', y='y', hue='Method', ax=g, estimator=None, ci=None) | |
plt.xlim((Y_LB - B) / K, (Y_UB - B) / K) | |
plt.ylim(Y_LB, Y_UB) | |
else: | |
g = sns.jointplot(data=training_df, x='x', y='y', color='#003ea1', alpha=0.1, linewidths=0, s=50, | |
marginal_kws=dict(bins=torch.linspace(Y_LB, Y_UB, steps=NUM_SEG + 1)), | |
xlim=((Y_LB - B) / K, (Y_UB - B) / K), | |
ylim=(Y_LB, Y_UB), | |
space=0.1, | |
height=5, | |
ratio=2, | |
estimator=None, ci=None, | |
legend=False, | |
) | |
g.ax_marg_x.remove() | |
g_line = sns.lineplot(data=df, x='x', y='y', hue='Method', ax=g.ax_joint, estimator=None, ci=None) | |
if epoch is not None: | |
g_line.legend(loc='upper left', title="Epoch {:03d}".format(epoch+1)) | |
else: | |
g_line.legend(loc='upper left') | |
plt.gca().axes.set_xlabel(r'$x$') | |
plt.gca().axes.set_ylabel(r'$y$') | |
plt.savefig(save_path, bbox_inches='tight', dpi=200) | |
plt.close() | |
def clean_up_logs(): | |
if not osp.exists('train_log'): | |
os.mkdir('train_log') | |
for f in os.listdir('train_log'): | |
os.remove(osp.join('train_log', f)) | |
for f in ['regression_result.png', 'training_data.png', 'movie.mp4']: | |
if osp.isfile(f): | |
os.remove(f) | |
def run(num1, num2, num3, num4, num5, random_seed, mode): | |
sel_num = [num1, num2, num3, num4, num5] | |
sel_num = [int(num / 100 * NUM_PER_BUCKET) for num in sel_num] | |
torch.manual_seed(int(random_seed)) | |
all_x, all_y, prob = prepare_data(sel_num) | |
train_loader = DataLoader(DummyDataset(all_x, all_y, prob), BATCH_SIZE, shuffle=True) | |
training_df = make_dataframe(all_x, all_y) | |
clean_up_logs() | |
if mode == 0: | |
visualize(training_df, df_oracle, 'training_data.png') | |
if mode == 1: | |
regress(train_loader, training_df) | |
make_video() | |
if mode == 0: | |
text = "Press \"Start Regressing\" if your are happy with the training data. Regression takes ~30s." | |
else: | |
text = "Press \"Prepare Training Data\" before moving the sliders. You may also change the random seed." | |
training_data_plot = 'training_data.png' if mode == 0 else None | |
output = 'regression_result.png'.format(NUM_EPOCHS) if mode == 1 else None | |
video = "movie.mp4" if mode == 1 else None | |
return training_data_plot, output, video, text | |
if __name__ == '__main__': | |
iface = gr.Interface( | |
fn=run, | |
inputs=[ | |
gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [8, 10)'), | |
gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [6, 8)'), | |
gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [4, 6)'), | |
gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [2, 4)'), | |
gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [0, 2)'), | |
gr.inputs.Number(default=0, label='Random Seed', optional=False), | |
gr.inputs.Radio(['Prepare Training Data', 'Start Regressing!'], | |
type="index", default=None, label='Mode', optional=False), | |
], | |
outputs=[ | |
gr.outputs.Image(type="file", label="Training data"), | |
gr.outputs.Image(type="file", label="Regression result"), | |
gr.outputs.Video(type='mp4', label='Training process'), | |
gr.outputs.Textbox(type="auto", label='What\' s next?') | |
], | |
live=True, | |
allow_flagging='never', | |
title="Balanced MSE for Imbalanced Visual Regression [CVPR 2022]", | |
description="Welcome to the demo of Balanced MSE ⚖. In this demo, we will work on a simple task: imbalanced <i>linear</i> regression. <br>" | |
"To get started, move the sliders 🎚 to create your training data " | |
"or click the examples 📕 at the bottom of the page 👇👇", | |
examples=[ | |
[0.1, 0.8, 6.4, 51.2, 100, 0, 'Prepare Training Data'], | |
[1, 10, 100, 10, 1, 0, 'Prepare Training Data'], | |
], | |
css=".output-image, .image-preview {height: 500px !important}", | |
article="<p style='text-align: center'><a href='https://github.com/jiawei-ren/BalancedMSE' target='_blank'>Balanced MSE @ GitHub</a></p> " | |
) | |
iface.launch() | |