Msaqibsharif commited on
Commit
f8f6f3e
·
verified ·
1 Parent(s): f40795b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -118
app.py CHANGED
@@ -4,14 +4,13 @@ from PIL import Image
4
  import numpy as np
5
  import traceback
6
  import gradio as gr
7
- from transformers import DetrImageProcessor, DetrForObjectDetection, LayoutLMTokenizer, LayoutLMForTokenClassification
8
- from diffusers import StableDiffusionPipeline, StableDiffusionUpscalePipeline
9
  from huggingface_hub import login
10
  import torchvision.transforms as T
11
- import torchvision.models as models
12
- from dotenv import load_dotenv
13
 
14
  # Load environment variables from .env file
 
15
  load_dotenv()
16
 
17
  # Retrieve Hugging Face token from environment variable
@@ -19,12 +18,12 @@ HF_TOKEN = os.getenv("HF_TOKEN")
19
  if HF_TOKEN is None:
20
  raise ValueError("Hugging Face token not found in environment variables.")
21
 
22
- ## 2.1 Image Analysis with DETR
23
  def load_detr_model():
24
  try:
25
- detr_model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50')
26
- detr_processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-50')
27
- return detr_model, detr_processor, None
28
  except Exception as e:
29
  return None, None, f"Error loading DETR model: {str(e)}"
30
 
@@ -43,72 +42,12 @@ def detect_objects(image):
43
  else:
44
  return None, "DETR models not loaded. Skipping object detection."
45
 
46
- ## 2.2 Style Transfer with Deep Image Prior
47
- def style_transfer(content_image, style_image):
48
- try:
49
- transform = T.Compose([
50
- T.Resize((512, 512)),
51
- T.ToTensor(),
52
- T.Lambda(lambda x: x.mul(255))
53
- ])
54
-
55
- content = transform(content_image).unsqueeze(0).requires_grad_(False)
56
- style = transform(style_image).unsqueeze(0).requires_grad_(False)
57
-
58
- vgg = models.vgg19(pretrained=True).features.eval()
59
- for param in vgg.parameters():
60
- param.requires_grad_(False)
61
-
62
- generated = content.clone().requires_grad_(True)
63
- optimizer = torch.optim.Adam([generated], lr=0.003)
64
-
65
- for i in range(300):
66
- generated_features = vgg(generated)
67
- content_features = vgg(content)
68
- style_features = vgg(style)
69
-
70
- content_loss = torch.mean((generated_features - content_features)**2)
71
- style_loss = torch.mean((generated_features - style_features)**2)
72
-
73
- total_loss = content_loss + style_loss
74
- optimizer.zero_grad()
75
- total_loss.backward()
76
- optimizer.step()
77
-
78
- generated_image = generated.squeeze().clamp(0, 255).cpu().detach().numpy().transpose(1, 2, 0)
79
- return Image.fromarray(np.uint8(generated_image)), None
80
- except Exception as e:
81
- return content_image, f"Error in style_transfer: {str(e)}"
82
-
83
- ## 2.3 Layout Generation with LayoutLM
84
- def load_layoutlm_model():
85
- try:
86
- layoutlm_tokenizer = LayoutLMTokenizer.from_pretrained('microsoft/layoutlm-base-uncased')
87
- layoutlm_model = LayoutLMForTokenClassification.from_pretrained('microsoft/layoutlm-base-uncased')
88
- return layoutlm_tokenizer, layoutlm_model, None
89
- except Exception as e:
90
- return None, None, f"Error loading LayoutLM model: {str(e)}"
91
-
92
- layoutlm_tokenizer, layoutlm_model, layoutlm_error = load_layoutlm_model()
93
-
94
- def generate_layout(text):
95
- if layoutlm_tokenizer is not None and layoutlm_model is not None:
96
- try:
97
- inputs = layoutlm_tokenizer(text, return_tensors="pt")
98
- outputs = layoutlm_model(**inputs)
99
- layout = outputs.logits.argmax(dim=-1)
100
- return layout, None
101
- except Exception as e:
102
- return None, f"Error in generate_layout: {str(e)}"
103
- else:
104
- return None, "LayoutLM models not loaded. Skipping layout generation."
105
-
106
- ## 2.4 Image Generation with Stable Diffusion
107
  def load_stable_diffusion_model():
108
  try:
109
  login(token=HF_TOKEN)
110
- sd_pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to("cuda")
111
- return sd_pipeline, None
112
  except Exception as e:
113
  return None, f"Error loading Stable Diffusion model: {str(e)}"
114
 
@@ -124,69 +63,32 @@ def generate_image(prompt):
124
  else:
125
  return None, "Stable Diffusion model not loaded. Skipping image generation."
126
 
