willco-afk commited on
Commit
9c5fbe6
·
verified ·
1 Parent(s): 48fe6c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -11
app.py CHANGED
@@ -1,11 +1,24 @@
1
  import streamlit as st
2
- from transformers import TFAutoModelForImageClassification, AutoFeatureExtractor
3
  from PIL import Image
4
  import numpy as np
 
 
 
5
 
6
- # Load the model and feature extractor
7
- model = TFAutoModelForImageClassification.from_pretrained('/content/drive/MyDrive/my_christmas_tree_model')
8
- feature_extractor = AutoFeatureExtractor.from_pretrained(model.config._name_or_path)
 
 
 
 
 
 
 
 
 
 
9
 
10
  # Streamlit UI
11
  st.title("Christmas Tree Classifier")
@@ -19,15 +32,17 @@ if uploaded_file is not None:
19
  st.image(image, caption="Uploaded Image.", use_column_width=True)
20
 
21
  # Preprocess the image
22
- inputs = feature_extractor(images=image, return_tensors="tf")
 
 
23
 
24
  # Make prediction
25
- logits = model(**inputs).logits
26
- predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
27
-
28
- # Map class index to label
29
- class_names = model.config.id2label # Get class names from model config
30
- predicted_class = class_names[predicted_class_idx]
31
 
32
  # Display the prediction
33
  st.write(f"Prediction: **{predicted_class}**")
 
1
  import streamlit as st
2
+ import tensorflow as tf
3
  from PIL import Image
4
  import numpy as np
5
+ import tempfile
6
+ import zipfile
7
+ import os
8
 
9
+ # Function to load model from a zip file
10
+ def load_model_from_zip(zip_file_path):
11
+ with tempfile.TemporaryDirectory() as temp_dir:
12
+ with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
13
+ zip_ref.extractall(temp_dir)
14
+
15
+ # Load the model (assuming your model is 'your_trained_model.keras' inside the zip)
16
+ model_path = os.path.join(temp_dir, 'your_trained_model.keras')
17
+ model = tf.keras.models.load_model(model_path)
18
+ return model
19
+
20
+ # Load the model from the zip file
21
+ model = load_model_from_zip('/content/drive/MyDrive/my_christmas_tree_model.zip') # Replace 'my_christmas_tree_model.zip' with the actual name of the zip file
22
 
23
  # Streamlit UI
24
  st.title("Christmas Tree Classifier")
 
32
  st.image(image, caption="Uploaded Image.", use_column_width=True)
33
 
34
  # Preprocess the image
35
+ image = image.resize((150, 150)) # Resize to match your model's input shape
36
+ image_array = np.array(image) / 255.0 # Normalize
37
+ image_array = np.expand_dims(image_array, axis=0) # Add batch dimension
38
 
39
  # Make prediction
40
+ prediction = model.predict(image_array)
41
+
42
+ # Interpret prediction (assuming binary classification)
43
+ class_names = ['Undecorated', 'Decorated'] #Update your class names here
44
+ predicted_class_index = 1 if prediction[0][0] >= 0.5 else 0 # Adjust threshold if needed
45
+ predicted_class = class_names[predicted_class_index]
46
 
47
  # Display the prediction
48
  st.write(f"Prediction: **{predicted_class}**")