import gradio as gr
import predict as predict
import extract_gradcam as extract_gradcam
import extract_lime as extract_lime
def upload_file(files):
file_paths = [file.name for file in files]
return file_paths
def process_file(
upload_filepath,
gradcam_toggle,
lime_toggle,
):
print("Upload filepath:", upload_filepath)
print("GradCAM toggle:", gradcam_toggle)
print("LIME toggle:", lime_toggle)
result = []
sorted_classes = predict.predict_image(upload_filepath)
for class_label, class_prob in sorted_classes:
class_prob = class_prob.item().__round__(2)
result.append(f"{class_label}: {class_prob}%")
result = result[:4]
if gradcam_toggle == True:
cam = extract_gradcam.extract_gradcam(upload_filepath, save_path="gradcam.jpg")
result.append("gradcam.jpg")
else:
result.append(None)
if lime_toggle == True:
lime = extract_lime.generate_lime(upload_filepath, save_path="lime.jpg")
result.append("lime.jpg")
else:
result.append(None)
return result
css = """
.block {
margin-left: auto;
margin-right: auto;
width: 100%;
}
#image_input {
width: 300px !important;
height: 300px !important;
}
#image_input img {
width: 300px !important;
height: 300px !important;
}
.output-image {
width: 70% !important;
text-align: -webkit-center !important;
}
.output-image img {
width: 300px !important;
}
.toggle {
width: 17% !important;
}
.show-api {
visibility: hidden !important;
}
.built-with {
visibility: hidden !important;
}
#title-label {
font-size: 35px !important;
text-align: -webkit-center !important;
margin-block-end: -55px;
}
#desc-label {
font-size: 15px !important;
text-align: -webkit-center !important;
}
.output-class.svelte-75gm11.svelte-75gm11.svelte-75gm11 {
font-size: unset !important;
}
"""
block = gr.Blocks(title="NeuroDetect", css=css, theme="gradio/soft")
block.queue()
with block as demo:
with gr.Column():
gr.Label("Spiral-ND", elem_id="title-label", show_label=False)
# gr.Label(
# "A Stress-free, Portable, and Cost-effective Machine Learning-Powered Web Application for Early Detection of Multiple Neurological Disorders through Spiral Drawing Analysis",
# elem_id="desc-label",
# show_label=False,
# )
gr.Markdown("
")
with gr.Row():
image_input = gr.Image(
type="filepath",
label="Choose Image",
sources="upload",
elem_id="image_input",
)
with gr.Column():
gr.Markdown(
"""