sagar007 commited on
Commit
2af5758
·
verified ·
1 Parent(s): 73a8e2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -14
app.py CHANGED
@@ -8,37 +8,37 @@ from transformers import AutoProcessor, CLIPSegForImageSegmentation
8
  processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
9
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
10
 
 
 
 
 
11
  def segment_everything(image):
12
- # Check if image is a list and extract the actual image data
13
  if isinstance(image, list):
14
  image = image[0]
15
 
16
- # Convert numpy array to PIL Image
17
  if isinstance(image, np.ndarray):
18
  image = Image.fromarray(image)
19
 
20
- inputs = processor(text=["object"], images=[image], padding="max_length", return_tensors="pt")
21
  with torch.no_grad():
22
  outputs = model(**inputs)
23
- preds = outputs.logits.squeeze().sigmoid()
24
  segmentation = (preds.numpy() * 255).astype(np.uint8)
25
  return Image.fromarray(segmentation)
26
 
27
  def segment_box(image, x1, y1, x2, y2):
28
- # Check if image is a list and extract the actual image data
29
  if isinstance(image, list):
30
  image = image[0]
31
 
32
- # Convert PIL Image to numpy array if necessary
33
  if isinstance(image, Image.Image):
34
  image = np.array(image)
35
 
36
  x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
37
  cropped_image = image[y1:y2, x1:x2]
38
- inputs = processor(text=["object"], images=[Image.fromarray(cropped_image)], padding="max_length", return_tensors="pt")
39
  with torch.no_grad():
40
  outputs = model(**inputs)
41
- preds = outputs.logits.squeeze().sigmoid()
42
  segmentation = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
43
  segmentation[y1:y2, x1:x2] = (preds.numpy() * 255).astype(np.uint8)
44
  return Image.fromarray(segmentation)
@@ -47,24 +47,19 @@ def update_image(image, segmentation):
47
  if segmentation is None:
48
  return image
49
 
50
- # Check if image is a list and extract the actual image data
51
  if isinstance(image, list):
52
  image = image[0]
53
 
54
- # Ensure image is in the correct format (PIL Image)
55
  if isinstance(image, np.ndarray):
56
  image_pil = Image.fromarray(image)
57
  else:
58
  image_pil = image
59
 
60
- # Convert segmentation to RGBA
61
  seg_pil = Image.fromarray(segmentation).convert('RGBA')
62
 
63
- # Resize segmentation to match input image if necessary
64
  if image_pil.size != seg_pil.size:
65
  seg_pil = seg_pil.resize(image_pil.size, Image.NEAREST)
66
 
67
- # Blend images
68
  blended = Image.blend(image_pil.convert('RGBA'), seg_pil, 0.5)
69
 
70
  return np.array(blended)
@@ -101,4 +96,4 @@ with gr.Blocks() as demo:
101
  outputs=[output_image]
102
  )
103
 
104
- demo.launch()
 
8
  processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
9
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
10
 
11
+ # Ensure that the model uses GPU if available
12
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+ model.to(device)
14
+
15
  def segment_everything(image):
 
16
  if isinstance(image, list):
17
  image = image[0]
18
 
 
19
  if isinstance(image, np.ndarray):
20
  image = Image.fromarray(image)
21
 
22
+ inputs = processor(text=["object"], images=[image], padding="max_length", return_tensors="pt").to(device)
23
  with torch.no_grad():
24
  outputs = model(**inputs)
25
+ preds = outputs.logits.squeeze().sigmoid().cpu()
26
  segmentation = (preds.numpy() * 255).astype(np.uint8)
27
  return Image.fromarray(segmentation)
28
 
29
  def segment_box(image, x1, y1, x2, y2):
 
30
  if isinstance(image, list):
31
  image = image[0]
32
 
 
33
  if isinstance(image, Image.Image):
34
  image = np.array(image)
35
 
36
  x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
37
  cropped_image = image[y1:y2, x1:x2]
38
+ inputs = processor(text=["object"], images=[Image.fromarray(cropped_image)], padding="max_length", return_tensors="pt").to(device)
39
  with torch.no_grad():
40
  outputs = model(**inputs)
41
+ preds = outputs.logits.squeeze().sigmoid().cpu()
42
  segmentation = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
43
  segmentation[y1:y2, x1:x2] = (preds.numpy() * 255).astype(np.uint8)
44
  return Image.fromarray(segmentation)
 
47
  if segmentation is None:
48
  return image
49
 
 
50
  if isinstance(image, list):
51
  image = image[0]
52
 
 
53
  if isinstance(image, np.ndarray):
54
  image_pil = Image.fromarray(image)
55
  else:
56
  image_pil = image
57
 
 
58
  seg_pil = Image.fromarray(segmentation).convert('RGBA')
59
 
 
60
  if image_pil.size != seg_pil.size:
61
  seg_pil = seg_pil.resize(image_pil.size, Image.NEAREST)
62
 
 
63
  blended = Image.blend(image_pil.convert('RGBA'), seg_pil, 0.5)
64
 
65
  return np.array(blended)
 
96
  outputs=[output_image]
97
  )
98
 
99
+ demo.launch()