Spaces:
Runtime error
Runtime error
File size: 8,806 Bytes
3bc69b8 6e6426e 3bc69b8 6e6426e 3bc69b8 6e6426e 3bc69b8 6e6426e 3bc69b8 6e6426e 3bc69b8 6e6426e 3bc69b8 6e6426e 3bc69b8 6e6426e 3bc69b8 6e6426e 3bc69b8 6e6426e 3bc69b8 6e6426e 3bc69b8 6e6426e 3bc69b8 6e6426e 3bc69b8 6e6426e 3bc69b8 6e6426e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 |
from flask import Flask, request, jsonify,send_file, Response
from flask_cors import CORS
import logging
import gc
import os
from threading import Thread
from flask_sse import sse
from io import BytesIO
from pathlib import Path
import sys
import torch
from PIL import Image, ImageOps
import numpy as np
from run.utils_ootd import get_mask_location
from run.cloths_db import cloths_map, modeL_db
from preprocess.openpose.run_openpose import OpenPose
from preprocess.humanparsing.run_parsing import Parsing
from ootd.inference_ootd_dc import OOTDiffusionDC
PROJECT_ROOT = Path(__file__).absolute().parents[1].absolute()
sys.path.insert(0, str(PROJECT_ROOT))
from queue import Queue
#run python garbage collector and nvidia cuda clear memory
gc.collect()
torch.cuda.empty_cache()
# Setup Flask server
app = Flask(__name__)
CORS(app, origins="*") # Enable CORS for the entire app
app.config["REDIS_URL"] = "redis://localhost:6379"
app.register_blueprint(sse, url_prefix='/stream')
logger = logging.getLogger()
openpose_model = OpenPose(0)
parsing_model_dc = Parsing(0)
ootd_model_dc = OOTDiffusionDC(0)
example_path = os.path.join(os.path.dirname(__file__), 'examples')
garment_path = os.path.join(os.path.dirname(__file__), 'examples','garment')
openpose_model.preprocessor.body_estimation.model.to('cuda')
ootd_model_dc.pipe.to('cuda')
ootd_model_dc.image_encoder.to('cuda')
ootd_model_dc.text_encoder.to('cuda')
category_dict = ['upperbody', 'lowerbody', 'dress']
category_dict_utils = ['upper_body', 'lower_body', 'dresses']
# Ensure this directory exists
UPLOAD_FOLDER = 'temp_images'
if not os.path.exists(UPLOAD_FOLDER):
os.makedirs(UPLOAD_FOLDER)
# progress_queue = Queue()
# def progress_callback(step, total_steps):
# if total_steps is not None and total_steps > 0:
# progress = int((step + 1) / total_steps * 100)
# progress_queue.put(progress)
# else:
# progress_queue.put(step + 1)
def progress_callback(step, total_steps):
if total_steps is not None and total_steps > 0:
progress = int((step + 1) / total_steps * 100)
sse.publish({"progress": progress}, type='progress')
else:
sse.publish({"step": step + 1}, type='progress')
def process_dc(vton_img, garm_img, category):
model_type = 'dc'
if category == 'Upper-body':
category = 0
elif category == 'Lower-body':
category = 1
else:
category = 2
with torch.no_grad():
# openpose_model.preprocessor.body_estimation.model.to('cuda')
# ootd_model_dc.pipe.to('cuda')
# ootd_model_dc.image_encoder.to('cuda')
# ootd_model_dc.text_encoder.to('cuda')
garm_img = Image.open(garm_img).resize((768, 1024))
vton_img = Image.open(vton_img).resize((768, 1024))
keypoints = openpose_model(vton_img.resize((384, 512)))
print(len(keypoints["pose_keypoints_2d"]))
print(keypoints["pose_keypoints_2d"])
left_point = keypoints["pose_keypoints_2d"][2]
right_point = keypoints["pose_keypoints_2d"][5]
neck_point = keypoints["pose_keypoints_2d"][1]
hip_point = keypoints["pose_keypoints_2d"][8]
print(f'left shoulder - {left_point}')
print(f'right shoulder - {right_point}')
# #find disctance using Euclidian distance
shoulder_width_pixels = round(np.sqrt( np.power((right_point[0]-left_point[0]),2) + np.power((right_point[1]-left_point[1]),2)),2)
height_pixels = round(np.sqrt( np.power((neck_point[0]-hip_point[0]),2) + np.power((neck_point[1]-hip_point[1]),2)),2) *2
# # Assuming an average human height
average_height_cm = 172.72 *1.5
# Conversion factor from pixels to cm
conversion_factor = average_height_cm / height_pixels
# Convert shoulder width to real-world units
shoulder_width_cm = shoulder_width_pixels * conversion_factor
print(f'Shoulder width (in pixels): {shoulder_width_pixels}')
print(f'Estimated height (in pixels): {height_pixels}')
print(f'Conversion factor (pixels to cm): {conversion_factor}')
print(f'Shoulder width (in cm): {shoulder_width_cm}')
print(f'Shoulder width (in INCH): {round(shoulder_width_cm/2.54,1)}')
model_parse,_ = parsing_model_dc(vton_img.resize((384, 512)))
mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints)
mask = mask.resize((768, 1024), Image.NEAREST)
mask_gray = mask_gray.resize((768, 1024), Image.NEAREST)
# Save the resized masks
# mask.save("mask_resized.png")
# mask_gray.save("mask_gray_resized.png")
masked_vton_img = Image.composite(mask_gray, vton_img, mask)
# masked_vton_img.save("masked_vton_img.png")
print(f'category is {category}')
# images = ootd_model_dc(
# model_type=model_type,
# category=category_dict[category],
# image_garm=garm_img,
# image_vton=masked_vton_img,
# mask=mask,
# image_ori=vton_img,
# num_samples=3,
# num_steps=20,
# image_scale= 2.0,
# seed=-1,
# )
images = ootd_model_dc(
model_type=model_type,
category=category_dict[category],
image_garm=garm_img,
image_vton=masked_vton_img,
mask=mask,
image_ori=vton_img,
num_samples=2,
num_steps=10,
image_scale=2.0,
seed=42,
progress_callback=progress_callback,
progress_interval=1, # Update progress every step
)
return images
@app.route('/')
def root():
try:
response_data = {"message": "This is VTR API v1.0"}
return jsonify(response_data)
except Exception as e:
logger.error(f"Root endpoint error: {str(e)}")
response_data = {"message": "Internal server Error"}
return jsonify(response_data), 500
@app.route('/stream')
def stream():
return Response(sse.stream(), content_type='text/event-stream')
#write Flask api name "generate" with POST method that will input 2 images and return 1 image
@app.route('/generate', methods=['POST'])
def generate():
"""
A Flask route that handles a POST request to the '/generate' endpoint.
It expects two files, 'garm_img' and 'vton_img', to be included in the request.
The function calls the 'process_dc' function with the provided files and the
category 'Upper-body'. It then sends the processed image as a file with the
mimetype 'image/png' and returns it to the client. If any exception occurs,
the function logs the error and returns a JSON response with a status code of
500.
Parameters:
None
Returns:
A Flask response object with the processed image as a file.
Raises:
None
"""
# if category == 'Upper-body':
# category = 0
# elif category == 'Lower-body':
# category = 1
# else:
# category = 2
try:
cloths_type = ["Upper-body", "Lower-body", "Dress"]
garm_img = request.files['garm_img']
vton_img = request.files['vton_img']
cat = request.form['category']
print(f'category is {cat}')
category =cloths_type[int(cat)] # Default to Upper-body if not specified
# Save the uploaded files
garm_path = os.path.join(UPLOAD_FOLDER, 'garm_input.png')
vton_path = os.path.join(UPLOAD_FOLDER, 'vton_input.png')
garm_img.save(garm_path)
vton_img.save(vton_path)
# Convert file objects to bytes IO objects
# garm_img = BytesIO(garm_img.read())
# vton_img = BytesIO(vton_img.read())
output_images = process_dc(garm_img=garm_img,
vton_img=vton_img,
category=category)
if not output_images:
return Response("No output image generated", status=500)
output_image = output_images[0] # Get the first image
# Convert PIL Image to bytes
img_byte_arr = BytesIO()
output_image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
# Send the final "complete" event via SSE
sse.publish({"message": "Processing complete"}, type='complete')
return Response(img_byte_arr, mimetype='image/png')
except Exception as e:
print(f"Error: {str(e)}") # Log the error
return Response(str(e), status=500)
if __name__ == '__main__':
app.run(debug=False, host='0.0.0.0', port=5009)
# nohup gunicorn -b 0.0.0.0:5003 sentiment_api:app &
|