Hammad712 commited on
Commit
50f1229
·
verified ·
1 Parent(s): 39283bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -87
app.py CHANGED
@@ -1,127 +1,178 @@
1
  import streamlit as st
2
- from fastai.vision import open_image, load_learner, show_image
3
- import PIL.Image
4
  from PIL import Image
5
- from io import BytesIO
6
  import requests
 
 
7
  import torch
8
  import torch.nn as nn
9
- import os
10
- import tempfile
11
- import shutil
 
12
 
13
- # Define the FeatureLoss class
14
  class FeatureLoss(nn.Module):
15
  def __init__(self, m_feat, layer_ids, layer_wgts):
16
  super().__init__()
17
  self.m_feat = m_feat
18
  self.loss_features = [self.m_feat[i] for i in layer_ids]
19
- self.hooks = hook_outputs(self.loss_features, detach=False)
20
  self.wgts = layer_wgts
21
- self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))] + [f'gram_{i}' for i in range(len(layer_ids))]
22
 
23
- def make_features(self, x, clone=False):
24
- self.m_feat(x)
25
- return [(o.clone() if clone else o) for o in self.hooks.stored]
26
 
27
  def forward(self, input, target):
28
- out_feat = self.make_features(target, clone=True)
29
- in_feat = self.make_features(input)
30
- self.feat_losses = [base_loss(input, target)]
31
- self.feat_losses += [base_loss(f_in, f_out) * w for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
32
- self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out)) * w**2 * 5e3 for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
 
 
33
  self.metrics = dict(zip(self.metric_names, self.feat_losses))
34
  return sum(self.feat_losses)
35
 
36
- def __del__(self): self.hooks.remove()
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- def add_margin(pil_img, top, right, bottom, left, color):
39
- width, height = pil_img.size
40
- new_width = width + right + left
41
- new_height = height + top + bottom
42
- result = Image.new(pil_img.mode, (new_width, new_height), color)
43
- result.paste(pil_img, (left, top))
44
- return result
 
 
45
 
46
  def tensor_to_pil(tensor):
47
- """
48
- Convert a tensor to a PIL Image.
49
- """
50
  tensor = tensor.cpu().clamp(0, 1)
51
  array = tensor.numpy().transpose(1, 2, 0)
52
  return Image.fromarray((array * 255).astype('uint8'))
53
 
