from flask import Flask, request, jsonify from transformers import ViTFeatureExtractor, ViTForImageClassification from PIL import Image import requests print("Loading models...") feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch32-384') model = ViTForImageClassification.from_pretrained('google/vit-base-patch32-384') print("Starting webapp...") app = Flask(__name__) print("Ready") @app.route("/") def hello_world(): global feature_extractor, model url = request.args.get('url') if url is None: return jsonify(error="Url is required", url=None, classes=[]) image = Image.open(requests.get(url, stream=True).raw) inputs = feature_extractor(images=image, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits # model predicts one of the 1000 ImageNet classes predicted_class_idx = logits.argmax(-1).item() return jsonify(url=url, classes=model.config.id2label[predicted_class_idx])