sigyllly commited on
Commit
3991df5
·
verified ·
1 Parent(s): f77dac9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -10
app.py CHANGED
@@ -8,11 +8,23 @@ import threading
8
  processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
9
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
10
 
 
11
  # Function to process image and generate mask
12
  def process_image(image, prompt):
13
  inputs = processor(
14
  text=prompt, images=image, padding="max_length", return_tensors="pt"
15
  )
 
 
 
 
 
 
 
 
 
 
 
16
  with torch.no_grad():
17
  outputs = model(**inputs)
18
  preds = outputs.logits
@@ -30,16 +42,6 @@ def process_image(image, prompt):
30
 
31
  return mask
32
 
33
- # Function to get masks from positive or negative prompts
34
- def get_masks(prompts, img, threshold):
35
- prompts = prompts.split(",")
36
- masks = []
37
- for prompt in prompts:
38
- mask = process_image(img, prompt)
39
- mask = mask > threshold
40
- masks.append(mask)
41
-
42
- return masks
43
 
44
  # Function to extract image using positive and negative prompts
45
  def extract_image(pos_prompts, neg_prompts, img, threshold):
 
8
  processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
9
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
10
 
11
+ # Function to process image and generate mask
12
  # Function to process image and generate mask
13
  def process_image(image, prompt):
14
  inputs = processor(
15
  text=prompt, images=image, padding="max_length", return_tensors="pt"
16
  )
17
+
18
+ # Extract image tensor and normalize it
19
+ image_tensor = inputs["pixel_values"].squeeze().permute(1, 2, 0).cpu().numpy()
20
+ image_tensor = (image_tensor * 255).astype(np.uint8)
21
+ image_tensor = Image.fromarray(image_tensor)
22
+ image_tensor = image_tensor.convert("RGB")
23
+
24
+ # Perform CLIPSeg processing
25
+ inputs = processor(
26
+ text=prompt, images=image_tensor, padding="max_length", return_tensors="pt"
27
+ )
28
  with torch.no_grad():
29
  outputs = model(**inputs)
30
  preds = outputs.logits
 
42
 
43
  return mask
44
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  # Function to extract image using positive and negative prompts
47
  def extract_image(pos_prompts, neg_prompts, img, threshold):