File size: 1,415 Bytes
e53edb8 75bf717 e53edb8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import streamlit as st
import torch
from normflows import nflow
import numpy as np
import seaborn as sns
import pandas as pd
uploaded_file = st.file_uploader("Choose original dataset")
bw = st.number_input('Scale',value=3.05)
def compute():
api = nflow(dim=8,latent=16,dataset=uploaded_file)
api.compile(optim=torch.optim.ASGD,bw=bw,lr=0.0001,wd=None)
my_bar = st.progress(0, text='Currently in progress')
for idx in api.train(iters=10000):
my_bar.progress(idx[0]/10000, text=str(idx[1]))
samples = np.array(api.model.sample(
torch.tensor(api.scaled).float()).detach())
# fig, ax = plt.subplots()
g = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind='kde',cmap=sns.color_palette("Blues", as_cmap=True),fill=True,label='Gaussian KDE',levels=50)
w = sns.scatterplot(x=api.scaled[:,0],y=api.scaled[:,1],ax=g.ax_joint,c='orange',marker='+',s=100,label='Real')
st.pyplot(w.get_figure())
def random_normal_samples(n, dim=2):
return torch.zeros(n, dim).normal_(mean=0, std=1)
samples = np.array(api.model.sample(torch.tensor(random_normal_samples(1000,api.scaled.shape[-1])).float()).detach())
return api.scaler.inverse_transform(samples)
if uploaded_file is not None:
samples=compute()
st.download_button('Download generated CSV', pd.DataFrame(samples).to_csv(), 'text/csv') |