HawkeyeHS commited on
Commit
43d3f36
·
1 Parent(s): 2b0a552

Add application file

Browse files
Files changed (1) hide show
  1. app.py +30 -22
app.py CHANGED
@@ -2,16 +2,11 @@ import os
2
  import warnings
3
  from transformers import AutoModelForImageClassification, AutoFeatureExtractor
4
  import torch
5
- warnings.filterwarnings("ignore")
6
-
7
- import json
8
  from flask_cors import CORS
9
- from flask import Flask, request, Response
10
-
11
  import numpy as np
12
  from PIL import Image
13
  import requests
14
-
15
  from io import BytesIO
16
 
17
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
@@ -19,36 +14,49 @@ os.environ["CUDA_VISIBLE_DEVICES"] = ""
19
  app = Flask(__name__)
20
  cors = CORS(app)
21
 
22
- global MODEL
23
- global CLASSES
24
-
25
 
26
  @app.route("/", methods=["GET"])
27
  def default():
28
  return json.dumps({"Hello I am Chitti": "Speed 1 Terra Hertz, Memory 1 Zeta Byte"})
29
 
30
-
31
  @app.route("/predict", methods=["GET"])
32
  def predict():
33
- feature_extractor = AutoFeatureExtractor.from_pretrained('carbon225/vit-base-patch16-224-hentai')
34
- model = AutoModelForImageClassification.from_pretrained('carbon225/vit-base-patch16-224-hentai')
35
- src = request.args.get("src")
36
- print(f"{src=}")
37
- response = requests.get(src)
38
- print(f"{response=}")
39
  try:
 
 
 
 
 
 
 
 
40
  image = Image.open(BytesIO(response.content))
41
  image = image.resize((128, 128))
42
- encoding = feature_extractor(image.convert("RGB"), return_tensors="pt")
 
 
 
 
43
  with torch.no_grad():
44
  outputs = model(**encoding)
45
  logits = outputs.logits
46
 
 
47
  predicted_class_idx = logits.argmax(-1).item()
48
- print(model.config.id2label[predicted_class_idx])
49
- # Return the Predictions
50
- return json.dumps({"class": model.config.id2label[predicted_class_idx]})
51
- except Exception as e:
52
- return json.dumps({"Uh oh": f"{str(e)}"})
53
 
 
 
 
 
 
 
 
 
 
54
 
 
 
 
2
  import warnings
3
  from transformers import AutoModelForImageClassification, AutoFeatureExtractor
4
  import torch
 
 
 
5
  from flask_cors import CORS
6
+ from flask import Flask, request, json, Response
 
7
  import numpy as np
8
  from PIL import Image
9
  import requests
 
10
  from io import BytesIO
11
 
12
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
 
14
  app = Flask(__name__)
15
  cors = CORS(app)
16
 
17
+ # Define the model and feature extractor globally
18
+ model = AutoModelForImageClassification.from_pretrained('carbon225/vit-base-patch16-224-hentai')
19
+ feature_extractor = AutoFeatureExtractor.from_pretrained('carbon225/vit-base-patch16-224-hentai')
20
 
21
  @app.route("/", methods=["GET"])
22
  def default():
23
  return json.dumps({"Hello I am Chitti": "Speed 1 Terra Hertz, Memory 1 Zeta Byte"})
24
 
 
25
  @app.route("/predict", methods=["GET"])
26
  def predict():
 
 
 
 
 
 
27
  try:
28
+ src = request.args.get("src")
29
+ print(f"{src=}")
30
+
31
+ # Download image from the provided URL
32
+ response = requests.get(src)
33
+ response.raise_for_status() # Check for HTTP errors
34
+
35
+ # Open and preprocess the image
36
  image = Image.open(BytesIO(response.content))
37
  image = image.resize((128, 128))
38
+
39
+ # Extract features using the pre-trained feature extractor
40
+ encoding = feature_extractor(images=image.convert("RGB"), return_tensors="pt")
41
+
42
+ # Make a prediction using the pre-trained model
43
  with torch.no_grad():
44
  outputs = model(**encoding)
45
  logits = outputs.logits
46
 
47
+ # Get the predicted class index and label
48
  predicted_class_idx = logits.argmax(-1).item()
49
+ predicted_class_label = model.config.id2label[predicted_class_idx]
 
 
 
 
50
 
51
+ print(predicted_class_label)
52
+
53
+ # Return the predictions
54
+ return json.dumps({"class": predicted_class_label})
55
+
56
+ except requests.exceptions.RequestException as e:
57
+ return json.dumps({"error": f"Request error: {str(e)}"})
58
+ except Exception as e:
59
+ return json.dumps({"error": f"An unexpected error occurred: {str(e)}"})
60
 
61
+ if __name__ == "__main__":
62
+ app.run(debug=True)