ppicazo commited on
Commit
03d257e
·
verified ·
1 Parent(s): 3160c4d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -11
app.py CHANGED
@@ -1,27 +1,38 @@
1
  import gradio as gr
2
  from transformers import pipeline
 
3
 
4
- pipeline = pipeline(task="image-classification", model="bortle/astrophotography-object-classifier-alpha5")
 
 
 
 
5
 
6
  def predict(image):
7
- predictions = pipeline(image)
8
- return {p["label"]: p["score"] for p in predictions}
9
-
10
-
11
- def process_image(image):
12
  width = 1080
13
  ratio = width / image.width
14
  height = int(image.height * ratio)
15
  resized_image = image.resize((width, height))
16
- return resized_image
 
 
 
 
 
17
 
 
18
  gr.Interface(
19
- predict,
20
- fn=process_image,
21
  inputs=gr.Image(type="pil", label="Upload Astrophotography image"),
22
  outputs=gr.Label(num_top_classes=5),
23
  title="Astrophotography Object Classifier",
24
  allow_flagging="manual",
25
- examples=["examples/Andromeda.jpg", "examples/Heart.jpg", "examples/Pleiades.jpg", "examples/Rosette.jpg", "examples/Moon.jpg", "examples/GreatHercules.jpg", "examples/Leo-Triplet.jpg", "examples/Crab.jpg", "examples/North-America.jpg", "examples/Horsehead-Flame.jpg", "examples/Pinwheel.jpg", "examples/Saturn.jpg"],
 
 
 
 
 
26
  cache_examples=True
27
- ).launch()
 
1
  import gradio as gr
2
  from transformers import pipeline
3
+ from PIL import Image
4
 
5
+ # Load your model pipeline
6
+ model_pipeline = pipeline(
7
+ task="image-classification",
8
+ model="bortle/astrophotography-object-classifier-alpha5"
9
+ )
10
 
11
  def predict(image):
12
+ # Resize the image to have width 1080 while keeping aspect ratio
 
 
 
 
13
  width = 1080
14
  ratio = width / image.width
15
  height = int(image.height * ratio)
16
  resized_image = image.resize((width, height))
17
+
18
+ # Perform predictions
19
+ predictions = model_pipeline(resized_image)
20
+
21
+ # Return predictions as a dictionary
22
+ return {p["label"]: p["score"] for p in predictions}
23
 
24
+ # Define the Gradio Interface
25
  gr.Interface(
26
+ fn=predict,
 
27
  inputs=gr.Image(type="pil", label="Upload Astrophotography image"),
28
  outputs=gr.Label(num_top_classes=5),
29
  title="Astrophotography Object Classifier",
30
  allow_flagging="manual",
31
+ examples=[
32
+ "examples/Andromeda.jpg", "examples/Heart.jpg", "examples/Pleiades.jpg",
33
+ "examples/Rosette.jpg", "examples/Moon.jpg", "examples/GreatHercules.jpg",
34
+ "examples/Leo-Triplet.jpg", "examples/Crab.jpg", "examples/North-America.jpg",
35
+ "examples/Horsehead-Flame.jpg", "examples/Pinwheel.jpg", "examples/Saturn.jpg"
36
+ ],
37
  cache_examples=True
38
+ ).launch()