i-dont-hug-face
commited on
Commit
•
47d7b51
1
Parent(s):
1550120
Update inference.py
Browse files- inference.py +6 -7
inference.py
CHANGED
@@ -4,9 +4,6 @@ from PIL import Image
|
|
4 |
from transformers import DonutProcessor, VisionEncoderDecoderModel
|
5 |
import io
|
6 |
import json
|
7 |
-
import logging
|
8 |
-
|
9 |
-
logging.basicConfig(level=logging.INFO)
|
10 |
|
11 |
def model_fn(model_dir, context=None):
|
12 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
@@ -18,16 +15,16 @@ def model_fn(model_dir, context=None):
|
|
18 |
def input_fn(input_data, content_type, context=None):
|
19 |
"""Deserialize the input data."""
|
20 |
logging.info("Entering input_fn")
|
21 |
-
if content_type == 'application/
|
22 |
image = Image.open(io.BytesIO(input_data))
|
23 |
return image
|
24 |
else:
|
25 |
raise ValueError(f"Unsupported content type: {content_type}")
|
26 |
|
27 |
-
def predict_fn(data,
|
28 |
"""Apply the model to the input data."""
|
29 |
logging.info("Entering predict_fn")
|
30 |
-
model, processor, device =
|
31 |
|
32 |
# Preprocess the image
|
33 |
pixel_values = processor(data, return_tensors="pt").pixel_values.to(device)
|
@@ -60,6 +57,8 @@ def output_fn(prediction, accept):
|
|
60 |
"""Serialize the prediction output."""
|
61 |
logging.info("Entering output_fn")
|
62 |
if accept == 'application/json':
|
63 |
-
return json.dumps(prediction)
|
64 |
else:
|
65 |
raise ValueError(f"Unsupported response content type: {accept}")
|
|
|
|
|
|
4 |
from transformers import DonutProcessor, VisionEncoderDecoderModel
|
5 |
import io
|
6 |
import json
|
|
|
|
|
|
|
7 |
|
8 |
def model_fn(model_dir, context=None):
|
9 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
15 |
def input_fn(input_data, content_type, context=None):
|
16 |
"""Deserialize the input data."""
|
17 |
logging.info("Entering input_fn")
|
18 |
+
if content_type == 'application/x-image' or content_type == 'application/octet-stream':
|
19 |
image = Image.open(io.BytesIO(input_data))
|
20 |
return image
|
21 |
else:
|
22 |
raise ValueError(f"Unsupported content type: {content_type}")
|
23 |
|
24 |
+
def predict_fn(data, model_data, context=None):
|
25 |
"""Apply the model to the input data."""
|
26 |
logging.info("Entering predict_fn")
|
27 |
+
model, processor, device = model_data
|
28 |
|
29 |
# Preprocess the image
|
30 |
pixel_values = processor(data, return_tensors="pt").pixel_values.to(device)
|
|
|
57 |
"""Serialize the prediction output."""
|
58 |
logging.info("Entering output_fn")
|
59 |
if accept == 'application/json':
|
60 |
+
return json.dumps(prediction), 'application/json'
|
61 |
else:
|
62 |
raise ValueError(f"Unsupported response content type: {accept}")
|
63 |
+
|
64 |
+
|