i-dont-hug-face commited on
Commit
e49dc7d
1 Parent(s): 9fcd911

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +50 -0
inference.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import re
4
+ from PIL import Image
5
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
6
+ import base64
7
+ import io
8
+
9
+ def model_fn(model_dir):
10
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+ processor = DonutProcessor.from_pretrained(model_dir)
12
+ model = VisionEncoderDecoderModel.from_pretrained(model_dir)
13
+ model.to(device)
14
+ return model, processor, device
15
+
16
+ def input_fn(request_body, content_type):
17
+ if content_type == 'application/json':
18
+ data = json.loads(request_body)
19
+ image_data = data['inputs']
20
+ image = Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB")
21
+ return image
22
+ else:
23
+ raise ValueError(f"Unsupported content type: {content_type}")
24
+
25
+ def predict_fn(input_data, model_and_processor):
26
+ model, processor, device = model_and_processor
27
+ pixel_values = processor(input_data, return_tensors="pt").pixel_values.to(device)
28
+
29
+ model.eval()
30
+ with torch.no_grad():
31
+ task_prompt = "<s_receipt>"
32
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device)
33
+ generated_outputs = model.generate(
34
+ pixel_values,
35
+ decoder_input_ids=decoder_input_ids,
36
+ max_length=model.config.decoder.max_position_embeddings,
37
+ pad_token_id=processor.tokenizer.pad_token_id,
38
+ eos_token_id=processor.tokenizer.eos_token_id,
39
+ early_stopping=True,
40
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
41
+ return_dict_in_generate=True
42
+ )
43
+
44
+ decoded_text = processor.batch_decode(generated_outputs.sequences)[0]
45
+ decoded_text = decoded_text.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
46
+ decoded_text = re.sub(r"<.*?>", "", decoded_text, count=1).strip()
47
+ return decoded_text
48
+
49
+ def output_fn(prediction, accept):
50
+ return json.dumps({'result': prediction}), accept