HAMIM-ML commited on
Commit
dd5df30
·
1 Parent(s): 9bd2271

predicton added

Browse files
app.py CHANGED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from src.imagecolorization.pipeline.prediction import ImageColorizationSystem
3
+ from PIL import Image
4
+ from io import BytesIO
5
+
6
+
7
+
8
+ # Streamlit app
9
+ st.title("Image Colorization App")
10
+ st.write("Upload a black-and-white image, and this app will colorize it.")
11
+
12
+ # Load the model
13
+ colorization_system = ImageColorizationSystem("C:\\mlops project\\image-colorization-mlops\\artifacts\\trained_model\\cwgan_generator_final.pt", "C:\\mlops project\\image-colorization-mlops\\artifacts\\trained_model\\cwgan_critic_final.pt")
14
+
15
+ # Upload image
16
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
17
+
18
+ if uploaded_file is not None:
19
+ # Load and display the image
20
+ image = Image.open(uploaded_file)
21
+ st.image(image, caption='Uploaded Image', use_column_width=True)
22
+
23
+ # Convert image to grayscale and colorize it
24
+ grayscale_image = colorization_system.load_image(image)
25
+ colorized_image = colorization_system.colorize(grayscale_image)
26
+
27
+ # Convert to Image and display
28
+ colorized_image_pil = Image.fromarray((colorized_image * 255).astype('uint8'))
29
+ st.image(colorized_image_pil, caption='Colorized Image', use_column_width=True)
30
+
31
+ # Option to download the colorized image
32
+ buf = BytesIO()
33
+ colorized_image_pil.save(buf, format="PNG")
34
+ byte_im = buf.getvalue()
35
+ st.download_button("Download Colorized Image", byte_im, file_name="colorized_image.png")
src/imagecolorization/pipeline/prediction.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, optim
3
+ from torchvision import transforms
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from torch.autograd import Variable
6
+ from torchvision import models
7
+ from torch.nn import functional as F
8
+ import numpy as np
9
+ from skimage.color import rgb2lab, lab2rgb
10
+ import streamlit as st
11
+ from PIL import Image
12
+ from io import BytesIO
13
+ from src.imagecolorization.conponents.model_building import Generator, Critic
14
+
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+
18
+
19
+ class ImageColorizationSystem:
20
+ def __init__(self, generator_path, critic_path):
21
+ self.generator = Generator(1, 2) # Expecting 1 channel input, 2 channel output
22
+ self.critic = Critic() # Initialize your critic model
23
+ self.generator.load_state_dict(torch.load(generator_path, map_location=device), strict=False)
24
+ self.critic.load_state_dict(torch.load(critic_path, map_location=device), strict=False)
25
+ self.generator.to(device)
26
+ self.critic.to(device)
27
+ self.generator.eval()
28
+ self.critic.eval()
29
+
30
+ def load_image(self, image):
31
+ image = image.convert("L") # Convert to grayscale (1 channel)
32
+ image = image.resize((224, 224)) # Resize to the expected input size
33
+ return image
34
+
35
+ def colorize(self, bw_image):
36
+ bw_tensor = transforms.ToTensor()(bw_image).unsqueeze(0).to(device) # Move tensor to the correct device
37
+ with torch.no_grad():
38
+ colorized = self.generator(bw_tensor)
39
+ colorized = colorized.cpu() # Move tensor back to CPU for processing
40
+ return self.lab_to_rgb(bw_tensor.squeeze(), colorized.squeeze())
41
+
42
+ def lab_to_rgb(self, L, ab):
43
+ # Ensure both tensors are on CPU
44
+ L = L.cpu() * 100
45
+ ab = (ab.cpu() * 2 - 1) * 128
46
+ # Concatenate on CPU
47
+ Lab = torch.cat([L.unsqueeze(0), ab], dim=0).numpy() # Move to numpy for conversion
48
+ Lab = np.moveaxis(Lab, 0, -1)
49
+ rgb_img = lab2rgb(Lab)
50
+ return rgb_img