zedwone commited on
Commit
af15e39
·
verified ·
1 Parent(s): 6a3024d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -4
app.py CHANGED
@@ -1,8 +1,21 @@
 
 
 
1
  import gradio as gr
 
 
2
 
3
- def greet(name):
4
- return f"Hello {name}!"
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
7
 
8
- demo.launch()
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import cv2
4
  import gradio as gr
5
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
6
+ from PIL import Image
7
 
8
+ # 加载 Segment Anything 模型
9
+ sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth").to("cuda")
10
+ mask_generator = SamAutomaticMaskGenerator(sam)
11
 
12
+ def segment(image):
13
+ image = np.array(image)
14
+ masks = mask_generator.generate(image)
15
+ largest_mask = max(masks, key=lambda x: x['area'])['segmentation']
16
+ binary_mask = np.where(largest_mask, 255, 0).astype(np.uint8)
17
+ return Image.fromarray(binary_mask)
18
 
19
+ # Gradio API
20
+ demo = gr.Interface(fn=segment, inputs="image", outputs="image")
21
+ demo.launch()