Spaces:
Runtime error
Runtime error
File size: 12,269 Bytes
3bc69b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 |
from flask import Flask, request, jsonify,send_file, Response
from flask_cors import CORS
import logging
import gc
import os
from threading import Thread
from flask_sse import sse
import uuid
import redis
import multiprocessing
from werkzeug.exceptions import NotFound, InternalServerError
import threading
from collections import OrderedDict
from flask import current_app
import time
from celery import Celery
from io import BytesIO
from pathlib import Path
import sys
import torch
from PIL import Image, ImageOps
import numpy as np
from run.utils_ootd import get_mask_location
from run.cloths_db import cloths_map, modeL_db
from preprocess.openpose.run_openpose import OpenPose
from preprocess.humanparsing.run_parsing import Parsing
from ootd.inference_ootd_dc import OOTDiffusionDC
PROJECT_ROOT = Path(__file__).absolute().parents[1].absolute()
sys.path.insert(0, str(PROJECT_ROOT))
from queue import Queue
from celery_worker import process_image
#run python garbage collector and nvidia cuda clear memory
gc.collect()
torch.cuda.empty_cache()
# Set the start method to 'spawn'
# multiprocessing.set_start_method('spawn', force=True)
# Setup Flask server
app = Flask(__name__)
app.config.update(
CELERY_BROKER_URL='redis://localhost:6379',
CELERY_RESULT_BACKEND='redis://localhost:6379'
)
# Initialize Celery
celery = Celery(app.name, broker=app.config['CELERY_BROKER_URL'])
celery.conf.update(app.config)
logger = logging.getLogger()
openpose_model = OpenPose(0)
parsing_model_dc = Parsing(0)
ootd_model_dc = OOTDiffusionDC(0)
example_path = os.path.join(os.path.dirname(__file__), 'examples')
garment_path = os.path.join(os.path.dirname(__file__), 'examples','garment')
openpose_model.preprocessor.body_estimation.model.to('cuda')
ootd_model_dc.pipe.to('cuda')
ootd_model_dc.image_encoder.to('cuda')
ootd_model_dc.text_encoder.to('cuda')
category_dict = ['upperbody', 'lowerbody', 'dress']
category_dict_utils = ['upper_body', 'lower_body', 'dresses']
# Ensure this directory exists
UPLOAD_FOLDER = 'temp_images'
if not os.path.exists(UPLOAD_FOLDER):
os.makedirs(UPLOAD_FOLDER)
OUTPUT_FOLDER = 'path/to/output/folder'
image_results = {}
image_results_lock = threading.Lock()
# Use an OrderedDict to limit the number of stored results
image_results = OrderedDict()
MAX_RESULTS = 100 # Adjust this value based on your needs
def process_dc(vton_img, garm_img, category,progress_callback):
model_type = 'dc'
if category == 'Upper-body':
category = 0
elif category == 'Lower-body':
category = 1
else:
category = 2
with torch.no_grad():
# openpose_model.preprocessor.body_estimation.model.to('cuda')
# ootd_model_dc.pipe.to('cuda')
# ootd_model_dc.image_encoder.to('cuda')
# ootd_model_dc.text_encoder.to('cuda')
garm_img = Image.open(garm_img).resize((768, 1024))
vton_img = Image.open(vton_img).resize((768, 1024))
keypoints = openpose_model(vton_img.resize((384, 512)))
print(len(keypoints["pose_keypoints_2d"]))
print(keypoints["pose_keypoints_2d"])
left_point = keypoints["pose_keypoints_2d"][2]
right_point = keypoints["pose_keypoints_2d"][5]
neck_point = keypoints["pose_keypoints_2d"][1]
hip_point = keypoints["pose_keypoints_2d"][8]
print(f'left shoulder - {left_point}')
print(f'right shoulder - {right_point}')
# #find disctance using Euclidian distance
shoulder_width_pixels = round(np.sqrt( np.power((right_point[0]-left_point[0]),2) + np.power((right_point[1]-left_point[1]),2)),2)
height_pixels = round(np.sqrt( np.power((neck_point[0]-hip_point[0]),2) + np.power((neck_point[1]-hip_point[1]),2)),2) *2
# # Assuming an average human height
average_height_cm = 172.72 *1.5
# Conversion factor from pixels to cm
conversion_factor = average_height_cm / height_pixels
# Convert shoulder width to real-world units
shoulder_width_cm = shoulder_width_pixels * conversion_factor
print(f'Shoulder width (in pixels): {shoulder_width_pixels}')
print(f'Estimated height (in pixels): {height_pixels}')
print(f'Conversion factor (pixels to cm): {conversion_factor}')
print(f'Shoulder width (in cm): {shoulder_width_cm}')
print(f'Shoulder width (in INCH): {round(shoulder_width_cm/2.54,1)}')
model_parse,_ = parsing_model_dc(vton_img.resize((384, 512)))
mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints)
mask = mask.resize((768, 1024), Image.NEAREST)
mask_gray = mask_gray.resize((768, 1024), Image.NEAREST)
# Save the resized masks
# mask.save("mask_resized.png")
# mask_gray.save("mask_gray_resized.png")
masked_vton_img = Image.composite(mask_gray, vton_img, mask)
# masked_vton_img.save("masked_vton_img.png")
print(f'category is {category}')
# images = ootd_model_dc(
# model_type=model_type,
# category=category_dict[category],
# image_garm=garm_img,
# image_vton=masked_vton_img,
# mask=mask,
# image_ori=vton_img,
# num_samples=3,
# num_steps=20,
# image_scale= 2.0,
# seed=-1,
# )
images = ootd_model_dc(
model_type=model_type,
category=category_dict[category],
image_garm=garm_img,
image_vton=masked_vton_img,
mask=mask,
image_ori=vton_img,
num_samples=2,
num_steps=10,
image_scale=2.0,
seed=42,
progress_callback=progress_callback,
progress_interval=1, # Update progress every step
)
return images
# def create_progress_callback(session_id):
# def progress_callback(step, total_steps):
# progress = int((step + 1) / total_steps * 100)
# print(f"Publishing progress {progress} for session {session_id}")
# sse.publish({"progress": progress}, type='progress', channel=session_id)
# return progress_callback
# @celery.task(bind=True)
# def process_image(self, session_id, garm_path, vton_path, category):
# try:
# print(f"Starting process_image task for session {session_id}")
# progress_callback = create_progress_callback(session_id)
# output_images = process_dc(garm_img=garm_path,
# vton_img=vton_path,
# category=category,
# progress_callback=progress_callback)
# if not output_images:
# sse.publish({"error": "No output image generated"}, type='error', channel=session_id)
# return None
# output_image = output_images[0]
# # Generate a UUID for the output image
# image_uuid = str(uuid.uuid4())
# # Create the output filename with the UUID
# output_filename = f"{image_uuid}.png"
# output_path = os.path.join(OUTPUT_FOLDER, output_filename)
# # Save the output image
# output_image.save(output_path, format='PNG')
# # Add the UUID and path to the image_results map
# with image_results_lock:
# image_results[image_uuid] = output_path
# sse.publish({"message": "Processing complete", "uuid": image_uuid}, type='complete', channel=session_id)
# return image_uuid
# except Exception as e:
# sse.publish({"error": str(e)}, type='error', channel=session_id)
# return print(f"panic in process_image: {str(e)}")
@app.route('/')
def root():
try:
response_data = {"message": "This is VTR API v1.0"}
return jsonify(response_data)
except Exception as e:
logger.error(f"Root endpoint error: {str(e)}")
response_data = {"message": "Internal server Error"}
return jsonify(response_data), 500
# @app.route('/stream')
# def stream():
# session_id = request.args.get('channel')
# if not session_id:
# return "No channel specified", 400
# return Response(sse.stream(), content_type='text/event-stream')
@app.route('/test_sse/<session_id>')
def test_sse(session_id):
sse.publish({"message": "Test SSE"}, type='test', channel=session_id)
return "SSE test message sent"
#write Flask api name "generate" with POST method that will input 2 images and return 1 image
@app.route('/generate', methods=['POST'])
def generate():
"""
A Flask route that handles a POST request to the '/generate' endpoint.
It expects two files, 'garm_img' and 'vton_img', to be included in the request.
The function calls the 'process_dc' function with the provided files and the
category 'Upper-body'. It then sends the processed image as a file with the
mimetype 'image/png' and returns it to the client. If any exception occurs,
the function logs the error and returns a JSON response with a status code of
500.
Parameters:
None
Returns:
A Flask response object with the processed image as a file.
Raises:
None
"""
# if category == 'Upper-body':
# category = 0
# elif category == 'Lower-body':
# category = 1
# else:
# category = 2
try:
cloths_type = ["Upper-body", "Lower-body", "Dress"]
garm_img = request.files['garm_img']
vton_img = request.files['vton_img']
cat = request.form['category']
print(f'category is {cat}')
category =cloths_type[int(cat)] # Default to Upper-body if not specified
# Save the uploaded files
garm_path = os.path.join(UPLOAD_FOLDER, 'garm_input.png')
vton_path = os.path.join(UPLOAD_FOLDER, 'vton_input.png')
garm_img.save(garm_path)
vton_img.save(vton_path)
# Convert file objects to bytes IO objects
# garm_img = BytesIO(garm_img.read())
# vton_img = BytesIO(vton_img.read())
# Start processing in a background task
session_id = str(uuid.uuid4())
process_image.apply_async(args=[session_id, garm_path, vton_path, category])
# Immediately return the session_id to the client
return jsonify({"session_id": session_id, "message": "Processing started"}), 202
# while not task.ready():
# time.sleep(1) # Polling the task status every second
# if task.successful():
# img_byte_arr = task.result
# if img_byte_arr:
# return Response(img_byte_arr, mimetype='image/png')
# else:
# return Response("No output image generated", status=500)
# else:
# return Response("Processing failed", status=500)
except Exception as e:
print(f"Error: {str(e)}") # Log the error
return Response(str(e), status=500)
@app.route('/get_image/<uuid>')
def get_image(uuid):
try:
with image_results_lock:
if uuid not in image_results:
raise NotFound("Invalid UUID or result not available")
image_path = image_results[uuid]
if not os.path.exists(image_path):
raise NotFound("Image file not found")
# Determine the MIME type based on the file extension
file_extension = os.path.splitext(image_path)[1].lower()
mime_type = 'image/jpeg' if file_extension == '.jpg' or file_extension == '.jpeg' else 'image/png'
return send_file(image_path, mimetype=mime_type, as_attachment=False)
except NotFound as e:
logger.warning(f"Get image request failed: {str(e)}")
return jsonify({"error": str(e)}), 404
except Exception as e:
logger.error(f"Unexpected error in get_image: {str(e)}")
return jsonify({"error": "An unexpected error occurred"}), 500
if __name__ == '__main__':
app.run(debug=False, host='0.0.0.0', port=5009)
# nohup gunicorn -b 0.0.0.0:5003 sentiment_api:app & |