awacke1 commited on
Commit
d999fe5
Β·
verified Β·
1 Parent(s): 811cdf9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +228 -101
app.py CHANGED
@@ -1,80 +1,74 @@
1
- import numpy as np
 
 
 
2
  import torch
3
  import torch.nn as nn
4
- import gradio as gr
 
5
  from PIL import Image
 
6
  import torchvision.transforms as transforms
7
- import os # πŸ“ For file operations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # 🧠 Neural network layers
10
  norm_layer = nn.InstanceNorm2d
11
 
12
- # 🧱 Building block for the generator
13
  class ResidualBlock(nn.Module):
14
  def __init__(self, in_features):
15
  super(ResidualBlock, self).__init__()
16
-
17
- conv_block = [ nn.ReflectionPad2d(1),
18
- nn.Conv2d(in_features, in_features, 3),
19
- norm_layer(in_features),
20
- nn.ReLU(inplace=True),
21
- nn.ReflectionPad2d(1),
22
- nn.Conv2d(in_features, in_features, 3),
23
- norm_layer(in_features)
24
- ]
25
-
26
  self.conv_block = nn.Sequential(*conv_block)
27
 
28
  def forward(self, x):
29
  return x + self.conv_block(x)
30
 
31
- # 🎨 Generator model for creating line drawings
32
  class Generator(nn.Module):
33
  def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
34
  super(Generator, self).__init__()
35
-
36
- # 🏁 Initial convolution block
37
- model0 = [ nn.ReflectionPad2d(3),
38
- nn.Conv2d(input_nc, 64, 7),
39
- norm_layer(64),
40
- nn.ReLU(inplace=True) ]
41
  self.model0 = nn.Sequential(*model0)
42
-
43
- # πŸ”½ Downsampling
44
  model1 = []
45
- in_features = 64
46
- out_features = in_features*2
47
  for _ in range(2):
48
- model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
49
- norm_layer(out_features),
50
- nn.ReLU(inplace=True) ]
51
- in_features = out_features
52
- out_features = in_features*2
53
  self.model1 = nn.Sequential(*model1)
54
-
55
- # πŸ” Residual blocks
56
- model2 = []
57
- for _ in range(n_residual_blocks):
58
- model2 += [ResidualBlock(in_features)]
59
  self.model2 = nn.Sequential(*model2)
60
-
61
- # πŸ”Ό Upsampling
62
  model3 = []
63
- out_features = in_features//2
64
  for _ in range(2):
65
- model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
66
- norm_layer(out_features),
67
- nn.ReLU(inplace=True) ]
68
- in_features = out_features
69
- out_features = in_features//2
70
  self.model3 = nn.Sequential(*model3)
71
-
72
- # 🎭 Output layer
73
- model4 = [ nn.ReflectionPad2d(3),
74
- nn.Conv2d(64, output_nc, 7)]
75
  if sigmoid:
76
  model4 += [nn.Sigmoid()]
77
-
78
  self.model4 = nn.Sequential(*model4)
79
 
80
  def forward(self, x, cond=None):
@@ -83,71 +77,204 @@ class Generator(nn.Module):
83
  out = self.model2(out)
84
  out = self.model3(out)
85
  out = self.model4(out)
86
-
87
  return out
88
 
89
- # πŸ”§ Load the models
90
  model1 = Generator(3, 1, 3)
91
- model1.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu'), weights_only=True))
92
- model1.eval()
93
-
94
  model2 = Generator(3, 1, 3)
95
- model2.load_state_dict(torch.load('model2.pth', map_location=torch.device('cpu'), weights_only=True))
 
 
 
 
 
96
  model2.eval()
97
 
98
- # πŸ–ΌοΈ Function to process the image and create line drawing
99
- def predict(input_img, ver):
100
- # Open the image and get its original size
101
- original_img = Image.open(input_img)
102
- original_size = original_img.size
 
 
 
 
 
 
103
 
