File size: 3,366 Bytes
03fc4f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from time import time
from typing import List, Optional, Tuple, Union
import json
import os
import requests
from urllib.parse import urljoin

from django.core.files.uploadedfile import InMemoryUploadedFile, TemporaryUploadedFile

import logging


MODEL_ENDPOINT_URL = os.getenv("MODEL_ENDPOINT_URL", "https://0.0.0.0:2000")


def try_make_request(request_kwargs, request_type: str):
    try:
        request_kwargs["timeout"] = 3
        if request_type.lower() == "get":
            response = requests.get(**request_kwargs)
        elif request_type.lower() == "post":
            response = requests.post(**request_kwargs)
        else:
            raise Exception("Request Type not Supported.  Only get, post supported.")

        return json.loads(response.content)
    except requests.exceptions.ConnectionError:
        logging.warning("Failed Model prediction", exc_info=True)
        return ["Image Failed to Predict", "Try Another Image", "", "", ""]
    except Exception:
        logging.warning("Failed Model prediction", exc_info=True)
        return ["Image Failed to Predict", "Try Another Image", "", "", ""]


def predict_url(url: str) -> List[str]:
    params = {"url": url}
    headers = {"content-type": "application/json", "Accept-Charset": "UTF-8"}
    request_url = urljoin(MODEL_ENDPOINT_URL, "predict_url")
    request_kwargs = dict(url=request_url, params=params, headers=headers)
    return try_make_request(request_kwargs, "get")


def predict_file(image_file) -> List[str]:
    image_file.seek(0)
    file_ob = {
        "upload_file": (image_file.name, image_file.read(), image_file.content_type)
    }
    request_url = urljoin(MODEL_ENDPOINT_URL, "predict_file")
    request_kwargs = dict(url=request_url, files=file_ob)
    return try_make_request(request_kwargs, "post")


def get_color_labels(guesses: List[str], actual_label: Optional[str]) -> List[str]:
    if not actual_label:
        return ["white"] * len(guesses)
    return ["lime" if x == actual_label else "white" for x in guesses]


def url_image_vars(
    input_img: Union[str, InMemoryUploadedFile, TemporaryUploadedFile], label: str
) -> Tuple[List[str], List[str]]:
    actual_label = label.title()
    if not is_healthy():
        logging.error("Model failed healthcheck")
        top_guesses = ["Model Offline", "Try Again Later", "", "", ""]
    elif isinstance(input_img, str):
        top_guesses = predict_url(input_img)
    elif isinstance(input_img, (InMemoryUploadedFile, TemporaryUploadedFile)):
        top_guesses = predict_file(input_img)
    else:
        logging.error(f"Unknown input type: {type(input_img)=}")
        top_guesses = ["Unknown Input Type", "", "", "", ""]
    color_labels = get_color_labels(top_guesses, actual_label)
    return top_guesses, color_labels


def is_healthy() -> bool:
    request_url = urljoin(MODEL_ENDPOINT_URL, "healthcheck")
    try:
        response = requests.get(url=request_url, timeout=1)
    except Exception:
        logging.error("Failed to make healthcheck request")
        return False
    if response.status_code == 200:
        try:
            response_content = json.loads(response.content)
        except Exception:
            logging.error("Failed to load healthcheck content")
            return False
        if response_content == {"status": "alive"}:
            return True
    return False