till-onethousand commited on
Commit
7a1b8c6
·
1 Parent(s): 9c855c1

trained model

Browse files
Files changed (2) hide show
  1. app.py +19 -2
  2. config.py +3 -1
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import streamlit as st
2
  from PIL import Image
3
  from transformers import ViTForImageClassification
4
- from config import UNTRAINED, labels
5
  from utils import predict
6
 
7
 
@@ -12,6 +12,13 @@ model_untrained = ViTForImageClassification.from_pretrained(
12
  label2id={c: str(i) for i, c in enumerate(labels)},
13
  )
14
 
 
 
 
 
 
 
 
15
  st.title("Detect Hurricane Damage")
16
 
17
  col1, col2 = st.columns(2)
@@ -22,5 +29,15 @@ with col1:
22
  if file_name is not None:
23
  image = Image.open(file_name)
24
  col1.image(image, use_container_width=True)
25
- label = predict
 
 
 
 
 
 
 
 
 
 
26
  st.write(f"Predicted label: {label}")
 
1
  import streamlit as st
2
  from PIL import Image
3
  from transformers import ViTForImageClassification
4
+ from config import UNTRAINED, labels, TRAINED
5
  from utils import predict
6
 
7
 
 
12
  label2id={c: str(i) for i, c in enumerate(labels)},
13
  )
14
 
15
+ model_trained = ViTForImageClassification.from_pretrained(
16
+ TRAINED,
17
+ num_labels=len(labels),
18
+ id2label={str(i): c for i, c in enumerate(labels)},
19
+ label2id={c: str(i) for i, c in enumerate(labels)},
20
+ )
21
+
22
  st.title("Detect Hurricane Damage")
23
 
24
  col1, col2 = st.columns(2)
 
29
  if file_name is not None:
30
  image = Image.open(file_name)
31
  col1.image(image, use_container_width=True)
32
+ label = predict(model_untrained, image)
33
+ st.write(f"Predicted label: {label}")
34
+
35
+ with col2:
36
+ st.markdown("## Fine-Tuned Model")
37
+ file_name = st.file_uploader("Upload a satellite image")
38
+
39
+ if file_name is not None:
40
+ image = Image.open(file_name)
41
+ col2.image(image, use_container_width=True)
42
+ label = predict(model_trained, image)
43
  st.write(f"Predicted label: {label}")
config.py CHANGED
@@ -4,4 +4,6 @@ dataset_name = "jonathan-roberts1/Satellite-Images-of-Hurricane-Damage"
4
  ds = load_dataset(dataset_name)
5
  labels = ds['train'].features['label'].names
6
 
7
- UNTRAINED = 'google/vit-base-patch16-224-in21k'
 
 
 
4
  ds = load_dataset(dataset_name)
5
  labels = ds['train'].features['label'].names
6
 
7
+ UNTRAINED = 'google/vit-base-patch16-224-in21k'
8
+
9
+ TRAINED = '"till-onethousand/hurricane_model"'