104
- # Define the transformation pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  transform = transforms.Compose([
106
  transforms.Resize(256, Image.BICUBIC),
107
  transforms.ToTensor(),
108
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
109
  ])
 
 
 
 
 
110
 
111
- # Apply the transformation
112
- input_tensor = transform(original_img)
113
- input_tensor = input_tensor.unsqueeze(0)
 
 
 
 
 
114
 
115
- # Process the image through the model
116
- with torch.no_grad():
117
- if ver == 'Simple Lines':
118
- output = model2(input_tensor)
119
- else:
120
- output = model1(input_tensor)
 
 
121
 
122
- # Convert the output tensor to an image
123
- output_img = transforms.ToPILImage()(output.squeeze().cpu().clamp(0, 1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- # Resize the output image back to the original size
126
- output_img = output_img.resize(original_size, Image.BICUBIC)
127
-
128
- return output_img
129
-
130
- # πŸ“ Title for the Gradio interface
131
- title="πŸ–ŒοΈ Image to Line Drawings - Complex and Simple Portraits and Landscapes"
132
-
133
- # πŸ–ΌοΈ Dynamically generate examples from images in the directory
134
- examples = []
135
- image_dir = '.' # Assuming images are in the current directory
136
- for file in os.listdir(image_dir):
137
- if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif')):
138
- examples.append([file, 'Simple Lines'])
139
- examples.append([file, 'Complex Lines'])
140
-
141
- # πŸš€ Create and launch the Gradio interface
142
- iface = gr.Interface(
143
- fn=predict,
144
- inputs=[
145
- gr.Image(type='filepath'),
146
- gr.Radio(['Complex Lines', 'Simple Lines'], label='version', value='Simple Lines')
147
- ],
148
- outputs=gr.Image(type="pil"),
149
- title=title,
150
- examples=examples
151
- )
152
-
153
- iface.launch()
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import base64
4
+ import time
5
  import torch
6
  import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
  from PIL import Image
10
+ import gradio as gr
11
  import torchvision.transforms as transforms
12
+ from transformers import AutoModel, AutoTokenizer
13
+ from diffusers import StableDiffusionPipeline
14
+ from torch.utils.data import Dataset, DataLoader
15
+ import asyncio
16
+ import aiofiles
17
+ import fitz # PyMuPDF
18
+ import requests
19
+ import logging
20
+ from io import BytesIO
21
+ from dataclasses import dataclass
22
+ from typing import Optional
23
+
24
+ # Logging setup
25
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
26
+ logger = logging.getLogger(__name__)
27
 
28
+ # Neural network layers for line drawing
29
  norm_layer = nn.InstanceNorm2d
30
 
31
+ # Residual Block for Generator
32
  class ResidualBlock(nn.Module):
33
  def __init__(self, in_features):
34
  super(ResidualBlock, self).__init__()
35
+ conv_block = [
36
+ nn.ReflectionPad2d(1),
37
+ nn.Conv2d(in_features, in_features, 3),
38
+ norm_layer(in_features),
39
+ nn.ReLU(inplace=True),
40
+ nn.ReflectionPad2d(1),
41
+ nn.Conv2d(in_features, in_features, 3),
42
+ norm_layer(in_features)
43
+ ]
 
44
  self.conv_block = nn.Sequential(*conv_block)
45
 
46
  def forward(self, x):
47
  return x + self.conv_block(x)
48
 
49
+ # Generator for Line Drawings
50
  class Generator(nn.Module):
51
  def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
52
  super(Generator, self).__init__()
53
+ model0 = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7), norm_layer(64), nn.ReLU(inplace=True)]
 
 
 
 
 
54
  self.model0 = nn.Sequential(*model0)
 
 
55
  model1 = []
56
+ in_features, out_features = 64, 128
 
57
  for _ in range(2):
58
+ model1 += [nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), norm_layer(out_features), nn.ReLU(inplace=True)]
59
+ in_features, out_features = out_features, out_features * 2
 
 
 
60
  self.model1 = nn.Sequential(*model1)
