i-dont-hug-face commited on
Commit
86209be
1 Parent(s): 8a8276e

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +6 -5
inference.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  import re
3
  from PIL import Image
@@ -6,16 +7,16 @@ import base64
6
  import io
7
  import json
8
 
9
- def model_fn(model_dir):
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
  processor = DonutProcessor.from_pretrained(model_dir)
12
  model = VisionEncoderDecoderModel.from_pretrained(model_dir)
13
  model.to(device)
14
  return model, processor, device
15
 
16
- def transform_fn(model, request_body, input_content_type, output_content_type):
17
  model, processor, device = model
18
- if input_content_type == 'application/json':
19
  data = json.loads(request_body)
20
  image_data = data['inputs']
21
  image = Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB")
@@ -46,6 +47,6 @@ def transform_fn(model, request_body, input_content_type, output_content_type):
46
 
47
  # Prepare the response
48
  prediction = {'result': decoded_text}
49
- return json.dumps(prediction), output_content_type
50
  else:
51
- raise ValueError(f"Unsupported content type: {input_content_type}")
 
1
+ import os
2
  import torch
3
  import re
4
  from PIL import Image
 
7
  import io
8
  import json
9
 
10
+ def model_fn(model_dir, context=None):
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
  processor = DonutProcessor.from_pretrained(model_dir)
13
  model = VisionEncoderDecoderModel.from_pretrained(model_dir)
14
  model.to(device)
15
  return model, processor, device
16
 
17
+ def transform_fn(model, request_body, request_content_type, response_content_type, context=None):
18
  model, processor, device = model
19
+ if request_content_type == 'application/json':
20
  data = json.loads(request_body)
21
  image_data = data['inputs']
22
  image = Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB")
 
47
 
48
  # Prepare the response
49
  prediction = {'result': decoded_text}
50
+ return json.dumps(prediction), response_content_type
51
  else:
52
+ raise ValueError(f"Unsupported content type: {request_content_type}")