127
- ## 2.5 Super-Resolution
128
- def load_upscale_pipeline():
129
- try:
130
- upscale_pipeline = StableDiffusionUpscalePipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler").to("cuda")
131
- return upscale_pipeline, None
132
- except Exception as e:
133
- return None, f"Error loading Upscale Pipeline: {str(e)}"
134
-
135
- upscale_pipeline, upscale_error = load_upscale_pipeline()
136
-
137
- def super_resolve(image):
138
- if upscale_pipeline is not None:
139
- try:
140
- if not isinstance(image, Image.Image):
141
- raise ValueError("Input must be a PIL image.")
142
- upscaled_image = upscale_pipeline(image=image).images[0]
143
- return upscaled_image, None
144
- except Exception as e:
145
- return None, f"Error in super_resolve: {str(e)}"
146
- else:
147
- return image, "Upscale Pipeline not loaded. Skipping super-resolution."
148
-
149
- # Step 3: Gradio Interface and Integration
150
- def process_image(image, style_image, text_prompt):
151
  try:
152
  # Detect objects
153
  object_results, detect_error = detect_objects(image)
154
  if detect_error:
155
  return None, detect_error
156
 
157
- # Style transfer
158
- styled_image, style_error = style_transfer(image, style_image)
159
- if style_error:
160
- return None, style_error
161
-
162
- # Generate layout
163
- layout_results, layout_error = generate_layout(text_prompt)
164
- if layout_error:
165
- return None, layout_error
166
-
167
- # Generate image based on layout
168
- generated_image, gen_image_error = generate_image("modern interior design based on layout")
169
  if gen_image_error:
170
  return None, gen_image_error
171
 
172
- # Super-resolve the generated image
173
- final_image, upscale_error = super_resolve(generated_image)
174
- if upscale_error:
175
- return None, upscale_error
176
-
177
- return final_image, None
178
  except Exception as e:
179
  return None, f"Error in process_image: {str(e)}"
180
 
181
  iface = gr.Interface(
182
  fn=process_image,
183
  inputs=[
184
- gr.Image(type="pil", label="Upload Room Image"),
185
- gr.Image(type="pil", label="Upload Style Image"),
186
- gr.Textbox(label="Enter Design Prompt")
187
  ],
188
  outputs=[
189
- gr.Image(type="pil", label="Generated Image"),
190
  gr.Textbox(label="Error Message")
191
  ]
192
  )
@@ -195,4 +97,4 @@ try:
195
  iface.launch()
196
  except Exception as e:
197
  print(f"Error occurred while launching the interface: {str(e)}")
198
- traceback.print_exc()
 
4
  import numpy as np
5
  import traceback
6
  import gradio as gr
7
+ from transformers import DetrImageProcessor, DetrForObjectDetection
8
+ from diffusers import StableDiffusionPipeline
9
  from huggingface_hub import login
10
  import torchvision.transforms as T
 
 
11
 
12
  # Load environment variables from .env file
13
+ from dotenv import load_dotenv
14
  load_dotenv()
15
 
16
  # Retrieve Hugging Face token from environment variable
 
18
  if HF_TOKEN is None:
19
  raise ValueError("Hugging Face token not found in environment variables.")
20
 
21
+ # Load DETR model for object detection
22
  def load_detr_model():
23
  try:
24
+ model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50')
25
+ processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-50')
26
+ return model, processor, None
27
  except Exception as e:
28
  return None, None, f"Error loading DETR model: {str(e)}"
29
 
 
42
  else:
43
  return None, "DETR models not loaded. Skipping object detection."
44
 
45
+ # Load Stable Diffusion model for image generation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def load_stable_diffusion_model():
47
  try:
48
  login(token=HF_TOKEN)
49
+ pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to("cuda")
50
+ return pipeline, None
51
  except Exception as e:
52
  return None, f"Error loading Stable Diffusion model: {str(e)}"
53
 
 
63
  else:
64
  return None, "Stable Diffusion model not loaded. Skipping image generation."
65
 
66
+ # Gradio Interface
67
+ def process_image(image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  try:
69
  # Detect objects
70
  object_results, detect_error = detect_objects(image)
71
  if detect_error:
72
  return None, detect_error
73
 
74
+ # Generate a modern redesign of the image based on the detected objects
75
+ # For simplicity, we'll use a fixed prompt for image generation
76
+ prompt = "modern redesign of an interior room"
77
+ generated_image, gen_image_error = generate_image(prompt)
 
 
 
 
 
 
 
 
78
  if gen_image_error:
79
  return None, gen_image_error
80
 
81
+ return generated_image, None
 
 
 
 
 
82
  except Exception as e:
83
  return None, f"Error in process_image: {str(e)}"
84
 
85
  iface = gr.Interface(
86
  fn=process_image,
87
  inputs=[
88
+ gr.Image(type="pil", label="Upload Room Image")
 
 
89
  ],
90
  outputs=[
91
+ gr.Image(type="pil", label="Redesigned Image"),
92
  gr.Textbox(label="Error Message")
93
  ]
94
  )
 
97
  iface.launch()
98
  except Exception as e:
99
  print(f"Error occurred while launching the interface: {str(e)}")
100
+ traceback.print_exc()