Spaces:
Sleeping
Sleeping
import os | |
import warnings | |
from transformers import AutoModelForImageClassification, AutoFeatureExtractor | |
import torch | |
warnings.filterwarnings("ignore") | |
import json | |
from flask_cors import CORS | |
from flask import Flask, request, Response | |
import numpy as np | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
app = Flask(__name__) | |
cors = CORS(app) | |
global MODEL | |
global CLASSES | |
def default(): | |
return json.dumps({"Hello I am Chitti": "Speed 1 Terra Hertz, Memory 1 Zeta Byte"}) | |
def predict(): | |
feature_extractor = AutoFeatureExtractor.from_pretrained('carbon225/vit-base-patch16-224-hentai') | |
model = AutoModelForImageClassification.from_pretrained('carbon225/vit-base-patch16-224-hentai') | |
src = request.args.get("src") | |
print(f"{src=}") | |
response = requests.get(src) | |
print(f"{response=}") | |
try: | |
image = Image.open(BytesIO(response.content)) | |
image = image.resize((128, 128)) | |
image.save("new.jpg") | |
encoding = feature_extractor(image.convert("RGB"), return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**encoding) | |
logits = outputs.logits | |
predicted_class_idx = logits.argmax(-1).item() | |
print(model.config.id2label[predicted_class_idx]) | |
# Return the predictions | |
return json.dumps({"class": model.config.id2label[predicted_class_idx]}) | |
except Exception as e: | |
print(f"An error occurred: {str(e)}") | |
return json.dumps({"Uh oh": "We are down"}) | |