File size: 1,428 Bytes
4383992
 
 
 
 
 
 
 
 
 
5b49d65
4383992
 
 
 
 
5b49d65
4383992
 
 
d355235
4383992
 
 
5b49d65
 
 
 
4383992
 
 
 
 
 
 
 
5b49d65
 
4383992
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import json
import logging

from qa_generator_pipeline import QAGeneratorPipeline
logger = logging.getLogger(__name__)

JSON_CONTENT_TYPE = 'application/json'


def model_fn(model_dir):
    logging.info('[### model_fn ###] Loading model from {}'.format(model_dir))
    model = QAGeneratorPipeline(model_dir=model_dir, use_cuda=True)
    return model


def predict_fn(input_data, model):
    logging.info('[### predict_fn ###] Entering predict_fn() method')
    logger.info("input text: {}".format(input_data))
    prediction = model(input_data)
    logger.info("prediction: {}".format(input_data))
    return prediction


def input_fn(serialized_input_data, content_type=JSON_CONTENT_TYPE):
    logging.info('[### input_fn ###] Entering input_fn() method')
    logging.info('[### input_fn ###] request_content_type: {}'.format(content_type))
    logging.info('[### input_fn ###] request_body: {}'.format(type(serialized_input_data)))

    if content_type == JSON_CONTENT_TYPE:
        input_data = json.loads(serialized_input_data)
        return input_data
    else:
        pass


def output_fn(prediction_output, accept=JSON_CONTENT_TYPE):
    logging.info('[### output_fn ###] Entering output_fn() method')
    logging.info('[### output_fn ###] prediction: {}'.format(prediction_output))
    if accept == JSON_CONTENT_TYPE:
        return json.dumps(prediction_output), accept
    raise Exception('Unsupported Content Type')