Update app.py
Browse files
app.py
CHANGED
@@ -1,13 +1,11 @@
|
|
1 |
from flask import Flask, request, jsonify
|
2 |
import torch
|
3 |
from transformers import (
|
4 |
-
UNet2DConditionModel,
|
5 |
-
AutoTokenizer,
|
6 |
-
CLIPTextModel,
|
7 |
-
CLIPTextModelWithProjection,
|
8 |
-
CLIPVisionModelWithProjection
|
9 |
-
AutoencoderKL,
|
10 |
-
DDPMScheduler
|
11 |
)
|
12 |
from PIL import Image
|
13 |
import base64
|
@@ -28,71 +26,70 @@ UNet_Encoder = None
|
|
28 |
|
29 |
# Load models once at startup
|
30 |
def load_models():
|
31 |
-
global unet, tokenizer_one, tokenizer_two, noise_scheduler
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
image = Image.open(BytesIO(image_data)).convert("RGB")
|
55 |
return image
|
56 |
|
57 |
-
# Helper function to
|
58 |
-
def
|
59 |
-
|
|
|
|
|
60 |
|
61 |
-
#
|
62 |
-
@app.route('/
|
63 |
-
def
|
64 |
-
data = request.
|
65 |
-
garm_img_base64 = data['garm_img']
|
66 |
-
human_img_base64 = data['human_img']
|
67 |
-
|
68 |
-
# Decode and resize images
|
69 |
-
garm_img = resize_image(base64_to_image(garm_img_base64))
|
70 |
-
human_img = resize_image(base64_to_image(human_img_base64))
|
71 |
|
72 |
-
#
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
result_img.save(buffered, format="JPEG")
|
89 |
-
result_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
90 |
-
|
91 |
-
return jsonify({"result": result_base64})
|
92 |
-
|
93 |
-
except Exception as e:
|
94 |
-
clear_gpu_memory()
|
95 |
-
return jsonify({"error": str(e)}), 500
|
96 |
|
97 |
if __name__ == '__main__':
|
98 |
app.run(host='0.0.0.0', port=7860)
|
|
|
1 |
from flask import Flask, request, jsonify
|
2 |
import torch
|
3 |
from transformers import (
|
4 |
+
UNet2DConditionModel,
|
5 |
+
AutoTokenizer,
|
6 |
+
CLIPTextModel,
|
7 |
+
CLIPTextModelWithProjection,
|
8 |
+
CLIPVisionModelWithProjection
|
|
|
|
|
9 |
)
|
10 |
from PIL import Image
|
11 |
import base64
|
|
|
26 |
|
27 |
# Load models once at startup
|
28 |
def load_models():
|
29 |
+
global unet, tokenizer_one, tokenizer_two, noise_scheduler
|
30 |
+
global text_encoder_one, text_encoder_two, image_encoder, vae, UNet_Encoder
|
31 |
+
|
32 |
+
if unet is None:
|
33 |
+
# Load models only when required to reduce memory usage
|
34 |
+
unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-v1-4")
|
35 |
+
|
36 |
+
if tokenizer_one is None:
|
37 |
+
tokenizer_one = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
38 |
+
|
39 |
+
if tokenizer_two is None:
|
40 |
+
tokenizer_two = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14-336")
|
41 |
+
|
42 |
+
if noise_scheduler is None:
|
43 |
+
noise_scheduler = DDPMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4")
|
44 |
+
|
45 |
+
if text_encoder_one is None:
|
46 |
+
text_encoder_one = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
47 |
+
|
48 |
+
if text_encoder_two is None:
|
49 |
+
text_encoder_two = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14-336")
|
50 |
+
|
51 |
+
if image_encoder is None:
|
52 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
|
53 |
+
|
54 |
+
if vae is None:
|
55 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-v1-4")
|
56 |
+
|
57 |
+
if UNet_Encoder is None:
|
58 |
+
UNet_Encoder = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-v1-4")
|
59 |
+
|
60 |
+
# Helper function to process base64 image
|
61 |
+
def decode_image(image_base64):
|
62 |
+
image_data = base64.b64decode(image_base64)
|
63 |
image = Image.open(BytesIO(image_data)).convert("RGB")
|
64 |
return image
|
65 |
|
66 |
+
# Helper function to encode image to base64
|
67 |
+
def encode_image(image):
|
68 |
+
buffered = BytesIO()
|
69 |
+
image.save(buffered, format="PNG")
|
70 |
+
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
71 |
|
72 |
+
# Route for image processing
|
73 |
+
@app.route('/process_image', methods=['POST'])
|
74 |
+
def process_image():
|
75 |
+
data = request.json
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
+
# Load the models (this will only happen once)
|
78 |
+
load_models()
|
79 |
+
|
80 |
+
# Extract the image from the request
|
81 |
+
image_base64 = data.get('image_base64')
|
82 |
+
if not image_base64:
|
83 |
+
return jsonify({"error": "No image provided"}), 400
|
84 |
+
|
85 |
+
image = decode_image(image_base64)
|
86 |
+
|
87 |
+
# Perform inference with the models (example, modify as needed)
|
88 |
+
processed_image = image # Placeholder for actual image processing
|
89 |
+
|
90 |
+
# Return the processed image as base64
|
91 |
+
processed_image_base64 = encode_image(processed_image)
|
92 |
+
return jsonify({"processed_image": processed_image_base64})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
if __name__ == '__main__':
|
95 |
app.run(host='0.0.0.0', port=7860)
|