i-dont-hug-face commited on
Commit
1550120
1 Parent(s): 27b2057

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +49 -32
inference.py CHANGED
@@ -4,6 +4,9 @@ from PIL import Image
4
  from transformers import DonutProcessor, VisionEncoderDecoderModel
5
  import io
6
  import json
 
 
 
7
 
8
  def model_fn(model_dir, context=None):
9
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -12,37 +15,51 @@ def model_fn(model_dir, context=None):
12
  model.to(device)
13
  return model, processor, device
14
 
15
- def transform_fn(model, request_body, request_content_type, response_content_type, context=None):
 
 
 
 
 
 
 
 
 
 
 
16
  model, processor, device = model
17
- if request_content_type == 'application/json':
18
- image = Image.open(io.BytesIO(request_body))
19
-
20
- # Preprocess the image
21
- pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
22
-
23
- # Run inference
24
- model.eval()
25
- with torch.no_grad():
26
- task_prompt = "<s_receipt>"
27
- decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device)
28
- generated_outputs = model.generate(
29
- pixel_values,
30
- decoder_input_ids=decoder_input_ids,
31
- max_length=model.config.decoder.max_position_embeddings,
32
- pad_token_id=processor.tokenizer.pad_token_id,
33
- eos_token_id=processor.tokenizer.eos_token_id,
34
- early_stopping=True,
35
- bad_words_ids=[[processor.tokenizer.unk_token_id]],
36
- return_dict_in_generate=True
37
- )
38
-
39
- # Decode the output
40
- decoded_text = processor.batch_decode(generated_outputs.sequences)[0]
41
- decoded_text = decoded_text.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
42
- decoded_text = re.sub(r"<.*?>", "", decoded_text, count=1).strip()
43
-
44
- # Prepare the response
45
- prediction = {'result': decoded_text}
46
- return json.dumps(prediction), response_content_type
 
 
 
47
  else:
48
- raise ValueError(f"Unsupported content type: {request_content_type}")
 
4
  from transformers import DonutProcessor, VisionEncoderDecoderModel
5
  import io
6
  import json
7
+ import logging
8
+
9
+ logging.basicConfig(level=logging.INFO)
10
 
11
  def model_fn(model_dir, context=None):
12
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
15
  model.to(device)
16
  return model, processor, device
17
 
18
+ def input_fn(input_data, content_type, context=None):
19
+ """Deserialize the input data."""
20
+ logging.info("Entering input_fn")
21
+ if content_type == 'application/json':
22
+ image = Image.open(io.BytesIO(input_data))
23
+ return image
24
+ else:
25
+ raise ValueError(f"Unsupported content type: {content_type}")
26
+
27
+ def predict_fn(data, model, context=None):
28
+ """Apply the model to the input data."""
29
+ logging.info("Entering predict_fn")
30
  model, processor, device = model
31
+
32
+ # Preprocess the image
33
+ pixel_values = processor(data, return_tensors="pt").pixel_values.to(device)
34
+
35
+ # Run inference
36
+ model.eval()
37
+ with torch.no_grad():
38
+ task_prompt = "<s_receipt>"
39
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device)
40
+ generated_outputs = model.generate(
41
+ pixel_values,
42
+ decoder_input_ids=decoder_input_ids,
43
+ max_length=model.config.decoder.max_position_embeddings,
44
+ pad_token_id=processor.tokenizer.pad_token_id,
45
+ eos_token_id=processor.tokenizer.eos_token_id,
46
+ early_stopping=True,
47
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
48
+ return_dict_in_generate=True
49
+ )
50
+
51
+ # Decode the output
52
+ decoded_text = processor.batch_decode(generated_outputs.sequences)[0]
53
+ decoded_text = decoded_text.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
54
+ decoded_text = re.sub(r"<.*?>", "", decoded_text, count=1).strip()
55
+
56
+ prediction = {'result': decoded_text}
57
+ return prediction
58
+
59
+ def output_fn(prediction, accept):
60
+ """Serialize the prediction output."""
61
+ logging.info("Entering output_fn")
62
+ if accept == 'application/json':
63
+ return json.dumps(prediction)
64
  else:
65
+ raise ValueError(f"Unsupported response content type: {accept}")