Tanusree88 commited on
Commit
6df7be1
·
verified ·
1 Parent(s): 6beb603

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -9,8 +9,19 @@ from torch.utils.data import Dataset, DataLoader
9
  import streamlit as st
10
  import gradio as gr
11
 
12
- # Load the Segformer model using Gradio (Optional)
13
- gr.load("models/nvidia/segformer-b0-finetuned-ade-512-512").launch()
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  # Function to extract zip files
16
  def extract_zip(zip_file, extract_to):
 
9
  import streamlit as st
10
  import gradio as gr
11
 
12
+ feature_extractor = SegformerFeatureExtractor.from_pretrained('nvidia/segformer-b0-finetuned-ade-512-512')
13
+ segformer_model = SegformerForSemanticSegmentation.from_pretrained('nvidia/segformer-b0-finetuned-ade-512-512')
14
+
15
+ # Inference function to use the Segformer model
16
+ def segment_image(image):
17
+ inputs = feature_extractor(images=image, return_tensors="pt")
18
+ outputs = segformer_model(**inputs)
19
+ segmentation = outputs.logits.argmax(dim=1).squeeze().cpu().numpy()
20
+ return segmentation
21
+
22
+ # Use Gradio interface to display the segmentation result
23
+ iface = gr.Interface(fn=segment_image, inputs=gr.Image(type="pil"), outputs="image")
24
+ iface.launch() # Launch Gradio UI
25
 
26
  # Function to extract zip files
27
  def extract_zip(zip_file, extract_to):