61
+ model2 = [ResidualBlock(in_features) for _ in range(n_residual_blocks)]
 
 
 
 
62
  self.model2 = nn.Sequential(*model2)
 
 
63
  model3 = []
64
+ out_features = in_features // 2
65
  for _ in range(2):
66
+ model3 += [nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), norm_layer(out_features), nn.ReLU(inplace=True)]
67
+ in_features, out_features = out_features, out_features // 2
 
 
 
68
  self.model3 = nn.Sequential(*model3)
69
+ model4 = [nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7)]
 
 
 
70
  if sigmoid:
71
  model4 += [nn.Sigmoid()]
 
72
  self.model4 = nn.Sequential(*model4)
73
 
74
  def forward(self, x, cond=None):
 
77
  out = self.model2(out)
78
  out = self.model3(out)
79
  out = self.model4(out)
 
80
  return out
81
 
82
+ # Load Line Drawing Models
83
  model1 = Generator(3, 1, 3)
 
 
 
84
  model2 = Generator(3, 1, 3)
85
+ try:
86
+ model1.load_state_dict(torch.load('model.pth', map_location='cpu', weights_only=True))
87
+ model2.load_state_dict(torch.load('model2.pth', map_location='cpu', weights_only=True))
88
+ except FileNotFoundError:
89
+ logger.warning("Model files not found. Please ensure 'model.pth' and 'model2.pth' are available.")
90
+ model1.eval()
91
  model2.eval()
92
 
93
+ # Tiny Diffusion Model
94
+ class TinyUNet(nn.Module):
95
+ def __init__(self, in_channels=3, out_channels=3):
96
+ super(TinyUNet, self).__init__()
97
+ self.down1 = nn.Conv2d(in_channels, 32, 3, padding=1)
98
+ self.down2 = nn.Conv2d(32, 64, 3, padding=1, stride=2)
99
+ self.mid = nn.Conv2d(64, 128, 3, padding=1)
100
+ self.up1 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1)
101
+ self.up2 = nn.Conv2d(64 + 32, 32, 3, padding=1)
102
+ self.out = nn.Conv2d(32, out_channels, 3, padding=1)
103
+ self.time_embed = nn.Linear(1, 64)
104
 
