Spaces:
Saad0KH
/
Running on Zero

Saad0KH commited on
Commit
ccd0584
ยท
verified ยท
1 Parent(s): af7056a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -67
app.py CHANGED
@@ -1,13 +1,11 @@
1
  from flask import Flask, request, jsonify
2
  import torch
3
  from transformers import (
4
- UNet2DConditionModel,
5
- AutoTokenizer,
6
- CLIPTextModel,
7
- CLIPTextModelWithProjection,
8
- CLIPVisionModelWithProjection,
9
- AutoencoderKL,
10
- DDPMScheduler
11
  )
12
  from PIL import Image
13
  import base64
@@ -28,71 +26,70 @@ UNet_Encoder = None
28
 
29
  # Load models once at startup
30
  def load_models():
31
- global unet, tokenizer_one, tokenizer_two, noise_scheduler, text_encoder_one, text_encoder_two, image_encoder, vae, UNet_Encoder
32
- base_path = "your_base_path_here"
33
- unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet", torch_dtype=torch.float16, force_download=False)
34
- tokenizer_one = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer", use_fast=False, force_download=False)
35
- tokenizer_two = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer_2", use_fast=False, force_download=False)
36
- noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
37
- text_encoder_one = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder", torch_dtype=torch.float16, force_download=False)
38
- text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=torch.float16, force_download=False)
39
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=torch.float16, force_download=False)
40
- vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16, force_download=False)
41
- UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=torch.float16, force_download=False)
42
-
43
- # Call the function to load models at startup
44
- load_models()
45
-
46
- # Helper function to free up GPU memory after processing
47
- def clear_gpu_memory():
48
- torch.cuda.empty_cache()
49
- torch.cuda.synchronize()
50
-
51
- # Helper function to convert base64 to image
52
- def base64_to_image(base64_str):
53
- image_data = base64.b64decode(base64_str)
 
 
 
 
 
 
 
 
 
 
 
54
  image = Image.open(BytesIO(image_data)).convert("RGB")
55
  return image
56
 
57
- # Helper function to resize images for faster processing
58
- def resize_image(image, size=(512, 768)):
59
- return image.resize(size)
 
 
60
 
61
- # Example try-on function
62
- @app.route('/start_tryon', methods=['POST'])
63
- def start_tryon():
64
- data = request.get_json()
65
- garm_img_base64 = data['garm_img']
66
- human_img_base64 = data['human_img']
67
-
68
- # Decode and resize images
69
- garm_img = resize_image(base64_to_image(garm_img_base64))
70
- human_img = resize_image(base64_to_image(human_img_base64))
71
 
72
- # Convert images to tensors and move to GPU
73
- garm_img_tensor = torch.tensor(garm_img, dtype=torch.float16).unsqueeze(0).to('cuda')
74
- human_img_tensor = torch.tensor(human_img, dtype=torch.float16).unsqueeze(0).to('cuda')
75
-
76
- try:
77
- # Processing steps (dummy example, replace with your logic)
78
- with torch.inference_mode():
79
- # Run the inference for both images
80
- result_tensor = unet(garm_img_tensor, human_img_tensor) # Replace with your actual logic
81
-
82
- # Free GPU memory after inference
83
- clear_gpu_memory()
84
-
85
- # Convert result back to base64 for return
86
- result_img = Image.fromarray(result_tensor.squeeze(0).cpu().numpy())
87
- buffered = BytesIO()
88
- result_img.save(buffered, format="JPEG")
89
- result_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
90
-
91
- return jsonify({"result": result_base64})
92
-
93
- except Exception as e:
94
- clear_gpu_memory()
95
- return jsonify({"error": str(e)}), 500
96
 
97
  if __name__ == '__main__':
98
  app.run(host='0.0.0.0', port=7860)
 
1
  from flask import Flask, request, jsonify
2
  import torch
3
  from transformers import (
4
+ UNet2DConditionModel,
5
+ AutoTokenizer,
6
+ CLIPTextModel,
7
+ CLIPTextModelWithProjection,
8
+ CLIPVisionModelWithProjection
 
 
9
  )
10
  from PIL import Image
11
  import base64
 
26
 
27
  # Load models once at startup
28
  def load_models():
29
+ global unet, tokenizer_one, tokenizer_two, noise_scheduler
30
+ global text_encoder_one, text_encoder_two, image_encoder, vae, UNet_Encoder
31
+
32
+ if unet is None:
33
+ # Load models only when required to reduce memory usage
34
+ unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-v1-4")
35
+
36
+ if tokenizer_one is None:
37
+ tokenizer_one = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14")
38
+
39
+ if tokenizer_two is None:
40
+ tokenizer_two = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14-336")
41
+
42
+ if noise_scheduler is None:
43
+ noise_scheduler = DDPMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4")
44
+
45
+ if text_encoder_one is None:
46
+ text_encoder_one = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
47
+
48
+ if text_encoder_two is None:
49
+ text_encoder_two = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14-336")
50
+
51
+ if image_encoder is None:
52
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
53
+
54
+ if vae is None:
55
+ vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-v1-4")
56
+
57
+ if UNet_Encoder is None:
58
+ UNet_Encoder = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-v1-4")
59
+
60
+ # Helper function to process base64 image
61
+ def decode_image(image_base64):
62
+ image_data = base64.b64decode(image_base64)
63
  image = Image.open(BytesIO(image_data)).convert("RGB")
64
  return image
65
 
66
+ # Helper function to encode image to base64
67
+ def encode_image(image):
68
+ buffered = BytesIO()
69
+ image.save(buffered, format="PNG")
70
+ return base64.b64encode(buffered.getvalue()).decode('utf-8')
71
 
72
+ # Route for image processing
73
+ @app.route('/process_image', methods=['POST'])
74
+ def process_image():
75
+ data = request.json
 
 
 
 
 
 
76
 
77
+ # Load the models (this will only happen once)
78
+ load_models()
79
+
80
+ # Extract the image from the request
81
+ image_base64 = data.get('image_base64')
82
+ if not image_base64:
83
+ return jsonify({"error": "No image provided"}), 400
84
+
85
+ image = decode_image(image_base64)
86
+
87
+ # Perform inference with the models (example, modify as needed)
88
+ processed_image = image # Placeholder for actual image processing
89
+
90
+ # Return the processed image as base64
91
+ processed_image_base64 = encode_image(processed_image)
92
+ return jsonify({"processed_image": processed_image_base64})
 
 
 
 
 
 
 
 
93
 
94
  if __name__ == '__main__':
95
  app.run(host='0.0.0.0', port=7860)