Saad0KH commited on
Commit
739eb5d
Β·
verified Β·
1 Parent(s): cbcdcd2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import base64
4
+ import spaces
5
+ from loadimg import load_img
6
+ from io import BytesIO
7
+ import numpy as np
8
+ import insightface
9
+ import onnxruntime as ort
10
+ import huggingface_hub
11
+ from SegCloth import segment_clothing
12
+ from transparent_background import Remover
13
+ import uuid
14
+ from transformers import AutoModelForImageSegmentation
15
+ import torch
16
+ from torchvision import transforms
17
+
18
+ # Load the model lazily
19
+ model = None
20
+ detector = None
21
+ def load_model():
22
+ global model, detector
23
+ path = huggingface_hub.hf_hub_download("public-data/insightface", "models/scrfd_person_2.5g.onnx")
24
+ options = ort.SessionOptions()
25
+ options.intra_op_num_threads = 8
26
+ options.inter_op_num_threads = 8
27
+ session = ort.InferenceSession(
28
+ path, sess_options=options, providers=["CPUExecutionProvider", "CUDAExecutionProvider"]
29
+ )
30
+ model = insightface.model_zoo.retinaface.RetinaFace(model_file=path, session=session)
31
+ model.prepare(-1, nms_thresh=0.5, input_size=(640, 640))
32
+ detector = model
33
+
34
+ # Load the segmentation model
35
+ torch.set_float32_matmul_precision(["high", "highest"][0])
36
+ birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True)
37
+ birefnet.to("cuda")
38
+ transform_image = transforms.Compose([
39
+ transforms.Resize((1024, 1024)),
40
+ transforms.ToTensor(),
41
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
42
+ ])
43
+
44
+ def save_image(img):
45
+ unique_name = str(uuid.uuid4()) + ".png"
46
+ img.save(unique_name)
47
+ return unique_name
48
+
49
+ def rm_background(image):
50
+ im = load_img(image, output_type="pil")
51
+ im = im.convert("RGB")
52
+ image_size = im.size
53
+ origin = im.copy()
54
+ image = load_img(im)
55
+ input_images = transform_image(image).unsqueeze(0).to("cuda")
56
+ # Prediction
57
+ with torch.no_grad():
58
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
59
+ pred = preds[0].squeeze()
60
+ pred_pil = transforms.ToPILImage()(pred)
61
+ mask = pred_pil.resize(image_size)
62
+ image.putalpha(mask)
63
+ return image
64
+
65
+ def detect_and_segment_persons(image, clothes):
66
+ img = np.array(image)
67
+ img = img[:, :, ::-1] # RGB -> BGR
68
+
69
+ if detector is None:
70
+ load_model() # Ensure the model is loaded
71
+
72
+ bboxes, kpss = detector.detect(img)
73
+ if bboxes.shape[0] == 0:
74
+ return [rm_background(image)]
75
+
76
+ height, width, _ = img.shape
77
+ bboxes = np.round(bboxes[:, :4]).astype(int)
78
+ bboxes[:, 0] = np.clip(bboxes[:, 0], 0, width)
79
+ bboxes[:, 1] = np.clip(bboxes[:, 1], 0, height)
80
+ bboxes[:, 2] = np.clip(bboxes[:, 2], 0, width)
81
+ bboxes[:, 3] = np.clip(bboxes[:, 3], 0, height)
82
+
83
+ all_segmented_images = []
84
+ for i in range(bboxes.shape[0]):
85
+ bbox = bboxes[i]
86
+ x1, y1, x2, y2 = bbox
87
+ person_img = img[y1:y2, x1:x2]
88
+ pil_img = Image.fromarray(person_img[:, :, ::-1])
89
+
90
+ img_rm_background = rm_background(pil_img)
91
+ segmented_result = segment_clothing(img_rm_background, clothes)
92
+ all_segmented_images.extend(segmented_result)
93
+
94
+ return all_segmented_images
95
+
96
+ def process_image(input_image):
97
+ try:
98
+ clothes = ["Upper-clothes", "Skirt", "Pants", "Dress"]
99
+ results = detect_and_segment_persons(input_image, clothes)
100
+ return results
101
+ except Exception as e:
102
+ return f"Error occurred: {e}"
103
+
104
+ # Gradio Interface
105
+ def gradio_interface(image):
106
+ results = process_image(image)
107
+ if isinstance(results, list):
108
+ return results
109
+ else:
110
+ return "Error: " + results
111
+
112
+ # Create Gradio app
113
+ interface = gr.Interface(
114
+ fn=gradio_interface,
115
+ inputs=gr.Image(type="pil"),
116
+ outputs=gr.Gallery(label="Segmented Results"),
117
+ title="Clothing Segmentation API"
118
+ )
119
+
120
+ interface.launch(server_name="0.0.0.0", server_port=7860)