Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import torchvision.transforms as transforms | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
# Load the trained generator model | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
generator_A2B = Generator().to(device) | |
generator_A2B.load_state_dict(torch.load("generator_A2B.pth", map_location=device)) | |
generator_A2B.eval() | |
def transform_image(image): | |
transform = transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
]) | |
return transform(image).unsqueeze(0).to(device) | |
def generate(image): | |
image = Image.open(image).convert("RGB") | |
input_tensor = transform_image(image) | |
with torch.no_grad(): | |
output_tensor = generator_A2B(input_tensor) | |
output_image = (output_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() + 1) / 2 | |
plt.imshow(output_image) | |
plt.axis("off") | |
plt.show() | |
return output_image | |
# Create Gradio Interface | |
demo = gr.Interface( | |
fn=generate, | |
inputs=gr.Image(type="filepath"), | |
outputs=gr.Image(), | |
title="CycleGAN Image Translation", | |
description="Upload an image and get the translated output from the CycleGAN model." | |
) | |
demo.launch() | |