OttoYu commited on
Commit
28ea5ec
·
verified ·
1 Parent(s): e5cb2e8

Upload 2 files

Browse files
Files changed (2) hide show
  1. main.py +101 -0
  2. requirements.txt +72 -0
main.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from PIL import Image
4
+ from transformers import pipeline
5
+ import torch
6
+ import tifffile
7
+ import gradio as gr
8
+ import os
9
+
10
+ # Step 1: Setup
11
+ print("Step 1: Setting up the environment...")
12
+ device = 0 if torch.cuda.is_available() else -1
13
+ print(f" > Device selected: {'GPU' if device == 0 else 'CPU'}")
14
+
15
+ # Step 2: Load SAM Model
16
+ print("Step 2: Loading SAM Model...")
17
+ generator = pipeline("mask-generation", model="facebook/sam-vit-huge", device=device)
18
+ print(" > SAM Model loaded successfully.")
19
+
20
+
21
+ def segment_image(image):
22
+ print("Step 3: Starting image segmentation...")
23
+
24
+ # Resize Image
25
+ print(" > Resizing image...")
26
+ raw_image = image.convert("RGB")
27
+ original_size = raw_image.size
28
+ resized_size = (original_size[0] // 3, original_size[1] // 3)
29
+ raw_image = raw_image.resize(resized_size)
30
+ print(f" > Original size: {original_size}, Resized size: {resized_size}")
31
+
32
+ # Run SAM Segmentation
33
+ print(" > Running SAM segmentation...")
34
+ outputs = generator(raw_image, points_per_batch=64)
35
+ masks = outputs["masks"]
36
+ print(f" > {len(masks)} masks generated.")
37
+
38
+ # Create Labeled Mask
39
+ print(" > Creating labeled mask...")
40
+ h, w = masks[0].shape
41
+ labeled_mask = np.zeros((h, w), dtype=np.uint16)
42
+ for i, mask in enumerate(masks):
43
+ labeled_mask[mask] = i + 1
44
+ print(" > Labeled mask created.")
45
+
46
+ # Generate Overlay
47
+ print(" > Generating overlay...")
48
+ overlay = np.zeros((h, w, 4)) # RGBA
49
+ np.random.seed(42)
50
+ for label in np.unique(labeled_mask):
51
+ if label == 0:
52
+ continue
53
+ color = np.random.rand(3)
54
+ overlay[labeled_mask == label] = np.append(color, 0.5)
55
+ print(" > Overlay generated.")
56
+
57
+ # Save the labeled mask as TIFF
58
+ output_path = "labeled_mask.tif"
59
+ print(" > Saving labeled mask as TIFF...")
60
+ tifffile.imwrite(output_path, labeled_mask)
61
+ print(f" > Mask saved to: {output_path}")
62
+
63
+ # Plotting results
64
+ print("Step 4: Plotting results...")
65
+ plt.figure(figsize=(15, 5))
66
+
67
+ # Original Image
68
+ plt.subplot(1, 2, 1)
69
+ plt.imshow(image)
70
+ plt.title("Original Image")
71
+ plt.axis("off")
72
+
73
+ # Segmented Overlay
74
+ plt.subplot(1, 2, 2)
75
+ plt.imshow(raw_image)
76
+ plt.imshow(overlay)
77
+ plt.title("Segmented Overlay")
78
+ plt.axis("off")
79
+
80
+ plt.tight_layout()
81
+ plt.savefig("segmented_overlay.png") # Save the overlay plot
82
+ plt.close() # Close the plot to avoid display issues
83
+ print(" > Results plotted.")
84
+
85
+ return output_path # Return path to the saved mask
86
+
87
+
88
+ # Step 5: Gradio Interface
89
+ print("Step 5: Setting up Gradio interface...")
90
+ iface = gr.Interface(
91
+ fn=segment_image,
92
+ inputs=gr.Image(type="pil"),
93
+ outputs=gr.File(label="Download Mask"),
94
+ title="Image Segmentation with SAM",
95
+ description="Upload an image to segment it and visualize the results."
96
+ )
97
+
98
+ # Step 6: Launch the interface
99
+ print("Step 6: Launching the interface...")
100
+ iface.launch()
101
+ print(" > Interface launched successfully.")
requirements.txt ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0
2
+ annotated-types==0.7.0
3
+ anyio==4.9.0
4
+ certifi==2025.6.15
5
+ charset-normalizer==3.4.2
6
+ click==8.2.1
7
+ contourpy==1.3.2
8
+ cycler==0.12.1
9
+ exceptiongroup==1.3.0
10
+ fastapi==0.115.12
11
+ ffmpy==0.6.0
12
+ filelock==3.18.0
13
+ fonttools==4.58.4
14
+ fsspec==2025.5.1
15
+ gradio==5.34.0
16
+ gradio_client==1.10.3
17
+ groovy==0.1.2
18
+ h11==0.16.0
19
+ hf-xet==1.1.4
20
+ httpcore==1.0.9
21
+ httpx==0.28.1
22
+ huggingface-hub==0.33.0
23
+ idna==3.10
24
+ Jinja2==3.1.6
25
+ kiwisolver==1.4.8
26
+ laspy==2.5.4
27
+ markdown-it-py==3.0.0
28
+ MarkupSafe==3.0.2
29
+ matplotlib==3.10.3
30
+ mdurl==0.1.2
31
+ mpmath==1.3.0
32
+ networkx==3.4.2
33
+ numpy==2.2.6
34
+ orjson==3.10.18
35
+ packaging==25.0
36
+ pandas==2.3.0
37
+ pillow==11.2.1
38
+ pydantic==2.11.7
39
+ pydantic_core==2.33.2
40
+ pydub==0.25.1
41
+ Pygments==2.19.1
42
+ pyparsing==3.2.3
43
+ python-dateutil==2.9.0.post0
44
+ python-multipart==0.0.20
45
+ pytz==2025.2
46
+ PyYAML==6.0.2
47
+ regex==2024.11.6
48
+ requests==2.32.4
49
+ rich==14.0.0
50
+ ruff==0.11.13
51
+ safehttpx==0.1.6
52
+ safetensors==0.5.3
53
+ semantic-version==2.10.0
54
+ shellingham==1.5.4
55
+ six==1.17.0
56
+ sniffio==1.3.1
57
+ starlette==0.46.2
58
+ sympy==1.14.0
59
+ tifffile==2025.5.10
60
+ tokenizers==0.21.1
61
+ tomlkit==0.13.3
62
+ torch==2.7.1
63
+ torchvision==0.22.1
64
+ tqdm==4.67.1
65
+ transformers==4.52.4
66
+ typer==0.16.0
67
+ typing-inspection==0.4.1
68
+ typing_extensions==4.14.0
69
+ tzdata==2025.2
70
+ urllib3==2.4.0
71
+ uvicorn==0.34.3
72
+ websockets==15.0.1