amjadfqs commited on
Commit
07952a9
·
verified ·
1 Parent(s): 22a9bc1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -15
app.py CHANGED
@@ -1,23 +1,25 @@
1
  import gradio as gr
2
- import requests
 
3
  from PIL import Image
4
 
5
- def load_model(model_name):
6
- # Use the Hugging Face API to load the model
7
- api_url = f"https://huggingface.co/{model_name}"
8
- response = requests.get(api_url)
9
- if response.status_code == 200:
10
- # Assume the model is an image processing model
11
- return lambda image: image # Replace with actual model processing code
12
- else:
13
- raise ValueError("Model not found")
14
-
15
- # Load the model once during initialization
16
- model = load_model("amjadfqs/finalProject")
17
 
18
  def predict(image):
19
- # Use the model to make a prediction
20
- return model(image)
 
 
 
 
 
 
 
 
 
21
 
22
  # Set up the Gradio interface
23
  image_cp = gr.Image(type="pil", label='Brain')
 
1
  import gradio as gr
2
+ from transformers import AutoModelForImageClassification, AutoFeatureExtractor
3
+ import torch
4
  from PIL import Image
5
 
6
+ # Load the model and feature extractor once during initialization
7
+ model_name = "amjadfqs/finalProject"
8
+ model = AutoModelForImageClassification.from_pretrained(model_name)
9
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
 
 
 
 
 
 
 
 
10
 
11
  def predict(image):
12
+ # Preprocess the image
13
+ inputs = feature_extractor(images=image, return_tensors="pt")
14
+ # Make prediction
15
+ with torch.no_grad():
16
+ outputs = model(**inputs)
17
+ logits = outputs.logits
18
+ # Get the predicted class
19
+ predicted_class = logits.argmax(-1).item()
20
+ # You may need to adjust the following line based on your class labels
21
+ class_names = ["class1", "class2", "class3", "class4"]
22
+ return predicted_class
23
 
24
  # Set up the Gradio interface
25
  image_cp = gr.Image(type="pil", label='Brain')