54
- def inference(image_path_or_url, learn):
55
- """
56
- Perform inference on an image from a local path or a URL.
57
-
58
- Parameters:
59
- image_path_or_url (str): Path to the local image or URL of the image.
60
- learn (Learner): The trained model.
61
-
62
- Returns:
63
- PIL.Image.Image: The high-resolution image generated by the model.
64
- """
65
- if image_path_or_url.startswith('http://') or image_path_or_url.startswith('https://'):
66
- response = requests.get(image_path_or_url)
67
- img = PIL.Image.open(BytesIO(response.content)).convert("RGB")
68
- else:
69
- img = PIL.Image.open(image_path_or_url).convert("RGB")
70
-
71
- im_new = add_margin(img, 250, 250, 250, 250, (255, 255, 255))
72
- im_new.save("test.jpg", quality=95)
73
- img = open_image("test.jpg")
74
- p, img_hr, b = learn.predict(img)
75
- return tensor_to_pil(img_hr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  # Streamlit application
78
- st.title("Image to Sketch")
79
- st.write("Convert any image to its sketch version using a Pix2Pix GAN Model.")
80
-
81
- # Download the model file from the Hugging Face repository
82
- model_url = "https://huggingface.co/Hammad712/image2sketch/resolve/main/image2sketch.pkl"
83
- model_file_path = 'image2sketch.pkl'
84
 
85
- if not os.path.exists(model_file_path):
86
- with st.spinner('Downloading model...'):
87
- response = requests.get(model_url)
88
- with open(model_file_path, 'wb') as f:
89
- f.write(response.content)
90
- st.success('Model downloaded successfully!')
91
 
92
- # Create a temporary directory for the model
93
- with tempfile.TemporaryDirectory() as tmpdirname:
94
- shutil.move(model_file_path, os.path.join(tmpdirname, 'export.pkl'))
95
- learn = load_learner(tmpdirname)
96
 
97
- # Sidebar to display previous generated images
98
- st.sidebar.title("Previous Results")
99
- if 'generated_images' not in st.session_state:
100
- st.session_state['generated_images'] = []
101
 
102
- for img in st.session_state['generated_images']:
103
- st.sidebar.image(img, use_column_width=True)
 
104
 
105
  # Input for image URL or path
106
- image_path_or_url = st.text_input("Enter image URL", "")
 
 
 
 
 
107
 
108
  # Run inference button
109
  if st.button("Convert"):
110
  if image_path_or_url:
111
  with st.spinner('Processing...'):
112
- high_res_image = inference(image_path_or_url, learn)
113
- original_image = PIL.Image.open(BytesIO(requests.get(image_path_or_url).content)) if image_path_or_url.startswith('http') else PIL.Image.open(image_path_or_url)
114
-
115
- # Display original and high-res images side by side
116
- col1, col2 = st.columns(2)
117
-
118
- with col1:
119
- st.image(original_image, caption='Original Image', use_column_width=True)
120
- with col2:
121
- st.image(high_res_image, caption='Sketch Image', use_column_width=True)
122
-
123
- # Save the generated image to session state for sidebar display
124
- st.session_state['generated_images'].append(high_res_image)
125
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  else:
127
  st.error("Please enter a valid image path or URL.")
 
1
  import streamlit as st
2
+ from fastai.vision import open_image, load_learner
 
3
  from PIL import Image
 
4
  import requests
5
+ import os
6
+ import logging
7
  import torch
8
  import torch.nn as nn
9
+ from io import BytesIO
10
+
11
+ # Setup logging
12
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
13
 
 
14
  class FeatureLoss(nn.Module):
15
  def __init__(self, m_feat, layer_ids, layer_wgts):
16
  super().__init__()
17
  self.m_feat = m_feat
18
  self.loss_features = [self.m_feat[i] for i in layer_ids]
19
+ self.hooks = [module.register_forward_hook(self.hook_fn) for module in self.loss_features]
20
  self.wgts = layer_wgts
21
+ self.metric_names = ['pixel'] + [f'feat_{i}' for i in range(len(layer_ids))] + [f'gram_{i}' for i in range(len(layer_ids))]
22
 
23
+ def hook_fn(self, module, input, output):
24
+ self.stored = output.detach().clone()
 
25
 
26
  def forward(self, input, target):
27
+ self.m_feat(target)
28
+ out_feat = [self.stored.clone()]
29
+ self.m_feat(input)
30
+ in_feat = [self.stored]
31
+ self.feat_losses = [torch.nn.functional.mse_loss(input, target)]
32
+ self.feat_losses += [torch.nn.functional.mse_loss(f_in, f_out) * w for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
33
+ self.feat_losses += [torch.nn.functional.mse_loss(self.gram_matrix(f_in), self.gram_matrix(f_out)) * w**2 * 5e3 for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
34
  self.metrics = dict(zip(self.metric_names, self.feat_losses))
35
  return sum(self.feat_losses)
36
 
37
+ @staticmethod
38
+ def gram_matrix(input):
39
+ b, c, h, w = input.size()
40
+ features = input.view(b, c, h * w)
41
+ G = torch.bmm(features, features.transpose(1, 2))
42
+ return G.div(c * h * w)
43
+
44
+ def fetch_image(image_path_or_url):
45
+ if isinstance(image_path_or_url, str) and image_path_or_url.startswith(('http://', 'https://')):
46
+ response = requests.get(image_path_or_url)
47
+ img = Image.open(BytesIO(response.content)).convert("RGB")
48
+ else:
49
+ img = Image.open(image_path_or_url).convert("RGB")
50
+ return img
51
 
52
+ def inference(image_path_or_url, learn):
53
+ img = fetch_image(image_path_or_url)
54
+ img_with_margin = Image.new('RGB', (img.width + 500, img.height + 500), (255, 255, 255))
55
+ img_with_margin.paste(img, (250, 250))
56
+ temp_image_path = "temp_image.jpg"
57
+ img_with_margin.save(temp_image_path, quality=95)
58
+ img_fastai = open_image(temp_image_path)
59
+ _, img_hr, _ = learn.predict(img_fastai)
60
+ return tensor_to_pil(img_hr)
61
 
62
  def tensor_to_pil(tensor):
 
 
 
63
  tensor = tensor.cpu().clamp(0, 1)
64
  array = tensor.numpy().transpose(1, 2, 0)
65
  return Image.fromarray((array * 255).astype('uint8'))
66
 
67
+ def load_model(model_url, model_file_path):
68
+ if not os.path.exists(model_file_path):
69
+ with st.spinner('Downloading model...'):
70
+ response = requests.get(model_url)
71
+ with open(model_file_path, 'wb') as f:
72
+ f.write(response.content)
73
+ st.success('Model downloaded successfully!')
74
+ learn = load_learner(os.path.dirname(model_file_path), model_file_path)
75
+ return learn
76
+
77
+ # Custom CSS
78
+ def set_css(style):
79
+ st.markdown(f"<style>{style}</style>", unsafe_allow_html=True)
80
+
81
+ # Combined dark mode styles
82
+ combined_css = """
83
+ .main, .sidebar .sidebar-content { background-color: #1c1c1c; color: #f0f2f6; }
84
+ .block-container { padding: 1rem 2rem; background-color: #333; border-radius: 10px; box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.5); }
85
+ .stButton>button, .stDownloadButton>button { background: linear-gradient(135deg, #ff7e5f, #feb47b); color: white; border: none; padding: 10px 24px; text-align: center; text-decoration: none; display: inline-block; font-size: 16px; margin: 4px 2px; cursor: pointer; border-radius: 5px; }
86
+ .stSpinner { color: #4CAF50; }
87
+ .title {
88
+ font-size: 3rem;
89
+ font-weight: bold;
90
+ display: flex;
91
+ align-items: center;
92
+ justify-content: center;
93
+ }
94
+ .colorful-text {
95
+ background: -webkit-linear-gradient(135deg, #ff7e5f, #feb47b);
96
+ -webkit-background-clip: text;
97
+ -webkit-text-fill-color: transparent;
98
+ }
99
+ .black-white-text {
100
+ color: black;
101
+ }
102
+ .small-input .stTextInput>div>input {
103
+ height: 2rem;
104
+ font-size: 0.9rem;
105
+ }
106
+ .small-file-uploader .stFileUploader>div>div {
107
+ height: 2rem;
108
+ font-size: 0.9rem;
109
+ }
110
+ .custom-text {
111
+ font-size: 1.2rem;
112
+ color: #feb47b;
113
+ text-align: center;
114
+ margin-top: -20px;
115
+ margin-bottom: 20px;
116
+ }
117
+ """
118
 
119
  # Streamlit application
120
+ st.set_page_config(layout="wide")
 
 
 
 
 
121
 
122
+ st.markdown(f"<style>{combined_css}</style>", unsafe_allow_html=True)
 
 
 
 
 
123
 
124
+ st.markdown('<div class="title"><span class="colorful-text">Image</span> <span class="black-white-text">to Sketch</span></div>', unsafe_allow_html=True)
125
+ st.markdown('<div class="custom-text">Jana\'s embroidery studio. Convert Photo\'s to Drawings using AI</div>', unsafe_allow_html=True)
 
 
126
 
127
+ # Download and load the model
128
+ MODEL_URL = "https://huggingface.co/Hammad712/image2sketch/resolve/main/image2sketch.pkl"
129
+ MODEL_FILE_PATH = 'image2sketch.pkl'
 
130
 
131
+ if 'learn' not in st.session_state:
132
+ st.session_state['learn'] = load_model(MODEL_URL, MODEL_FILE_PATH)
133
+ learn = st.session_state['learn']
134
 
135
  # Input for image URL or path
136
+ with st.expander("Input Options", expanded=True):
137
+ image_path_or_url = st.text_input("Enter image URL", "", key="image_url", placeholder="Enter image URL", help="Enter the URL of the image to convert")
138
+ uploaded_file = st.file_uploader("Or upload an image", type=["jpg", "jpeg", "png", "webp"], key="upload_file", help="Upload an image file to convert")
139
+
140
+ if uploaded_file is not None:
141
+ image_path_or_url = uploaded_file
142
 
143
  # Run inference button
144
  if st.button("Convert"):
145
  if image_path_or_url:
146
  with st.spinner('Processing...'):
147
+ try:
148
+ high_res_image = inference(image_path_or_url, learn)
149
+ original_image = fetch_image(image_path_or_url)
150
+
151
+ # Display original and high-res images side by side
152
+ st.markdown("### Result")
153
+ col1, col2 = st.columns(2)
154
+
155
+ with col1:
156
+ st.image(original_image, caption='Original Image', use_column_width=True)
157
+ with col2:
158
+ st.image(high_res_image, caption='Sketch Image', use_column_width=True)
159
+
160
+ # Provide a download button for the generated image
161
+ img_byte_arr = BytesIO()
162
+ high_res_image.save(img_byte_arr, format='JPEG')
163
+ img_byte_arr = img_byte_arr.getvalue()
164
+
165
+ st.download_button(
166
+ label="Download Sketch Image",
167
+ data=img_byte_arr,
168
+ file_name="sketch_image.jpg",
169
+ mime="image/jpeg"
170
+ )
171
+
172
+ st.success("Image processed successfully!")
173
+
174
+ except Exception as e:
175
+ st.error(f"An error occurred: {e}")
176
+ logging.error("Error during inference", exc_info=True)
177
  else:
178
  st.error("Please enter a valid image path or URL.")