anon5's picture
Update app.py
c55b8a7 verified
raw
history blame
1.1 kB
from flask import Flask, request, jsonify
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import requests
import logging
print("Loading models...")
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch32-384')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch32-384')
print("Starting webapp...")
app = Flask(__name__)
log = logging.getLogger('werkzeug')
log.disabled = True
app.logger.disabled = True
print("Ready")
@app.route("/")
def hello_world():
global feature_extractor, model
url = request.args.get('url')
if url is None:
return jsonify(error="Url is required", url=None, classes=[])
image = Image.open(requests.get(url, stream=True).raw)
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = logits.argmax(-1).item()
return jsonify(url=url, classes=model.config.id2label[predicted_class_idx])