i-dont-hug-face commited on
Commit
47d7b51
1 Parent(s): 1550120

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +6 -7
inference.py CHANGED
@@ -4,9 +4,6 @@ from PIL import Image
4
  from transformers import DonutProcessor, VisionEncoderDecoderModel
5
  import io
6
  import json
7
- import logging
8
-
9
- logging.basicConfig(level=logging.INFO)
10
 
11
  def model_fn(model_dir, context=None):
12
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -18,16 +15,16 @@ def model_fn(model_dir, context=None):
18
  def input_fn(input_data, content_type, context=None):
19
  """Deserialize the input data."""
20
  logging.info("Entering input_fn")
21
- if content_type == 'application/json':
22
  image = Image.open(io.BytesIO(input_data))
23
  return image
24
  else:
25
  raise ValueError(f"Unsupported content type: {content_type}")
26
 
27
- def predict_fn(data, model, context=None):
28
  """Apply the model to the input data."""
29
  logging.info("Entering predict_fn")
30
- model, processor, device = model
31
 
32
  # Preprocess the image
33
  pixel_values = processor(data, return_tensors="pt").pixel_values.to(device)
@@ -60,6 +57,8 @@ def output_fn(prediction, accept):
60
  """Serialize the prediction output."""
61
  logging.info("Entering output_fn")
62
  if accept == 'application/json':
63
- return json.dumps(prediction)
64
  else:
65
  raise ValueError(f"Unsupported response content type: {accept}")
 
 
 
4
  from transformers import DonutProcessor, VisionEncoderDecoderModel
5
  import io
6
  import json
 
 
 
7
 
8
  def model_fn(model_dir, context=None):
9
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
15
  def input_fn(input_data, content_type, context=None):
16
  """Deserialize the input data."""
17
  logging.info("Entering input_fn")
18
+ if content_type == 'application/x-image' or content_type == 'application/octet-stream':
19
  image = Image.open(io.BytesIO(input_data))
20
  return image
21
  else:
22
  raise ValueError(f"Unsupported content type: {content_type}")
23
 
24
+ def predict_fn(data, model_data, context=None):
25
  """Apply the model to the input data."""
26
  logging.info("Entering predict_fn")
27
+ model, processor, device = model_data
28
 
29
  # Preprocess the image
30
  pixel_values = processor(data, return_tensors="pt").pixel_values.to(device)
 
57
  """Serialize the prediction output."""
58
  logging.info("Entering output_fn")
59
  if accept == 'application/json':
60
+ return json.dumps(prediction), 'application/json'
61
  else:
62
  raise ValueError(f"Unsupported response content type: {accept}")
63
+
64
+