Spaces:
Running
Running
File size: 2,049 Bytes
59e7820 c10e08c 59e7820 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import gradio as gr
import numpy as np
from model import model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform1 = transforms.Compose([
transforms.Resize((128, 128)), # Resize the image to 128x128 for the model
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
transform2 = transforms.Compose([
transforms.Resize((512, 512)) # Resize the image to 512x512 for display
])
def load_image(image):
image = Image.fromarray(image).convert('RGB')
image = transform1(image)
return image.unsqueeze(0).to(device)
def infer_image(image, noise_level):
image = load_image(image)
with torch.no_grad():
mu, logvar = model.encode(image)
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std) * noise_level
z = mu + eps * std
decoded_image = model.decode(z)
decoded_image = decoded_image.squeeze().permute(1, 2, 0).cpu().numpy().astype(np.float32) * 0.5 + 0.5
decoded_image = np.clip(decoded_image, 0, 1)
decoded_image = Image.fromarray((decoded_image * 255).astype(np.uint8))
decoded_image = transform2(decoded_image)
return np.array(decoded_image)
examples = [
["example_images/image5.png", 0.98],
["example_images/image1.jpg", 0.1],
["example_images/image2.png", 0.5],
["example_images/image3.jpg", 1.0],
]
with gr.Blocks() as vae:
noise_slider = gr.Slider(0, 10, value=0.01, step=0.01, label="Noise Level")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Upload an image", type="numpy")
with gr.Column():
output_image = gr.Image(label="Reconstructed Image")
input_image.change(fn=infer_image, inputs=[input_image, noise_slider], outputs=output_image)
noise_slider.change(fn=infer_image, inputs=[input_image, noise_slider], outputs=output_image)
gr.Examples(examples=examples, inputs=[input_image, noise_slider])
|