i-dont-hug-face
commited on
Commit
•
86209be
1
Parent(s):
8a8276e
Update inference.py
Browse files- inference.py +6 -5
inference.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import torch
|
2 |
import re
|
3 |
from PIL import Image
|
@@ -6,16 +7,16 @@ import base64
|
|
6 |
import io
|
7 |
import json
|
8 |
|
9 |
-
def model_fn(model_dir):
|
10 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
11 |
processor = DonutProcessor.from_pretrained(model_dir)
|
12 |
model = VisionEncoderDecoderModel.from_pretrained(model_dir)
|
13 |
model.to(device)
|
14 |
return model, processor, device
|
15 |
|
16 |
-
def transform_fn(model, request_body,
|
17 |
model, processor, device = model
|
18 |
-
if
|
19 |
data = json.loads(request_body)
|
20 |
image_data = data['inputs']
|
21 |
image = Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB")
|
@@ -46,6 +47,6 @@ def transform_fn(model, request_body, input_content_type, output_content_type):
|
|
46 |
|
47 |
# Prepare the response
|
48 |
prediction = {'result': decoded_text}
|
49 |
-
return json.dumps(prediction),
|
50 |
else:
|
51 |
-
raise ValueError(f"Unsupported content type: {
|
|
|
1 |
+
import os
|
2 |
import torch
|
3 |
import re
|
4 |
from PIL import Image
|
|
|
7 |
import io
|
8 |
import json
|
9 |
|
10 |
+
def model_fn(model_dir, context=None):
|
11 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
12 |
processor = DonutProcessor.from_pretrained(model_dir)
|
13 |
model = VisionEncoderDecoderModel.from_pretrained(model_dir)
|
14 |
model.to(device)
|
15 |
return model, processor, device
|
16 |
|
17 |
+
def transform_fn(model, request_body, request_content_type, response_content_type, context=None):
|
18 |
model, processor, device = model
|
19 |
+
if request_content_type == 'application/json':
|
20 |
data = json.loads(request_body)
|
21 |
image_data = data['inputs']
|
22 |
image = Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB")
|
|
|
47 |
|
48 |
# Prepare the response
|
49 |
prediction = {'result': decoded_text}
|
50 |
+
return json.dumps(prediction), response_content_type
|
51 |
else:
|
52 |
+
raise ValueError(f"Unsupported content type: {request_content_type}")
|