i-dont-hug-face commited on
Commit
8a8276e
1 Parent(s): e49dc7d

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +33 -32
inference.py CHANGED
@@ -1,10 +1,10 @@
1
- import os
2
  import torch
3
  import re
4
  from PIL import Image
5
  from transformers import DonutProcessor, VisionEncoderDecoderModel
6
  import base64
7
  import io
 
8
 
9
  def model_fn(model_dir):
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -13,38 +13,39 @@ def model_fn(model_dir):
13
  model.to(device)
14
  return model, processor, device
15
 
16
- def input_fn(request_body, content_type):
17
- if content_type == 'application/json':
 
18
  data = json.loads(request_body)
19
  image_data = data['inputs']
20
  image = Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB")
21
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  else:
23
- raise ValueError(f"Unsupported content type: {content_type}")
24
-
25
- def predict_fn(input_data, model_and_processor):
26
- model, processor, device = model_and_processor
27
- pixel_values = processor(input_data, return_tensors="pt").pixel_values.to(device)
28
-
29
- model.eval()
30
- with torch.no_grad():
31
- task_prompt = "<s_receipt>"
32
- decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device)
33
- generated_outputs = model.generate(
34
- pixel_values,
35
- decoder_input_ids=decoder_input_ids,
36
- max_length=model.config.decoder.max_position_embeddings,
37
- pad_token_id=processor.tokenizer.pad_token_id,
38
- eos_token_id=processor.tokenizer.eos_token_id,
39
- early_stopping=True,
40
- bad_words_ids=[[processor.tokenizer.unk_token_id]],
41
- return_dict_in_generate=True
42
- )
43
-
44
- decoded_text = processor.batch_decode(generated_outputs.sequences)[0]
45
- decoded_text = decoded_text.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
46
- decoded_text = re.sub(r"<.*?>", "", decoded_text, count=1).strip()
47
- return decoded_text
48
-
49
- def output_fn(prediction, accept):
50
- return json.dumps({'result': prediction}), accept
 
 
1
  import torch
2
  import re
3
  from PIL import Image
4
  from transformers import DonutProcessor, VisionEncoderDecoderModel
5
  import base64
6
  import io
7
+ import json
8
 
9
  def model_fn(model_dir):
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
13
  model.to(device)
14
  return model, processor, device
15
 
16
+ def transform_fn(model, request_body, input_content_type, output_content_type):
17
+ model, processor, device = model
18
+ if input_content_type == 'application/json':
19
  data = json.loads(request_body)
20
  image_data = data['inputs']
21
  image = Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB")
22
+
23
+ # Preprocess the image
24
+ pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
25
+
26
+ # Run inference
27
+ model.eval()
28
+ with torch.no_grad():
29
+ task_prompt = "<s_receipt>"
30
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device)
31
+ generated_outputs = model.generate(
32
+ pixel_values,
33
+ decoder_input_ids=decoder_input_ids,
34
+ max_length=model.config.decoder.max_position_embeddings,
35
+ pad_token_id=processor.tokenizer.pad_token_id,
36
+ eos_token_id=processor.tokenizer.eos_token_id,
37
+ early_stopping=True,
38
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
39
+ return_dict_in_generate=True
40
+ )
41
+
42
+ # Decode the output
43
+ decoded_text = processor.batch_decode(generated_outputs.sequences)[0]
44
+ decoded_text = decoded_text.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
45
+ decoded_text = re.sub(r"<.*?>", "", decoded_text, count=1).strip()
46
+
47
+ # Prepare the response
48
+ prediction = {'result': decoded_text}
49
+ return json.dumps(prediction), output_content_type
50
  else:
51
+ raise ValueError(f"Unsupported content type: {input_content_type}")