Msaqibsharif commited on
Commit
d57c17f
·
verified ·
1 Parent(s): b17a5ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -140
app.py CHANGED
@@ -1,140 +1 @@
1
- import os
2
- import torch
3
- from PIL import Image
4
- import gradio as gr
5
- from transformers import DetrImageProcessor, DetrForObjectDetection
6
- from diffusers import StableDiffusionPipeline
7
- from huggingface_hub import login
8
- from dotenv import load_dotenv
9
-
10
- # Load environment variables from .env file
11
- load_dotenv()
12
-
13
- # Retrieve Hugging Face token from environment variable
14
- HF_TOKEN = os.getenv('HF_TOKEN')
15
-
16
- if HF_TOKEN is None:
17
- raise ValueError("Hugging Face token not found in environment variables.")
18
-
19
- # Login to Hugging Face using the token
20
- login(token=HF_TOKEN)
21
-
22
- # Load DETR model for object detection
23
- def load_detr_model():
24
- try:
25
- model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50')
26
- processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-50')
27
- return model, processor, None
28
- except Exception as e:
29
- return None, None, f"Error loading DETR model: {str(e)}"
30
-
31
- detr_model, detr_processor, detr_error = load_detr_model()
32
-
33
- def detect_objects(image):
34
- if image is None:
35
- return None, "Invalid image: image is None."
36
-
37
- if detr_model is not None and detr_processor is not None:
38
- try:
39
- inputs = detr_processor(images=image, return_tensors="pt")
40
- outputs = detr_model(**inputs)
41
- target_sizes = torch.tensor([image.size[::-1]])
42
- results = detr_processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
43
-
44
- detected_objects = [
45
- {"label": detr_model.config.id2label[label.item()],
46
- "box": box.tolist()}
47
- for label, box in zip(results['labels'], results['boxes'])
48
- ]
49
- return detected_objects, None
50
- except Exception as e:
51
- return None, f"Error in detect_objects: {str(e)}"
52
- else:
53
- return None, "DETR models not loaded. Skipping object detection."
54
-
55
- # Load Stable Diffusion model for image generation
56
- def load_stable_diffusion_model():
57
- try:
58
- device = "cuda" if torch.cuda.is_available() else "cpu"
59
- pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(device)
60
- return pipeline, None
61
- except Exception as e:
62
- return None, f"Error loading Stable Diffusion model: {str(e)}"
63
-
64
- sd_pipeline, sd_error = load_stable_diffusion_model()
65
-
66
- def adjust_dimensions(width, height):
67
- # Adjust width and height to be divisible by 8
68
- adjusted_width = (width // 8) * 8
69
- adjusted_height = (height // 8) * 8
70
- return adjusted_width, adjusted_height
71
-
72
- def generate_image(prompt, width, height):
73
- if sd_pipeline is not None:
74
- try:
75
- adjusted_width, adjusted_height = adjust_dimensions(width, height)
76
- image = sd_pipeline(prompt, width=adjusted_width, height=adjusted_height).images[0]
77
- # Resize back to original dimensions if needed
78
- image = image.resize((width, height), Image.LANCZOS)
79
- return image, None
80
- except Exception as e:
81
- return None, f"Error in generate_image: {str(e)}"
82
- else:
83
- return None, "Stable Diffusion model not loaded. Skipping image generation."
84
-
85
- def process_image(image):
86
- if image is None:
87
- return None, "Invalid image: image is None."
88
-
89
- try:
90
- # Detect objects in the provided image
91
- detected_objects, detect_error = detect_objects(image)
92
- if detect_error:
93
- return None, detect_error
94
-
95
- # Create a prompt based on detected objects
96
- prompt = "modern redesign of an interior room with "
97
- if detected_objects:
98
- prompt += ", ".join([obj['label'] for obj in detected_objects])
99
- else:
100
- prompt += "empty room"
101
-
102
- # Generate a redesigned image based on the prompt
103
- width, height = image.size
104
- generated_image, gen_image_error = generate_image(prompt, width, height)
105
- if gen_image_error:
106
- return None, gen_image_error
107
-
108
- return generated_image, None
109
- except Exception as e:
110
- return None, f"Error in process_image: {str(e)}"
111
-
112
- # Custom CSS for styling
113
- custom_css = """
114
- body {
115
- background-color: black;
116
- }
117
-
118
- h1 {
119
- background: linear-gradient(to right, blue, purple);
120
- -webkit-background-clip: text;
121
- color: transparent;
122
- font-size: 3em;
123
- text-align: center;
124
- margin-bottom: 20px;
125
- }
126
- """
127
-
128
- # Creating the Gradio interface with custom styling
129
- iface = gr.Interface(
130
- fn=process_image,
131
- inputs=[gr.Image(type="pil", label="Upload Room Image")],
132
- outputs=[gr.Image(type="pil", label="Redesigned Image"), gr.Textbox(label="Error Message")],
133
- title="Interior Redesign",
134
- css=custom_css
135
- )
136
-
137
- try:
138
- iface.launch()
139
- except Exception as e:
140
- print(f"Error occurred while launching the interface: {str(e)}")
 
1
+ print("hello")