rishh76 commited on
Commit
3ada0c7
·
verified ·
1 Parent(s): e3d0946

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -110
app.py CHANGED
@@ -1,48 +1,41 @@
1
- import os
2
- import cv2
3
  import gradio as gr
4
  import numpy as np
5
  import random
6
- import time
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- def tryon(person_img, garment_prompt, seed, randomize_seed):
9
- post_start_time = time.time()
10
-
11
- if person_img is None or garment_prompt.strip() == "":
12
- return None, None, "Empty image or prompt"
13
-
14
  if randomize_seed:
15
  seed = random.randint(0, MAX_SEED)
16
-
17
- # Create a copy of the person image to overlay text
18
- result_img = person_img.copy()
19
-
20
- # Convert the image to OpenCV format (if needed)
21
- if len(result_img.shape) == 2: # Convert grayscale to RGB
22
- result_img = cv2.cvtColor(result_img, cv2.COLOR_GRAY2RGB)
23
-
24
- # Set text position and properties
25
- text_position = (10, 30)
26
- font = cv2.FONT_HERSHEY_SIMPLEX
27
- font_scale = 1
28
- font_color = (0, 255, 0) # Green color for the text
29
- thickness = 2
30
 
31
- # Overlay the garment description text on the image
32
- cv2.putText(result_img, f'Garment: {garment_prompt}', text_position, font, font_scale, font_color, thickness, cv2.LINE_AA)
 
33
 
34
- post_end_time = time.time()
35
- print(f"post time used: {post_end_time - post_start_time}")
36
 
37
- # Return the resulting image, used seed, and success message
38
  return result_img, seed, "Success"
39
 
40
- MAX_SEED = 999999
 
 
41
 
42
- example_path = os.path.join(os.path.dirname(__file__), 'assets')
 
 
43
 
44
- human_list = os.listdir(os.path.join(example_path, "human"))
45
- human_list_path = [os.path.join(example_path, "human", human) for human in human_list]
46
 
47
  css = """
48
  #col-left {
@@ -61,90 +54,32 @@ css = """
61
  margin: 0 auto;
62
  max-width: 1100px;
63
  }
