File size: 5,514 Bytes
f67b8d5
 
 
 
 
 
 
 
 
 
36f1223
f67b8d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36f1223
f67b8d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177

import torch
from PIL import Image
from torchvision import transforms
from clipseg import CLIPDensePredT


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.Resize((352, 352), antialias=True),
])


model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)
model.eval()
model.load_state_dict(torch.load('weights/rd64-uni.pth', 
                    map_location=torch.device('cpu')), strict=False)


def predict(image, prompts):
    """
    Predict segmentation masks for the given image based on the provided prompts.

    Parameters:
    - image (PIL.Image): The input image.
    - prompts (str): A comma-separated string of prompts.
    - Model (torch.nn): Segmentation Model.

    Returns:
    - tuple: A tuple containing the resized input image and a list of segmentation masks.
    """
    
    img = transform(image).unsqueeze(0)
    
    # Split the prompts string into a list of individual prompts
    prompts = prompts.split(',')
    num_prompts = len(prompts)

    # Ensure no gradient computation during prediction for performance
    with torch.no_grad():
        # Get model predictions for each prompt
        preds = model(img.repeat(len(prompts), 1, 1, 1), prompts)[0]

    # Convert model predictions to segmentation masks
    masks = [torch.sigmoid(preds[i][0]) for i in range(num_prompts)]
    masks = [(m.squeeze(0).numpy(), prompts[i]) for i, m in enumerate(masks)]
    
    # Return the resized input image and the list of segmentation masks
    return (image.resize((352, 352), Image.LANCZOS), masks)

def get_examples():
   examples = [
      ['images/000013.jpg', 'deer, tree, grass'],
      ['images/000002.jpg', 'train, tracks, electric pole, house'],
      ['images/00125.jpg', 'dog, flowers'],
      ['images/000010.jpg', 'horse, man, fence, buildings, hill'],
      ['images/000004.jpg', 'car, truck, building, sky, traffic light, tree, clouds']
   ]
   return(examples)


def get_html():
  html_string = """
      <!DOCTYPE html>
      <html lang="en">

      <head>
          <meta charset="UTF-8">
          <meta name="viewport" content="width=device-width, initial-scale=1.0">
          <title>Multi-Prompt Image Segmentation</title>
          <link href="https://fonts.googleapis.com/css2?family=Roboto+Slab:wght@400;700&display=swap" rel="stylesheet">

          <style>
              /* General styling */
              body {
                  font-family: 'Roboto Slab', serif;
                  margin: 0;
                  padding: 0;
                  background-color: #f4f4f4;
              }

              .app-header {
                  background: linear-gradient(135deg, #4a90e2, #50e3c2);
                  color: #fff;
                  text-align: center;
                  padding: 40px 0;
                  border-radius: 20px;
                  position: relative;
                  overflow: hidden;
                  box-shadow: 0px 10px 20px rgba(0, 0, 0, 0.1);
              }

              /* Ellipse Overlay */
              .app-header::before {
                  content: "";
                  position: absolute;
                  top: -50%;
                  left: -50%;
                  width: 200%;
                  height: 200%;
                  background: rgba(255, 255, 255, 0.1);
                  transform: rotate(45deg);
                  border-radius: 50%;
              }

              /* Floating Shapes */
              .app-header::after {
                  content: "";
                  position: absolute;
                  top: 20%;
                  right: 10%;
                  width: 70px;
                  height: 70px;
                  background: rgba(255, 255, 255, 0.2);
                  border-radius: 50%;
              }

              .floating-shape {
                  content: "";
                  position: absolute;
                  top: 10%;
                  left: 5%;
                  width: 50px;
                  height: 50px;
                  background: rgba(255, 255, 255, 0.2);
                  border-radius: 50%;
              }

              /* Text Styling */
              .app-title {
                  font-size: 28px;
                  margin: 0;
                  font-weight: 700;
                  text-shadow: 2px 2px 4px rgba(0, 0, 0, 0.2);
              }

              .app-description {
                  font-size: 18px;
                  margin-top: 15px;
                  opacity: 0.9;
                  text-shadow: 1px 1px 3px rgba(0, 0, 0, 0.1);
              }

              /* Wavy Bottom */
              .wavy-bottom {
                  position: absolute;
                  bottom: -10px;
                  left: 0;
                  width: 100%;
                  height: 20px;
                  background: #f4f4f4;
                  border-radius: 100% 100% 0 0;
              }
          </style>
      </head>

      <body>

          <!-- App Header -->
          <div class="app-header">
              <h1 class="app-title">Multi-Prompt Image Segmentation</h1>
              <p class="app-description">Upload an image and provide multiple text prompts separated by commas. Get segmented masks for each prompt.</p>
              <div class="floating-shape"></div>
              <div class="wavy-bottom"></div>
          </div>

          <!-- Rest of the app content will go here -->

      </body>

      </html>


  """

  return(html_string)