Tanusree88 commited on
Commit
4c7d7ed
·
verified ·
1 Parent(s): 6df7be1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -3
app.py CHANGED
@@ -8,20 +8,34 @@ from PIL import Image
8
  from torch.utils.data import Dataset, DataLoader
9
  import streamlit as st
10
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
11
 
 
12
  feature_extractor = SegformerFeatureExtractor.from_pretrained('nvidia/segformer-b0-finetuned-ade-512-512')
13
  segformer_model = SegformerForSemanticSegmentation.from_pretrained('nvidia/segformer-b0-finetuned-ade-512-512')
14
 
15
- # Inference function to use the Segformer model
16
  def segment_image(image):
17
  inputs = feature_extractor(images=image, return_tensors="pt")
18
  outputs = segformer_model(**inputs)
19
  segmentation = outputs.logits.argmax(dim=1).squeeze().cpu().numpy()
20
  return segmentation
21
 
22
- # Use Gradio interface to display the segmentation result
23
  iface = gr.Interface(fn=segment_image, inputs=gr.Image(type="pil"), outputs="image")
24
- iface.launch() # Launch Gradio UI
 
 
 
25
 
26
  # Function to extract zip files
27
  def extract_zip(zip_file, extract_to):
 
8
  from torch.utils.data import Dataset, DataLoader
9
  import streamlit as st
10
  import gradio as gr
11
+ import os
12
+ import zipfile
13
+ import numpy as np
14
+ import torch
15
+ from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor
16
+ from transformers import ResNetForImageClassification, AdamW
17
+ from PIL import Image
18
+ from torch.utils.data import Dataset, DataLoader
19
+ import streamlit as st
20
+ import gradio as gr
21
 
22
+ # Load feature extractor and model
23
  feature_extractor = SegformerFeatureExtractor.from_pretrained('nvidia/segformer-b0-finetuned-ade-512-512')
24
  segformer_model = SegformerForSemanticSegmentation.from_pretrained('nvidia/segformer-b0-finetuned-ade-512-512')
25
 
26
+ # Inference function for segmentation
27
  def segment_image(image):
28
  inputs = feature_extractor(images=image, return_tensors="pt")
29
  outputs = segformer_model(**inputs)
30
  segmentation = outputs.logits.argmax(dim=1).squeeze().cpu().numpy()
31
  return segmentation
32
 
33
+ # Gradio interface at the end of the script
34
  iface = gr.Interface(fn=segment_image, inputs=gr.Image(type="pil"), outputs="image")
35
+
36
+ # Specify a custom port if needed to avoid conflicts (optional)
37
+ iface.launch(server_port=7861) # Change port if 7860 is occupied
38
+
39
 
40
  # Function to extract zip files
41
  def extract_zip(zip_file, extract_to):