Update app.py
Browse files
app.py
CHANGED
@@ -8,7 +8,6 @@ import numpy as np
|
|
8 |
import requests
|
9 |
from io import BytesIO
|
10 |
|
11 |
-
# Download the model from Hugging Face Hub
|
12 |
repo_id = "Hammad712/GAN-Colorization-Model"
|
13 |
model_filename = "generator.pt"
|
14 |
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
|
@@ -29,7 +28,6 @@ G_net = build_generator(n_input=1, n_output=2, size=256)
|
|
29 |
G_net.load_state_dict(torch.load(model_path, map_location=device))
|
30 |
G_net.eval()
|
31 |
|
32 |
-
# Preprocessing function
|
33 |
def preprocess_image(img):
|
34 |
img = img.convert("RGB")
|
35 |
img = transforms.Resize((256, 256), Image.BICUBIC)(img)
|
@@ -39,7 +37,6 @@ def preprocess_image(img):
|
|
39 |
L = img_to_lab[[0], ...] / 50. - 1.
|
40 |
return L.unsqueeze(0).to(device)
|
41 |
|
42 |
-
# Inference function
|
43 |
def colorize_image(img, model):
|
44 |
L = preprocess_image(img)
|
45 |
with torch.no_grad():
|
@@ -53,14 +50,12 @@ def colorize_image(img, model):
|
|
53 |
rgb_imgs.append(img_rgb)
|
54 |
return np.stack(rgb_imgs, axis=0)
|
55 |
|
56 |
-
# Gradio interface
|
57 |
def colorize(img):
|
58 |
colorized_images = colorize_image(img, G_net)
|
59 |
colorized_image = colorized_images[0]
|
60 |
return Image.fromarray((colorized_image * 255).astype(np.uint8))
|
61 |
|
62 |
-
|
63 |
-
iface = gr.Interface(
|
64 |
fn=colorize,
|
65 |
inputs=gr.Image(type="pil", label="Upload Grayscale Image"),
|
66 |
outputs=gr.Image(type="pil", label="Colorized Image"),
|
@@ -69,5 +64,4 @@ iface = gr.Interface(
|
|
69 |
allow_flagging="never"
|
70 |
)
|
71 |
|
72 |
-
|
73 |
-
iface.launch()
|
|
|
8 |
import requests
|
9 |
from io import BytesIO
|
10 |
|
|
|
11 |
repo_id = "Hammad712/GAN-Colorization-Model"
|
12 |
model_filename = "generator.pt"
|
13 |
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
|
|
|
28 |
G_net.load_state_dict(torch.load(model_path, map_location=device))
|
29 |
G_net.eval()
|
30 |
|
|
|
31 |
def preprocess_image(img):
|
32 |
img = img.convert("RGB")
|
33 |
img = transforms.Resize((256, 256), Image.BICUBIC)(img)
|
|
|
37 |
L = img_to_lab[[0], ...] / 50. - 1.
|
38 |
return L.unsqueeze(0).to(device)
|
39 |
|
|
|
40 |
def colorize_image(img, model):
|
41 |
L = preprocess_image(img)
|
42 |
with torch.no_grad():
|
|
|
50 |
rgb_imgs.append(img_rgb)
|
51 |
return np.stack(rgb_imgs, axis=0)
|
52 |
|
|
|
53 |
def colorize(img):
|
54 |
colorized_images = colorize_image(img, G_net)
|
55 |
colorized_image = colorized_images[0]
|
56 |
return Image.fromarray((colorized_image * 255).astype(np.uint8))
|
57 |
|
58 |
+
app = gr.Interface(
|
|
|
59 |
fn=colorize,
|
60 |
inputs=gr.Image(type="pil", label="Upload Grayscale Image"),
|
61 |
outputs=gr.Image(type="pil", label="Colorized Image"),
|
|
|
64 |
allow_flagging="never"
|
65 |
)
|
66 |
|
67 |
+
app.launch()
|
|