Jangai commited on
Commit
7d27275
·
verified ·
1 Parent(s): 77e4539

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -7
app.py CHANGED
@@ -3,21 +3,39 @@ import torch
3
  from transformers import ViTImageProcessor, ViTForImageClassification
4
  from PIL import Image
5
  import numpy as np
 
 
 
 
6
 
7
  # Load the pre-trained model and image processor
8
  model_name = "google/vit-base-patch16-224"
 
9
  image_processor = ViTImageProcessor.from_pretrained(model_name)
10
  model = ViTForImageClassification.from_pretrained(model_name)
11
 
12
  # Define the prediction function
13
  def predict(image):
14
- # Convert the dictionary to a PIL image
15
- image = Image.fromarray(image.astype('uint8'), 'RGB')
16
- inputs = image_processor(images=image, return_tensors="pt")
17
- outputs = model(**inputs)
18
- logits = outputs.logits
19
- predicted_class_idx = logits.argmax(-1).item()
20
- return model.config.id2label[predicted_class_idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # Create the Gradio interface
23
  iface = gr.Interface(
 
3
  from transformers import ViTImageProcessor, ViTForImageClassification
4
  from PIL import Image
5
  import numpy as np
6
+ import logging
7
+
8
+ # Configure logging
9
+ logging.basicConfig(level=logging.DEBUG)
10
 
11
  # Load the pre-trained model and image processor
12
  model_name = "google/vit-base-patch16-224"
13
+ logging.info("Loading image processor and model...")
14
  image_processor = ViTImageProcessor.from_pretrained(model_name)
15
  model = ViTForImageClassification.from_pretrained(model_name)
16
 
17
  # Define the prediction function
18
  def predict(image):
19
+ try:
20
+ logging.info("Received image of type: %s", type(image))
21
+ # Convert the dictionary to a PIL image
22
+ if isinstance(image, dict):
23
+ logging.debug("Converting dictionary to NumPy array...")
24
+ image = np.array(image['image_data'])
25
+ logging.debug("Converting NumPy array to PIL image...")
26
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
27
+ logging.debug("Image converted successfully.")
28
+
29
+ logging.info("Processing image...")
30
+ inputs = image_processor(images=image, return_tensors="pt")
31
+ outputs = model(**inputs)
32
+ logits = outputs.logits
33
+ predicted_class_idx = logits.argmax(-1).item()
34
+ logging.info("Prediction successful.")
35
+ return model.config.id2label[predicted_class_idx]
36
+ except Exception as e:
37
+ logging.error("Error during prediction: %s", e)
38
+ return str(e)
39
 
40
  # Create the Gradio interface
41
  iface = gr.Interface(