Update app.py
Browse files
app.py
CHANGED
@@ -9,8 +9,19 @@ from torch.utils.data import Dataset, DataLoader
|
|
9 |
import streamlit as st
|
10 |
import gradio as gr
|
11 |
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
# Function to extract zip files
|
16 |
def extract_zip(zip_file, extract_to):
|
|
|
9 |
import streamlit as st
|
10 |
import gradio as gr
|
11 |
|
12 |
+
feature_extractor = SegformerFeatureExtractor.from_pretrained('nvidia/segformer-b0-finetuned-ade-512-512')
|
13 |
+
segformer_model = SegformerForSemanticSegmentation.from_pretrained('nvidia/segformer-b0-finetuned-ade-512-512')
|
14 |
+
|
15 |
+
# Inference function to use the Segformer model
|
16 |
+
def segment_image(image):
|
17 |
+
inputs = feature_extractor(images=image, return_tensors="pt")
|
18 |
+
outputs = segformer_model(**inputs)
|
19 |
+
segmentation = outputs.logits.argmax(dim=1).squeeze().cpu().numpy()
|
20 |
+
return segmentation
|
21 |
+
|
22 |
+
# Use Gradio interface to display the segmentation result
|
23 |
+
iface = gr.Interface(fn=segment_image, inputs=gr.Image(type="pil"), outputs="image")
|
24 |
+
iface.launch() # Launch Gradio UI
|
25 |
|
26 |
# Function to extract zip files
|
27 |
def extract_zip(zip_file, extract_to):
|