VishalD1234 commited on
Commit
4183ac5
1 Parent(s): 28421c3

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +70 -39
handler.py CHANGED
@@ -1,48 +1,79 @@
1
- from typing import Dict, List, Any
2
- from PIL import Image
3
  import torch
4
- import os
 
 
 
5
  from io import BytesIO
6
- from transformers import BlipForConditionalGeneration, BlipProcessor
7
- # -
8
-
9
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
 
11
- class EndpointHandler():
12
- def __init__(self, path=""):
13
- # load the optimized model
 
 
14
 
15
- self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
16
- self.model = BlipForConditionalGeneration.from_pretrained(
17
- "Salesforce/blip-image-captioning-base"
18
- ).to(device)
19
- self.model.eval()
20
- self.model = self.model.to(device)
21
 
22
-
23
-
24
- def __call__(self, data: Any) -> Dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  """
 
 
26
  Args:
27
- data (:obj:):
28
- includes the input data and the parameters for the inference.
29
- Return:
30
- A :obj:`dict`:. The object returned should be a dict of one list like {"captions": ["A hugging face at the office"]} containing :
31
- - "caption": A string corresponding to the generated caption.
32
  """
33
- inputs = data.pop("inputs", data)
34
- parameters = data.pop("parameters", {})
35
-
36
- raw_images = [Image.open(BytesIO(_img)) for _img in inputs]
37
-
38
- processed_image = self.processor(images=raw_images, return_tensors="pt")
39
- processed_image["pixel_values"] = processed_image["pixel_values"].to(device)
40
- processed_image = {**processed_image, **parameters}
41
 
42
- with torch.no_grad():
43
- out = self.model.generate(
44
- **processed_image
45
- )
46
- captions = self.processor.batch_decode(out, skip_special_tokens=True)
47
- # postprocess the prediction
48
- return {"captions": captions}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
3
+ from PIL import Image
4
+ import json
5
+ import base64
6
  from io import BytesIO
 
 
 
 
7
 
8
+ # Load the model and feature extractor when the handler is initialized
9
+ class VisionModelHandler:
10
+ def __init__(self, model_name_or_path="https://huggingface.co/VishalD1234/Florence-metere1"):
11
+ self.model_name_or_path = model_name_or_path
12
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
+ # Load model and feature extractor
15
+ self.model = AutoModelForImageClassification.from_pretrained(self.model_name_or_path)
16
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(self.model_name_or_path)
 
 
 
17
 
18
+ # Move the model to the appropriate device (GPU/CPU)
19
+ self.model.to(self.device)
20
+ self.model.eval() # Set model to evaluation mode
21
+
22
+ def preprocess_image(self, image_data):
23
+ """
24
+ Preprocess the image for the model. Convert it from base64 and apply the necessary transformations.
25
+ """
26
+ image = Image.open(BytesIO(base64.b64decode(image_data)))
27
+ inputs = self.feature_extractor(images=image, return_tensors="pt").to(self.device)
28
+ return inputs
29
+
30
+ def predict(self, inputs):
31
+ """
32
+ Perform inference and return the model's predictions.
33
+ """
34
+ with torch.no_grad():
35
+ outputs = self.model(**inputs)
36
+ logits = outputs.logits
37
+ predicted_class_idx = logits.argmax(-1).item() # Get the index of the highest score
38
+ return predicted_class_idx
39
+
40
+ def handle(self, event, context):
41
  """
42
+ Entry point for the inference request. This will be called by the inference endpoint.
43
+
44
  Args:
45
+ event: This will contain the input data, usually in the form of a JSON with an image in base64.
46
+ context: Optional, can contain metadata about the request (not used here).
47
+
48
+ Returns:
49
+ A JSON response with the prediction result.
50
  """
51
+ # Extract image data from the request body
52
+ body = json.loads(event.get("body", "{}"))
53
+ image_data = body.get("image_base64", None)
 
 
 
 
 
54
 
55
+ if image_data is None:
56
+ return {
57
+ "statusCode": 400,
58
+ "body": json.dumps({"error": "No image data found in the request"})
59
+ }
60
+
61
+ # Preprocess the image and make predictions
62
+ inputs = self.preprocess_image(image_data)
63
+ prediction = self.predict(inputs)
64
+
65
+ # You can add more details to this mapping depending on your use case
66
+ response = {
67
+ "statusCode": 200,
68
+ "body": json.dumps({"prediction": prediction})
69
+ }
70
+
71
+ return response
72
+
73
+
74
+ # Instantiate the handler
75
+ vision_model_handler = VisionModelHandler()
76
+
77
+ # If running in an environment like AWS Lambda or Sagemaker, ensure this is exposed
78
+ def lambda_handler(event, context):
79
+ return vision_model_handler.handle(event, context)