105
+ def forward(self, x, t):
106
+ t_embed = F.relu(self.time_embed(t.unsqueeze(-1))).view(t_embed.size(0), t_embed.size(1), 1, 1)
107
+ x1 = F.relu(self.down1(x))
108
+ x2 = F.relu(self.down2(x1))
109
+ x_mid = F.relu(self.mid(x2)) + t_embed
110
+ x_up1 = F.relu(self.up1(x_mid))
111
+ x_up2 = F.relu(self.up2(torch.cat([x_up1, x1], dim=1)))
112
+ return self.out(x_up2)
113
+
114
+ class TinyDiffusion:
115
+ def __init__(self, model, timesteps=100):
116
+ self.model = model
117
+ self.timesteps = timesteps
118
+ self.beta = torch.linspace(0.0001, 0.02, timesteps)
119
+ self.alpha = 1 - self.beta
120
+ self.alpha_cumprod = torch.cumprod(self.alpha, dim=0)
121
+
122
+ def train(self, images, epochs=10):
123
+ dataset = [torch.tensor(np.array(img.convert("RGB")).transpose(2, 0, 1), dtype=torch.float32) / 255.0 for img in images]
124
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
125
+ optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
126
+ device = torch.device("cpu")
127
+ self.model.to(device)
128
+ for epoch in range(epochs):
129
+ total_loss = 0
130
+ for x in dataloader:
131
+ x = x.to(device)
132
+ t = torch.randint(0, self.timesteps, (x.size(0),), device=device).float()
133
+ noise = torch.randn_like(x)
134
+ alpha_t = self.alpha_cumprod[t.long()].view(-1, 1, 1, 1)
135
+ x_noisy = torch.sqrt(alpha_t) * x + torch.sqrt(1 - alpha_t) * noise
136
+ pred_noise = self.model(x_noisy, t)
137
+ loss = F.mse_loss(pred_noise, noise)
138
+ optimizer.zero_grad()
139
+ loss.backward()
140
+ optimizer.step()
141
+ total_loss += loss.item()
142
+ logger.info(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(dataloader):.4f}")
143
+ return self
144
+
145
+ def generate(self, size=(64, 64), steps=100):
146
+ device = torch.device("cpu")
147
+ x = torch.randn(1, 3, size[0], size[1], device=device)
148
+ for t in reversed(range(steps)):
149
+ t_tensor = torch.full((1,), t, device=device, dtype=torch.float32)
150
+ alpha_t = self.alpha_cumprod[t].view(-1, 1, 1, 1)
151
+ pred_noise = self.model(x, t_tensor)
152
+ x = (x - (1 - self.alpha[t]) / torch.sqrt(1 - alpha_t) * pred_noise) / torch.sqrt(self.alpha[t])
153
+ if t > 0:
154
+ x += torch.sqrt(self.beta[t]) * torch.randn_like(x)
155
+ x = torch.clamp(x * 255, 0, 255).byte()
156
+ return Image.fromarray(x.squeeze(0).permute(1, 2, 0).cpu().numpy())
157
+
158
+ # Utility Functions
159
+ def generate_filename(sequence, ext="png"):
160
+ timestamp = time.strftime("%d%m%Y%H%M%S")
161
+ return f"{sequence}_{timestamp}.{ext}"
162
+
163
+ def predict_line_drawing(input_img, ver):
164
+ original_img = Image.open(input_img) if isinstance(input_img, str) else input_img
165
+ original_size = original_img.size
166
  transform = transforms.Compose([
167
  transforms.Resize(256, Image.BICUBIC),
168
  transforms.ToTensor(),
169
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
170
  ])
171
+ input_tensor = transform(original_img).unsqueeze(0)
172
+ with torch.no_grad():
173
+ output = model2(input_tensor) if ver == 'Simple Lines' else model1(input_tensor)
174
+ output_img = transforms.ToPILImage()(output.squeeze().cpu().clamp(0, 1))
175
+ return output_img.resize(original_size, Image.BICUBIC)
176
 
177
+ async def process_ocr(image):
178
+ tokenizer = AutoTokenizer.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True)
179
+ model = AutoModel.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True, torch_dtype=torch.float32).to("cpu").eval()
180
+ result = model.chat(tokenizer, image, ocr_type='ocr')
181
+ output_file = generate_filename("ocr_output", "txt")
182
+ async with aiofiles.open(output_file, "w") as f:
183
+ await f.write(result)
184
+ return result, output_file
185
 
186
+ async def process_diffusion(images):
187
+ unet = TinyUNet()
188
+ diffusion = TinyDiffusion(unet)
189
+ diffusion.train(images)
190
+ gen_image = diffusion.generate()
191
+ output_file = generate_filename("diffusion_output", "png")
192
+ gen_image.save(output_file)
193
+ return gen_image, output_file
194
 
