|
import gradio as gr |
|
from transformers import pipeline |
|
from PIL import Image |
|
|
|
|
|
model_pipeline = pipeline( |
|
task="image-classification", |
|
model="bortle/astrophotography-object-classifier-alpha5" |
|
) |
|
|
|
def predict(image): |
|
|
|
width = 1080 |
|
ratio = width / image.width |
|
height = int(image.height * ratio) |
|
resized_image = image.resize((width, height)) |
|
|
|
|
|
predictions = model_pipeline(resized_image) |
|
|
|
|
|
return {p["label"]: p["score"] for p in predictions} |
|
|
|
|
|
gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil", label="Upload Astrophotography image"), |
|
outputs=gr.Label(num_top_classes=5), |
|
title="Astrophotography Object Classifier", |
|
allow_flagging="manual", |
|
examples=[ |
|
"examples/Andromeda.jpg", "examples/Heart.jpg", "examples/Pleiades.jpg", |
|
"examples/Rosette.jpg", "examples/Moon.jpg", "examples/GreatHercules.jpg", |
|
"examples/Leo-Triplet.jpg", "examples/Crab.jpg", "examples/North-America.jpg", |
|
"examples/Horsehead-Flame.jpg", "examples/Pinwheel.jpg", "examples/Saturn.jpg" |
|
], |
|
cache_examples=True |
|
).launch() |
|
|