jays009 commited on
Commit
1f6dbae
·
verified ·
1 Parent(s): 2251f70

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -68
app.py CHANGED
@@ -1,80 +1,32 @@
1
- import gradio as gr
2
- import json
3
- import torch
4
- from torch import nn
5
- from torchvision import models, transforms
6
- from huggingface_hub import hf_hub_download
7
- from PIL import Image
8
- import requests
9
- import base64
10
- from io import BytesIO
11
- import os
12
-
13
- # Define the number of classes
14
- num_classes = 2
15
-
16
- # Download model from Hugging Face
17
- def download_model():
18
- try:
19
- model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
20
- return model_path
21
- except Exception as e:
22
- print(f"Error downloading model: {e}")
23
- return None
24
-
25
- # Load the model from Hugging Face
26
- def load_model(model_path):
27
  try:
28
- model = models.resnet50(pretrained=False)
29
- model.fc = nn.Linear(model.fc.in_features, num_classes)
30
- model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
31
- model.eval()
32
- return model
33
- except Exception as e:
34
- print(f"Error loading model: {e}")
35
- return None
36
-
37
- # Download the model and load it
38
- model_path = download_model()
39
- model = load_model(model_path) if model_path else None
40
 
41
- # Define the transformation for the input image
42
- transform = transforms.Compose([
43
- transforms.Resize(256),
44
- transforms.CenterCrop(224),
45
- transforms.ToTensor(),
46
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
47
- ])
48
-
49
- def predict(image):
50
- try:
51
- print(f"Received image input: {image}")
52
 
53
  # Check if the input is a PIL Image type
54
- if isinstance(image, Image.Image):
55
- print(f"Image is already loaded as PIL Image: {image}")
56
  else:
57
  # Check if the input contains a base64-encoded string or URL
58
- if isinstance(image, dict) and image.get("data"):
59
- image_data = image["data"]
60
- if image_data.startswith("http"): # URL case
61
- response = requests.get(image_data)
62
  image = Image.open(BytesIO(response.content))
63
  print(f"Fetched image from URL: {image}")
64
- else: # Base64-encoded image case
65
- image_data = base64.b64decode(image_data)
 
 
 
 
66
  image = Image.open(BytesIO(image_data))
67
  print(f"Decoded base64 image: {image}")
68
- elif isinstance(image, str) and image.startswith("http"):
69
- response = requests.get(image)
70
- image = Image.open(BytesIO(response.content))
71
- print(f"Fetched image from URL: {image}")
72
- elif isinstance(image, str) and os.path.isfile(image):
73
- image = Image.open(image)
74
- print(f"Loaded image from local path: {image}")
75
- else:
76
- print("Invalid image format received.")
77
- return json.dumps({"error": "Invalid image format received."})
78
 
79
  # Apply transformations
80
  image = transform(image).unsqueeze(0)
@@ -109,4 +61,3 @@ iface = gr.Interface(
109
 
110
  # Launch the Gradio interface
111
  iface.launch(share=True, show_error=True)
112
-
 
1
+ def predict(data):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  try:
3
+ image_input = data.get('image', None)
4
+ if not image_input:
5
+ return json.dumps({"error": "No image provided."})
 
 
 
 
 
 
 
 
 
6
 
7
+ print(f"Received image input: {image_input}")
 
 
 
 
 
 
 
 
 
 
8
 
9
  # Check if the input is a PIL Image type
10
+ if isinstance(image_input, Image.Image):
11
+ print(f"Image is already loaded as PIL Image: {image_input}")
12
  else:
13
  # Check if the input contains a base64-encoded string or URL
14
+ if image_input.startswith("http"): # URL case
15
+ try:
16
+ response = requests.get(image_input)
 
17
  image = Image.open(BytesIO(response.content))
18
  print(f"Fetched image from URL: {image}")
19
+ except Exception as e:
20
+ print(f"Error fetching image from URL: {e}")
21
+ return json.dumps({"error": f"Error fetching image from URL: {e}"})
22
+ else: # Assuming it is base64-encoded image data
23
+ try:
24
+ image_data = base64.b64decode(image_input)
25
  image = Image.open(BytesIO(image_data))
26
  print(f"Decoded base64 image: {image}")
27
+ except Exception as e:
28
+ print(f"Error decoding base64 image: {e}")
29
+ return json.dumps({"error": f"Error decoding base64 image: {e}"})
 
 
 
 
 
 
 
30
 
31
  # Apply transformations
32
  image = transform(image).unsqueeze(0)
 
61
 
62
  # Launch the Gradio interface
63
  iface.launch(share=True, show_error=True)