195
+ def download_pdf(url):
196
+ output_path = f"pdf_{int(time.time())}.pdf"
197
+ response = requests.get(url, stream=True, timeout=10)
198
+ if response.status_code == 200:
199
+ with open(output_path, "wb") as f:
200
+ for chunk in response.iter_content(chunk_size=8192):
201
+ f.write(chunk)
202
+ return output_path
203
+ return None
204
+
205
+ # Gradio Blocks UI
206
+ with gr.Blocks(title="Mystical AI Vision Studio 🌌", css="""
207
+ .gr-button {background-color: #4CAF50; color: white;}
208
+ .gr-tab {border: 2px solid #2196F3; border-radius: 5px;}
209
+ #gallery img {border: 1px solid #ddd; border-radius: 4px;}
210
+ """) as demo:
211
+ gr.Markdown("<h1 style='text-align: center; color: #2196F3;'>Mystical AI Vision Studio 🌌</h1>")
212
+ gr.Markdown("<p style='text-align: center;'>Transform images into line drawings, extract text with OCR, and craft unique art with diffusion!</p>")
213
+
214
+ with gr.Tab("Image to Line Drawings 🎨"):
215
+ with gr.Row():
216
+ with gr.Column():
217
+ img_input = gr.Image(type="pil", label="Upload Image")
218
+ version = gr.Radio(['Complex Lines', 'Simple Lines'], label='Style', value='Simple Lines')
219
+ submit_btn = gr.Button("Generate Line Drawing")
220
+ with gr.Column():
221
+ line_output = gr.Image(type="pil", label="Line Drawing")
222
+ download_btn = gr.Button("Download Output")
223
+ submit_btn.click(predict_line_drawing, inputs=[img_input, version], outputs=line_output)
224
+ download_btn.click(lambda x: gr.File(x, label="Download Line Drawing"), inputs=line_output, outputs=None)
225
+
226
+ with gr.Tab("OCR Vision πŸ”"):
227
+ with gr.Row():
228
+ with gr.Column():
229
+ ocr_input = gr.Image(type="pil", label="Upload Image or PDF Snapshot")
230
+ ocr_btn = gr.Button("Extract Text")
231
+ with gr.Column():
232
+ ocr_text = gr.Textbox(label="Extracted Text", interactive=False)
233
+ ocr_file = gr.File(label="Download OCR Result")
234
+ async def run_ocr(img):
235
+ result, file_path = await process_ocr(img)
236
+ return result, file_path
237
+ ocr_btn.click(run_ocr, inputs=ocr_input, outputs=[ocr_text, ocr_file])
238
+
239
+ with gr.Tab("Custom Diffusion πŸŽ¨πŸ€“"):
240
+ with gr.Row():
241
+ with gr.Column():
242
+ diffusion_input = gr.File(label="Upload Images for Training", multiple=True)
243
+ diffusion_btn = gr.Button("Train & Generate")
244
+ with gr.Column():
245
+ diffusion_output = gr.Image(type="pil", label="Generated Art")
246
+ diffusion_file = gr.File(label="Download Art")
247
+ async def run_diffusion(files):
248
+ images = [Image.open(BytesIO(f.read())) for f in files]
249
+ img, file_path = await process_diffusion(images)
250
+ return img, file_path
251
+ diffusion_btn.click(run_diffusion, inputs=diffusion_input, outputs=[diffusion_output, diffusion_file])
252
+
253
+ with gr.Tab("PDF Downloader πŸ“₯"):
254
+ with gr.Row():
255
+ pdf_url = gr.Textbox(label="Enter PDF URL")
256
+ pdf_btn = gr.Button("Download PDF")
257
+ pdf_output = gr.File(label="Downloaded PDF")
258
+ pdf_btn.click(download_pdf, inputs=pdf_url, outputs=pdf_output)
259
+
260
+ with gr.Tab("Gallery πŸ“Έ"):
261
+ gallery = gr.Gallery(label="Processed Outputs", elem_id="gallery")
262
+ def update_gallery():
263
+ files = [f for f in os.listdir('.') if f.endswith(('.png', '.txt', '.pdf'))]
264
+ return [f for f in files]
265
+ gr.Button("Refresh Gallery").click(update_gallery, outputs=gallery)
266
+
267
+ # JavaScript for dynamic UI enhancements
268
+ gr.HTML("""
269
+ <script>
270
+ document.addEventListener('DOMContentLoaded', () => {
271
+ const buttons = document.querySelectorAll('.gr-button');
272
+ buttons.forEach(btn => {
273
+ btn.addEventListener('mouseover', () => btn.style.backgroundColor = '#45a049');
274
+ btn.addEventListener('mouseout', () => btn.style.backgroundColor = '#4CAF50');
275
+ });
276
+ });
277
+ </script>
278
+ """)
279
 
280
+ demo.launch()