yolac commited on
Commit
8a4cd9d
·
verified ·
1 Parent(s): 616be32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -26
app.py CHANGED
@@ -1,31 +1,28 @@
1
- import gradio as gr
 
2
  import tensorflow as tf
3
- import numpy as np
4
- from tensorflow.keras.utils import load_img, img_to_array
5
 
6
- # Load the pre-trained model
7
- model = tf.keras.models.load_model('https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model')
 
 
8
 
9
- # Define a function for making predictions
10
- def predict_bacterial_morphology(image):
11
- # Preprocess the input image
12
- img_array = img_to_array(image) / 255.0
13
- img_array = np.expand_dims(img_array, axis=0)
14
- # Make a prediction
15
- prediction = model.predict(img_array, verbose=0)
16
- class_labels = ['cocci', 'bacilli', 'spirilla']
17
- predicted_label = class_labels[np.argmax(prediction)]
18
- return predicted_label
19
 
20
- # Define the Gradio interface
21
- iface = gr.Interface(
22
- fn=predict_bacterial_morphology,
23
- inputs=gr.inputs.Image(shape=(224, 224)),
24
- outputs=gr.outputs.Textbox(label="Predicted Class"),
25
- title="Bacterial Morphology Classification",
26
- description="Upload an image of a bacterium to classify it into one of the following categories: cocci, bacilli, or spirilla."
27
- )
28
 
29
- # Launch the app
30
- if __name__ == "__main__":
31
- iface.launch()
 
1
+ import requests
2
+ import tempfile
3
  import tensorflow as tf
 
 
4
 
5
+ # URL to the model file in your Hugging Face repository
6
+ url = 'https://huggingface.co/datasets/yolac/BacterialMorphologyClassification'
7
+ # Download the model file
8
+ response = requests.get(url, stream=True)
9
 
10
+ # Check if the download was successful
11
+ if response.status_code == 200:
12
+ # Create a temporary file to save the model
13
+ with tempfile.NamedTemporaryFile(suffix=".h5", delete=False) as temp_file:
14
+ for chunk in response.iter_content(chunk_size=8192):
15
+ if chunk:
16
+ temp_file.write(chunk)
17
+ temp_file_path = temp_file.name # Get the path to the temporary file
 
 
18
 
19
+ # Load the pre-trained model
20
+ try:
21
+ model = tf.keras.models.load_model(temp_file_path)
22
+ print("Model loaded successfully.")
23
+ except OSError as e:
24
+ print(f"Error loading the model: {e}")
25
+ else:
26
+ print("Failed to download the model. Status code:", response.status_code)
27
 
28
+ # Now you can use the `model` object for predictions or further processing.