64
- #button {
65
- color: blue;
66
- }
67
  """
68
 
69
- def load_description(fp):
70
- with open(fp, 'r', encoding='utf-8') as f:
71
- content = f.read()
72
- return content
73
-
74
-
75
  with gr.Blocks(css=css) as Tryon:
76
- gr.HTML(load_description("assets/title.md"))
77
- with gr.Row():
78
- with gr.Column(elem_id="col-left"):
79
- gr.HTML("""
80
- <div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
81
- <div>
82
- Step 1. Upload a person image ⬇️
83
- </div>
84
- </div>
85
- """)
86
- with gr.Column(elem_id="col-mid"):
87
- gr.HTML("""
88
- <div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
89
- <div>
90
- Step 2. Enter a text prompt for the garment ⬇️
91
- </div>
92
- </div>
93
- """)
94
- with gr.Column(elem_id="col-right"):
95
- gr.HTML("""
96
- <div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
97
- <div>
98
- Step 3. Press “Run” to get try-on results
99
- </div>
100
- </div>
101
- """)
102
  with gr.Row():
103
  with gr.Column(elem_id="col-left"):
104
- imgs = gr.Image(label="Person image", sources='upload', type="numpy")
105
- example = gr.Examples(
106
- inputs=imgs,
107
- examples_per_page=12,
108
- examples=human_list_path
109
- )
110
  with gr.Column(elem_id="col-mid"):
111
- garment_prompt = gr.Textbox(label="Garment text prompt", placeholder="Describe the garment...")
 
 
112
  with gr.Column(elem_id="col-right"):
113
- image_out = gr.Image(label="Result", show_share_button=False)
114
- with gr.Row():
115
- seed = gr.Slider(
116
- label="Seed",
117
- minimum=0,
118
- maximum=MAX_SEED,
119
- step=1,
120
- value=0,
121
- )
122
- randomize_seed = gr.Checkbox(label="Random seed", value=True)
123
- with gr.Row():
124
- seed_used = gr.Number(label="Seed used")
125
- result_info = gr.Text(label="Response")
126
- test_button = gr.Button(value="Run", elem_id="button")
127
-
128
- test_button.click(fn=tryon, inputs=[imgs, garment_prompt, seed, randomize_seed], outputs=[image_out, seed_used, result_info], concurrency_limit=40)
129
-
130
- with gr.Column(elem_id="col-showcase"):
131
- gr.HTML("""
132
- <div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
133
- <div> </div>
134
- <br>
135
- <div>
136
- Virtual try-on examples in pairs of person and garment images
137
- </div>
138
- </div>
139
- """)
140
- show_case = gr.Examples(
141
- examples=[
142
- ["assets/examples/model2.png", "assets/examples/garment2.png", "assets/examples/result2.png"],
143
- ["assets/examples/model3.png", "assets/examples/garment3.png", "assets/examples/result3.png"],
144
- ["assets/examples/model1.png", "assets/examples/garment1.png", "assets/examples/result1.png"],
145
- ],
146
- inputs=[imgs, garment_prompt, image_out],
147
- label=None
148
- )
149
 
150
  Tryon.launch()
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
+ import torch
5
+ from diffusers import StableDiffusionPipeline
6
+
7
+ # Load the Stable Diffusion model for text-based garment generation
8
+ model_id = "runwayml/stable-diffusion-v1-5"
9
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
10
+ pipe = pipe.to("cuda") # Use GPU for faster inference
11
+
12
+ MAX_SEED = 999999
13
+
14
+ def generate_garment(person_img, cloth_description, seed, randomize_seed):
15
+ if person_img is None or cloth_description is None or cloth_description.strip() == "":
16
+ return None, None, "Invalid input"
17
 
 
 
 
 
 
 
18
  if randomize_seed:
19
  seed = random.randint(0, MAX_SEED)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ # Generate garment image from the text description
22
+ torch.manual_seed(seed)
23
+ garment_img = pipe(cloth_description).images[0]
24
 
25
+ # Combine the generated garment with the person's image
26
+ result_img = combine_images(person_img, garment_img)
27
 
 
28
  return result_img, seed, "Success"
29
 
30
+ def combine_images(person_img, garment_img):
31
+ person_img = np.array(person_img)
32
+ garment_img = np.array(garment_img.resize((person_img.shape[1], person_img.shape[0])))
33
 
34
+ # Simple overlay of garment on the person image
35
+ # Further improvement may require segmentation/masking
36
+ result_img = np.where(garment_img[:, :, 3:] > 0, garment_img[:, :, :3], person_img)
37
 
38
+ return result_img
 
39
 
40
  css = """
41
  #col-left {
 
54
  margin: 0 auto;
55
  max-width: 1100px;
56
  }
 
 
 
57
  """
58
 
 
 
 
 
 
 
59
  with gr.Blocks(css=css) as Tryon:
60
+ gr.HTML("<h1>Virtual Try-On with Text-based Garment Generation</h1>")
61
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  with gr.Row():
63
  with gr.Column(elem_id="col-left"):
64
+ gr.HTML("<h3>Step 1: Upload a person image ⬇️</h3>")
65
+ person_img = gr.Image(label="Person Image", source='upload', type="numpy")
66
+
 
 
 
67
  with gr.Column(elem_id="col-mid"):
68
+ gr.HTML("<h3>Step 2: Describe the garment ⬇️</h3>")
69
+ cloth_description = gr.Textbox(label="Garment Description", placeholder="e.g., red dress with floral pattern")
70
+
71
  with gr.Column(elem_id="col-right"):
72
+ gr.HTML("<h3>Step 3: Generate Try-On Image ⬇️</h3>")
73
+ result_img = gr.Image(label="Result", show_share_button=False)
74
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
75
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
76
+ seed_used = gr.Number(label="Seed Used", interactive=False)
77
+ result_info = gr.Text(label="Status", interactive=False)
78
+
79
+ generate_button = gr.Button(value="Run")
80
+
81
+ generate_button.click(fn=generate_garment,
82
+ inputs=[person_img, cloth_description, seed, randomize_seed],
83
+ outputs=[result_img, seed_used, result_info])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  Tryon.launch()