zedwone commited on
Commit
9cb3e33
·
verified ·
1 Parent(s): 9f0f2cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -2
app.py CHANGED
@@ -4,9 +4,16 @@ import cv2
4
  import gradio as gr
5
  from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
6
  from PIL import Image
 
 
 
 
 
 
 
7
 
8
  # 加载 Segment Anything 模型
9
- sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth").to("cuda")
10
  mask_generator = SamAutomaticMaskGenerator(sam)
11
 
12
  def segment(image):
@@ -18,4 +25,4 @@ def segment(image):
18
 
19
  # Gradio API
20
  demo = gr.Interface(fn=segment, inputs="image", outputs="image")
21
- demo.launch()
 
4
  import gradio as gr
5
  from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
6
  from PIL import Image
7
+ from huggingface_hub import hf_hub_download # 导入 Hugging Face Hub 下载工具
8
+
9
+ # 下载模型文件
10
+ chkpt_path = hf_hub_download(
11
+ repo_id="ybelkada/segment-anything", # 使用 ybelkada 的仓库
12
+ filename="checkpoints/sam_vit_b_01ec64.pth" # 下载 sam_vit_b_01ec64.pth 模型
13
+ )
14
 
15
  # 加载 Segment Anything 模型
16
+ sam = sam_model_registry["vit_b"](checkpoint=chkpt_path).to("cuda") # 注意模型类型改为 vit_b
17
  mask_generator = SamAutomaticMaskGenerator(sam)
18
 
19
  def segment(image):
 
25
 
26
  # Gradio API
27
  demo = gr.Interface(fn=segment, inputs="image", outputs="image")
28
+ demo.launch()