from django.shortcuts import render from django.views.generic import TemplateView from model_client import url_image_vars from sample_loader import get_random_sample_image from .forms import ImageUploadForm, URLInputForm from base64 import b64encode def get_default_form_context(): return { "url_form": URLInputForm(), "image_form": ImageUploadForm(), } class HomeView(TemplateView): template_name = "index.html" def render_for_img_url_label(self, request, img_url: str, label: str): img_src = img_url top_guesses, color_labels = url_image_vars(img_url, label) context_dict = get_default_form_context() context_dict.update( { "img_src": img_src, "top_guesses": top_guesses, "actual_label": label or "Unknown", "color_labels": color_labels, } ) return render(request, "request.html", context_dict) def render_for_img_file_label(self, request, img_file, label: str): decoded_img = b64encode(img_file.read()).decode("utf-8") img_src = f"data:{img_file.content_type};base64,{decoded_img}" top_guesses, color_labels = url_image_vars(img_file, label) context_dict = get_default_form_context() context_dict.update( { "top_guesses": top_guesses, "actual_label": label or "Unknown", "color_labels": color_labels, "img_src": img_src, } ) return render(request, "request.html", context_dict) def get(self, request): context_dict = get_default_form_context() return render(request, self.template_name, context_dict) def post(self, request): action = request.POST["action"] context_dict = get_default_form_context() if action == "SubmitURL": form = URLInputForm(request.POST) if form.is_valid(): img_url = form.cleaned_data["post"] label = form.cleaned_data["label"] return self.render_for_img_url_label(request, img_url, label) if action == "SubmitImage": form = ImageUploadForm(request.POST, request.FILES) assert form.is_valid() if form.is_valid(): image_field = form.cleaned_data["img"] label = form.cleaned_data["label"] return self.render_for_img_file_label(request, image_field, label) if action == "LoadRandom": img_url, label = get_random_sample_image() return self.render_for_img_url_label(request, img_url, label) return render(request, self.template_name, context_dict)