ItsJATAYU commited on
Commit
8ec48a4
·
verified ·
1 Parent(s): e5285b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -8
app.py CHANGED
@@ -8,7 +8,6 @@ import numpy as np
8
  import requests
9
  from io import BytesIO
10
 
11
- # Download the model from Hugging Face Hub
12
  repo_id = "Hammad712/GAN-Colorization-Model"
13
  model_filename = "generator.pt"
14
  model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
@@ -29,7 +28,6 @@ G_net = build_generator(n_input=1, n_output=2, size=256)
29
  G_net.load_state_dict(torch.load(model_path, map_location=device))
30
  G_net.eval()
31
 
32
- # Preprocessing function
33
  def preprocess_image(img):
34
  img = img.convert("RGB")
35
  img = transforms.Resize((256, 256), Image.BICUBIC)(img)
@@ -39,7 +37,6 @@ def preprocess_image(img):
39
  L = img_to_lab[[0], ...] / 50. - 1.
40
  return L.unsqueeze(0).to(device)
41
 
42
- # Inference function
43
  def colorize_image(img, model):
44
  L = preprocess_image(img)
45
  with torch.no_grad():
@@ -53,14 +50,12 @@ def colorize_image(img, model):
53
  rgb_imgs.append(img_rgb)
54
  return np.stack(rgb_imgs, axis=0)
55
 
56
- # Gradio interface
57
  def colorize(img):
58
  colorized_images = colorize_image(img, G_net)
59
  colorized_image = colorized_images[0]
60
  return Image.fromarray((colorized_image * 255).astype(np.uint8))
61
 
62
- # Create the Gradio interface
63
- iface = gr.Interface(
64
  fn=colorize,
65
  inputs=gr.Image(type="pil", label="Upload Grayscale Image"),
66
  outputs=gr.Image(type="pil", label="Colorized Image"),
@@ -69,5 +64,4 @@ iface = gr.Interface(
69
  allow_flagging="never"
70
  )
71
 
72
- # Launch the interface
73
- iface.launch()
 
8
  import requests
9
  from io import BytesIO
10
 
 
11
  repo_id = "Hammad712/GAN-Colorization-Model"
12
  model_filename = "generator.pt"
13
  model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
 
28
  G_net.load_state_dict(torch.load(model_path, map_location=device))
29
  G_net.eval()
30
 
 
31
  def preprocess_image(img):
32
  img = img.convert("RGB")
33
  img = transforms.Resize((256, 256), Image.BICUBIC)(img)
 
37
  L = img_to_lab[[0], ...] / 50. - 1.
38
  return L.unsqueeze(0).to(device)
39
 
 
40
  def colorize_image(img, model):
41
  L = preprocess_image(img)
42
  with torch.no_grad():
 
50
  rgb_imgs.append(img_rgb)
51
  return np.stack(rgb_imgs, axis=0)
52
 
 
53
  def colorize(img):
54
  colorized_images = colorize_image(img, G_net)
55
  colorized_image = colorized_images[0]
56
  return Image.fromarray((colorized_image * 255).astype(np.uint8))
57
 
58
+ app = gr.Interface(
 
59
  fn=colorize,
60
  inputs=gr.Image(type="pil", label="Upload Grayscale Image"),
61
  outputs=gr.Image(type="pil", label="Colorized Image"),
 
64
  allow_flagging="never"
65
  )
66
 
67
+ app.launch()