yolac commited on
Commit
eae2c3c
·
verified ·
1 Parent(s): 864e1fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -20
app.py CHANGED
@@ -1,28 +1,39 @@
 
 
1
  import requests
2
  import tempfile
3
- import tensorflow as tf
4
 
5
- # URL to the model file in your Hugging Face repository
6
- url = 'https://huggingface.co/datasets/yolac/BacterialMorphologyClassification'
7
- # Download the model file
8
- response = requests.get(url, stream=True)
9
 
10
- # Check if the download was successful
 
11
  if response.status_code == 200:
12
- # Create a temporary file to save the model
13
- with tempfile.NamedTemporaryFile(suffix=".h5", delete=False) as temp_file:
14
- for chunk in response.iter_content(chunk_size=8192):
15
- if chunk:
16
- temp_file.write(chunk)
17
- temp_file_path = temp_file.name # Get the path to the temporary file
18
-
19
- # Load the pre-trained model
20
- try:
21
  model = tf.keras.models.load_model(temp_file_path)
22
- print("Model loaded successfully.")
23
- except OSError as e:
24
- print(f"Error loading the model: {e}")
25
  else:
26
- print("Failed to download the model. Status code:", response.status_code)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- # Now you can use the `model` object for predictions or further processing.
 
 
1
+ import gradio as gr
2
+ import tensorflow as tf
3
  import requests
4
  import tempfile
 
5
 
6
+ # Define the URL for your model and dataset
7
+ model_url = 'https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/bacterial_morphology_classification_model.h5'
 
 
8
 
9
+ # Download the model from the URL and load it into a temporary file
10
+ response = requests.get(model_url, stream=True)
11
  if response.status_code == 200:
12
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.h5') as temp_file:
13
+ temp_file.write(response.content)
14
+ temp_file_path = temp_file.name
 
 
 
 
 
 
15
  model = tf.keras.models.load_model(temp_file_path)
 
 
 
16
  else:
17
+ raise Exception(f"Failed to download the model. Status code: {response.status_code}")
18
+
19
+ # Define the function for prediction
20
+ def predict_image(img):
21
+ img = img.resize((224, 224))
22
+ img_array = tf.keras.preprocessing.image.img_to_array(img) / 255.0
23
+ img_array = tf.expand_dims(img_array, axis=0)
24
+ prediction = model.predict(img_array)
25
+ classes = ['cocci', 'bacilli', 'spirilla']
26
+ predicted_class = classes[prediction.argmax()]
27
+ return predicted_class
28
+
29
+ # Create the Gradio interface
30
+ iface = gr.Interface(
31
+ fn=predict_image,
32
+ inputs=gr.inputs.Image(shape=(224, 224)),
33
+ outputs=gr.outputs.Label(num_top_classes=3),
34
+ title="Bacterial Morphology Classification",
35
+ description="Upload an image of bacterial morphology to classify it as cocci, bacilli, or spirilla."
36
+ )
37
 
38
+ # Launch the Gradio app
39
+ iface.launch()