VishalD1234 commited on
Commit
976dad6
1 Parent(s): 4183ac5

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +5 -69
handler.py CHANGED
@@ -1,79 +1,15 @@
1
- import torch
2
  from transformers import AutoFeatureExtractor, AutoModelForImageClassification
3
- from PIL import Image
4
- import json
5
- import base64
6
- from io import BytesIO
7
 
8
- # Load the model and feature extractor when the handler is initialized
9
  class VisionModelHandler:
10
  def __init__(self, model_name_or_path="https://huggingface.co/VishalD1234/Florence-metere1"):
11
  self.model_name_or_path = model_name_or_path
12
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
- # Load model and feature extractor
15
- self.model = AutoModelForImageClassification.from_pretrained(self.model_name_or_path)
16
- self.feature_extractor = AutoFeatureExtractor.from_pretrained(self.model_name_or_path)
17
 
18
- # Move the model to the appropriate device (GPU/CPU)
19
  self.model.to(self.device)
20
- self.model.eval() # Set model to evaluation mode
21
-
22
- def preprocess_image(self, image_data):
23
- """
24
- Preprocess the image for the model. Convert it from base64 and apply the necessary transformations.
25
- """
26
- image = Image.open(BytesIO(base64.b64decode(image_data)))
27
- inputs = self.feature_extractor(images=image, return_tensors="pt").to(self.device)
28
- return inputs
29
-
30
- def predict(self, inputs):
31
- """
32
- Perform inference and return the model's predictions.
33
- """
34
- with torch.no_grad():
35
- outputs = self.model(**inputs)
36
- logits = outputs.logits
37
- predicted_class_idx = logits.argmax(-1).item() # Get the index of the highest score
38
- return predicted_class_idx
39
-
40
- def handle(self, event, context):
41
- """
42
- Entry point for the inference request. This will be called by the inference endpoint.
43
-
44
- Args:
45
- event: This will contain the input data, usually in the form of a JSON with an image in base64.
46
- context: Optional, can contain metadata about the request (not used here).
47
-
48
- Returns:
49
- A JSON response with the prediction result.
50
- """
51
- # Extract image data from the request body
52
- body = json.loads(event.get("body", "{}"))
53
- image_data = body.get("image_base64", None)
54
-
55
- if image_data is None:
56
- return {
57
- "statusCode": 400,
58
- "body": json.dumps({"error": "No image data found in the request"})
59
- }
60
-
61
- # Preprocess the image and make predictions
62
- inputs = self.preprocess_image(image_data)
63
- prediction = self.predict(inputs)
64
-
65
- # You can add more details to this mapping depending on your use case
66
- response = {
67
- "statusCode": 200,
68
- "body": json.dumps({"prediction": prediction})
69
- }
70
-
71
- return response
72
-
73
-
74
- # Instantiate the handler
75
- vision_model_handler = VisionModelHandler()
76
 
77
- # If running in an environment like AWS Lambda or Sagemaker, ensure this is exposed
78
- def lambda_handler(event, context):
79
- return vision_model_handler.handle(event, context)
 
 
1
  from transformers import AutoFeatureExtractor, AutoModelForImageClassification
 
 
 
 
2
 
 
3
  class VisionModelHandler:
4
  def __init__(self, model_name_or_path="https://huggingface.co/VishalD1234/Florence-metere1"):
5
  self.model_name_or_path = model_name_or_path
6
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
 
8
+ # Add the trust_remote_code parameter
9
+ self.model = AutoModelForImageClassification.from_pretrained(self.model_name_or_path, trust_remote_code=True)
10
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(self.model_name_or_path, trust_remote_code=True)
11
 
 
12
  self.model.to(self.device)
13
+ self.model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ # Rest of the code stays the same...