File size: 2,049 Bytes
59e7820
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c10e08c
59e7820
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import gradio as gr
import numpy as np
from model import model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform1 = transforms.Compose([
    transforms.Resize((128, 128)),  # Resize the image to 128x128 for the model
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

transform2 = transforms.Compose([
    transforms.Resize((512, 512))  # Resize the image to 512x512 for display
])

def load_image(image):
    image = Image.fromarray(image).convert('RGB')
    image = transform1(image)
    return image.unsqueeze(0).to(device)

def infer_image(image, noise_level):
    image = load_image(image)
    with torch.no_grad():
        mu, logvar = model.encode(image)
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std) * noise_level  
        z = mu + eps * std
        decoded_image = model.decode(z)
    
    decoded_image = decoded_image.squeeze().permute(1, 2, 0).cpu().numpy().astype(np.float32) * 0.5 + 0.5
    decoded_image = np.clip(decoded_image, 0, 1)  
    
    decoded_image = Image.fromarray((decoded_image * 255).astype(np.uint8))
    decoded_image = transform2(decoded_image)
    return np.array(decoded_image)

examples = [
    ["example_images/image5.png", 0.98],
    ["example_images/image1.jpg", 0.1],
    ["example_images/image2.png", 0.5],
    ["example_images/image3.jpg", 1.0],
]

with gr.Blocks() as vae:
    noise_slider = gr.Slider(0, 10, value=0.01, step=0.01, label="Noise Level")
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label="Upload an image", type="numpy")
        with gr.Column():
            output_image = gr.Image(label="Reconstructed Image")

    input_image.change(fn=infer_image, inputs=[input_image, noise_slider], outputs=output_image)
    noise_slider.change(fn=infer_image, inputs=[input_image, noise_slider], outputs=output_image)

    gr.Examples(examples=examples, inputs=[input_image, noise_slider])