Spaces:
Sleeping
Sleeping
import tensorflow as tf | |
from PIL import Image | |
from tensorflow import keras | |
import numpy as np | |
import os | |
import random | |
import logging | |
from tensorflow.keras.preprocessing import image as keras_image | |
from huggingface_hub import from_pretrained_keras | |
from openai import AzureOpenAI | |
import gradio as gr | |
from zipfile import ZipFile | |
logging.basicConfig(level=logging.INFO) | |
class DiseaseDetectionApp: | |
def __init__(self): | |
self.class_names =['Normal', 'Pneumonia'] | |
self.model =tf.keras.models.load_model("pneumonia_xray_prediction.keras") | |
self.client=AzureOpenAI() | |
def predict_disease(self, image_path): | |
""" | |
Predict the disease present in the X-Ray image. | |
Args: | |
- image_data: PIL image data | |
Returns: | |
- predicted_disease: string | |
""" | |
try: | |
# Load the image file, resizing it to the dimensions expected by the model | |
img = keras_image.load_img(image_path, target_size=(256, 256)) # Adjust target_size according to your model's expected input dimensions | |
# Convert the image to a numpy array | |
img_array = keras_image.img_to_array(img) | |
# Add an additional dimension to the array: (1, height, width, channels) | |
img_array = tf.expand_dims(img_array, 0) # Model expects a batch of images, but we're only passing a single image | |
# print(img_array) | |
# Make predictions | |
predictions = self.model.predict(img_array) | |
# Extract the predicted class and confidence | |
predict_class =self.class_names[np.argmax(predictions[0])] | |
confidence = round(100 * np.max(predictions[0]), 2) | |
return predict_class | |
except Exception as e: | |
logging.error(f"Error predicting disease: {str(e)}") | |
return None | |
def classify_disease(self,image_path): | |
disease_name=self.predict_disease(image_path) | |
print(disease_name) | |
if disease_name=="Pneumonia": | |
conversation = [ | |
{"role": "system", "content": "You are a medical assistant"}, | |
{"role": "user", "content": f""" your task describe(classify) about the given disease as a summary only in 5 lines. | |
```{disease_name}``` | |
"""} | |
] | |
# Generate completion using ChatGPT model | |
response = self.client.chat.completions.create( | |
model="GPT-4o", | |
messages=conversation, | |
temperature=0.4, | |
max_tokens=1000 | |
) | |
# Get the generated topics message | |
result = response.choices[0].message.content | |
return disease_name,result | |
elif disease_name=="Normal": | |
result="No problem in your xray image" | |
return disease_name,result | |
def unzip_image_data(self,filespath): | |
""" | |
Unzips an image dataset into a specified directory. | |
Returns: | |
str: The path to the directory containing the extracted image files. | |
""" | |
try: | |
with ZipFile(filespath,"r") as extract: | |
directory_path = random.randrange(100) | |
extract.extractall(f"{directory_path}") | |
return f"{directory_path}" | |
except Exception as e: | |
logging.error(f"An error occurred during extraction: {e}") | |
return "" | |
def example_images(self,filespath): | |
""" | |
Unzips the image dataset and generates a list of paths to the individual image files and use image for showing example | |
Returns: | |
List[str]: A list of file paths to each image in the dataset. | |
""" | |
image_dataset_folder = self.unzip_image_data(filespath) | |
image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'] | |
image_count = len([name for name in os.listdir(image_dataset_folder) if os.path.isfile(os.path.join(image_dataset_folder, name)) and os.path.splitext(name)[1].lower() in image_extensions]) | |
example=[] | |
for i in range(image_count): | |
for name in os.listdir(image_dataset_folder): | |
path=(os.path.join(os.path.dirname(image_dataset_folder),os.path.join(image_dataset_folder,name))) | |
example.append(path) | |
return example | |
def get_example_image(self): | |
normal_image="Normal_dataset.zip" | |
tuberclosis_image="Pnemonia_dataset.zip" | |
normal_image_unziped=self.example_images(normal_image) | |
tuberclosis_image_unziped=self.example_images(tuberclosis_image) | |
return normal_image_unziped,tuberclosis_image_unziped | |
def gradio_interface(self): | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.HTML("""<center><h1>Pneumonia Disease Detection</h1></center>""") | |
normal_image,tuberclosis_image=self.get_example_image() | |
with gr.Row(): | |
input_image =gr.Image(type="filepath",sources="upload") | |
with gr.Column(): | |
output=gr.Label(label="Disease Name") | |
with gr.Row(): | |
classify_disease_=gr.Textbox(label="About disease") | |
with gr.Row(): | |
button =gr.Button(value="Detect The Disease") | |
button.click(self.classify_disease,[input_image],[output,classify_disease_]) | |
gr.Examples( | |
examples=normal_image, | |
label="Normal X-ray Images", | |
inputs=[input_image], | |
outputs=[output,classify_disease_], | |
fn=self.classify_disease, | |
examples_per_page=5, | |
cache_examples=False) | |
gr.Examples( | |
examples=tuberclosis_image, | |
label="Pneumonia X-ray Images", | |
inputs=[input_image], | |
outputs=[output,classify_disease_], | |
examples_per_page=5, | |
fn=self.classify_disease, | |
cache_examples=False) | |
demo.launch(debug=True) | |
if __name__ == "__main__": | |
app = DiseaseDetectionApp() | |
result=app.gradio_interface() | |