SuriRaja commited on
Commit
dce33d4
·
verified ·
1 Parent(s): b3ef395

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ import gradio as gr
5
+ from PIL import Image
6
+ import torch
7
+ import torchvision.transforms as transforms
8
+ from transformers import pipeline
9
+
10
+ # Load the DeepVTO model from Hugging Face
11
+ deepvto_pipeline = pipeline("image-to-image", model="huggingface/deepvto")
12
+
13
+ # Load sample product data
14
+ product_data = pd.DataFrame({
15
+ 'Product': ['Dress 1', 'Dress 2', 'Dress 3'],
16
+ 'Size': ['S', 'M', 'L'],
17
+ 'Color': ['Red', 'Blue', 'Green'],
18
+ 'Image': ['sample_dress1.jpg', 'sample_dress2.jpg', 'sample_dress3.jpg']
19
+ })
20
+
21
+ def process_image(image, product):
22
+ # Convert the uploaded image to a PIL image
23
+ person_image = Image.fromarray(image).convert("RGB")
24
+
25
+ # Fetch the garment image corresponding to the selected product
26
+ garment_filename = product_data[product_data['Product'] == product]['Image'].values[0]
27
+ garment_path = os.path.join(os.getcwd(), garment_filename)
28
+
29
+ if not os.path.exists(garment_path):
30
+ raise FileNotFoundError(f"File not found: {garment_path}")
31
+
32
+ garment_image = Image.open(garment_path).convert("RGB")
33
+
34
+ # Convert images to the format required by the model
35
+ person_image_tensor = transforms.ToTensor()(person_image).unsqueeze(0)
36
+ garment_image_tensor = transforms.ToTensor()(garment_image).unsqueeze(0)
37
+
38
+ # Run the DeepVTO model
39
+ with torch.no_grad():
40
+ output = deepvto_pipeline(person_image_tensor, garment_image_tensor)
41
+
42
+ # Convert the output to a PIL image
43
+ output_image = transforms.ToPILImage()(output[0])
44
+
45
+ # Convert to numpy array for Gradio
46
+ result_array = np.array(output_image)
47
+
48
+ # Fetch product details
49
+ product_details = product_data[product_data['Product'] == product].iloc[0].to_dict()
50
+
51
+ return result_array, product_details
52
+
53
+ # Gradio interface
54
+ iface = gr.Interface(
55
+ fn=process_image,
56
+ inputs=[
57
+ gr.Image(type="numpy", label="Upload Your Image"),
58
+ gr.Dropdown(choices=product_data['Product'].tolist(), label="Select Product")
59
+ ],
60
+ outputs=[
61
+ gr.Image(type="numpy", label="Output Image"),
62
+ gr.JSON(label="Product Details")
63
+ ],
64
+ title="Virtual Dress Fitting",
65
+ description="Upload an image and select a product to see how it fits on you."
66
+ )
67
+
68
+ if __name__ == "__main__":
69
+ iface.launch()