Msaqibsharif commited on
Commit
8b827ad
·
verified ·
1 Parent(s): c50cad0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -14
app.py CHANGED
@@ -1,8 +1,7 @@
1
  import os
2
  import torch
3
- from PIL import Image
4
  import numpy as np
5
- import traceback
6
  import gradio as gr
7
  from transformers import DetrImageProcessor, DetrForObjectDetection
8
  from diffusers import StableDiffusionPipeline
@@ -14,8 +13,7 @@ from dotenv import load_dotenv
14
  load_dotenv()
15
 
16
  # Retrieve Hugging Face token from environment variable
17
- HF_TOKEN = os.getenv("HF_TOKEN")
18
-
19
  if HF_TOKEN is None:
20
  raise ValueError("Hugging Face token not found in environment variables.")
21
 
@@ -40,7 +38,13 @@ def detect_objects(image):
40
  outputs = detr_model(**inputs)
41
  target_sizes = torch.tensor([image.size[::-1]])
42
  results = detr_processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
43
- return results, None
 
 
 
 
 
 
44
  except Exception as e:
45
  return None, f"Error in detect_objects: {str(e)}"
46
  else:
@@ -57,27 +61,42 @@ def load_stable_diffusion_model():
57
 
58
  sd_pipeline, sd_error = load_stable_diffusion_model()
59
 
60
- def generate_image(prompt):
 
 
 
 
 
 
61
  if sd_pipeline is not None:
62
  try:
63
- image = sd_pipeline(prompt).images[0]
 
 
 
64
  return image, None
65
  except Exception as e:
66
  return None, f"Error in generate_image: {str(e)}"
67
  else:
68
  return None, "Stable Diffusion model not loaded. Skipping image generation."
69
 
70
- # Gradio Interface
71
  def process_image(image):
72
  try:
73
- # Detect objects
74
- object_results, detect_error = detect_objects(image)
75
  if detect_error:
76
  return None, detect_error
77
 
78
- # Generate a modern redesign of the image based on the detected objects
79
- prompt = "modern redesign of an interior room"
80
- generated_image, gen_image_error = generate_image(prompt)
 
 
 
 
 
 
 
81
  if gen_image_error:
82
  return None, gen_image_error
83
 
@@ -100,4 +119,3 @@ try:
100
  iface.launch()
101
  except Exception as e:
102
  print(f"Error occurred while launching the interface: {str(e)}")
103
- traceback.print_exc()
 
1
  import os
2
  import torch
3
+ from PIL import Image, ImageDraw
4
  import numpy as np
 
5
  import gradio as gr
6
  from transformers import DetrImageProcessor, DetrForObjectDetection
7
  from diffusers import StableDiffusionPipeline
 
13
  load_dotenv()
14
 
15
  # Retrieve Hugging Face token from environment variable
16
+ HF_TOKEN = os.getenv('HF_TOKEN')
 
17
  if HF_TOKEN is None:
18
  raise ValueError("Hugging Face token not found in environment variables.")
19
 
 
38
  outputs = detr_model(**inputs)
39
  target_sizes = torch.tensor([image.size[::-1]])
40
  results = detr_processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
41
+
42
+ detected_objects = [
43
+ {"label": detr_model.config.id2label[label.item()],
44
+ "box": box.tolist()}
45
+ for label, box in zip(results['labels'], results['boxes'])
46
+ ]
47
+ return detected_objects, None
48
  except Exception as e:
49
  return None, f"Error in detect_objects: {str(e)}"
50
  else:
 
61
 
62
  sd_pipeline, sd_error = load_stable_diffusion_model()
63
 
64
+ def adjust_dimensions(width, height):
65
+ # Adjust width and height to be divisible by 8
66
+ adjusted_width = (width // 8) * 8
67
+ adjusted_height = (height // 8) * 8
68
+ return adjusted_width, adjusted_height
69
+
70
+ def generate_image(prompt, width, height):
71
  if sd_pipeline is not None:
72
  try:
73
+ adjusted_width, adjusted_height = adjust_dimensions(width, height)
74
+ image = sd_pipeline(prompt, width=adjusted_width, height=adjusted_height).images[0]
75
+ # Resize back to original dimensions if needed
76
+ image = image.resize((width, height), Image.LANCZOS)
77
  return image, None
78
  except Exception as e:
79
  return None, f"Error in generate_image: {str(e)}"
80
  else:
81
  return None, "Stable Diffusion model not loaded. Skipping image generation."
82
 
 
83
  def process_image(image):
84
  try:
85
+ # Detect objects in the provided image
86
+ detected_objects, detect_error = detect_objects(image)
87
  if detect_error:
88
  return None, detect_error
89
 
90
+ # Create a prompt based on detected objects
91
+ prompt = "modern redesign of an interior room with "
92
+ if detected_objects:
93
+ prompt += ", ".join([obj['label'] for obj in detected_objects])
94
+ else:
95
+ prompt += "empty room"
96
+
97
+ # Generate a redesigned image based on the prompt
98
+ width, height = image.size
99
+ generated_image, gen_image_error = generate_image(prompt, width, height)
100
  if gen_image_error:
101
  return None, gen_image_error
102
 
 
119
  iface.launch()
120
  except Exception as e:
121
  print(f"Error occurred while launching the interface: {str(e)}")