yeftakun commited on
Commit
b9ed281
·
verified ·
1 Parent(s): 73f8cc0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -7
app.py CHANGED
@@ -1,13 +1,59 @@
1
- import gradio as gr
2
  from transformers import ViTImageProcessor, AutoModelForImageClassification
3
  from PIL import Image
4
  import requests
 
 
 
 
 
5
 
6
- # Load the model and processor
7
  processor = ViTImageProcessor.from_pretrained('AdamCodd/vit-base-nsfw-detector')
8
  model = AutoModelForImageClassification.from_pretrained('AdamCodd/vit-base-nsfw-detector')
9
 
10
- # Define prediction function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def predict_image(image_url):
12
  try:
13
  # Load image from URL
@@ -26,13 +72,18 @@ def predict_image(image_url):
26
  except Exception as e:
27
  return str(e)
28
 
29
- # Create Gradio interface
 
 
 
30
  iface = gr.Interface(
31
  fn=predict_image,
32
  inputs=gr.Textbox(label="Image URL", placeholder="Enter image URL here"),
33
  outputs=gr.Textbox(label="Predicted Class"),
34
- title="NSFW Image Detection"
 
 
35
  )
36
 
37
- # Launch the interface
38
- iface.launch()
 
1
+ from flask import Flask, request, jsonify, url_for
2
  from transformers import ViTImageProcessor, AutoModelForImageClassification
3
  from PIL import Image
4
  import requests
5
+ import threading
6
+ import gradio as gr
7
+
8
+ # Initialize the Flask app
9
+ app = Flask(__name__)
10
 
11
+ # Load the processor and model outside of the route to avoid reloading it with each request
12
  processor = ViTImageProcessor.from_pretrained('AdamCodd/vit-base-nsfw-detector')
13
  model = AutoModelForImageClassification.from_pretrained('AdamCodd/vit-base-nsfw-detector')
14
 
15
+ @app.route('/classify', methods=['POST'])
16
+ def classify_image():
17
+ try:
18
+ # Get the image URL from the POST request
19
+ data = request.get_json()
20
+ image_url = data.get('image_url')
21
+
22
+ if not image_url:
23
+ return jsonify({"error": "Image URL not provided"}), 400
24
+
25
+ # Fetch the image from the URL
26
+ image = Image.open(requests.get(image_url, stream=True).raw)
27
+
28
+ # Preprocess the image
29
+ inputs = processor(images=image, return_tensors="pt")
30
+
31
+ # Run the image through the model
32
+ outputs = model(**inputs)
33
+ logits = outputs.logits
34
+
35
+ # Get the predicted class
36
+ predicted_class_idx = logits.argmax(-1).item()
37
+ predicted_class = model.config.id2label[predicted_class_idx]
38
+
39
+ # Return the classification result
40
+ return jsonify({
41
+ "image_url": image_url,
42
+ "predicted_class": predicted_class
43
+ })
44
+
45
+ except Exception as e:
46
+ return jsonify({"error": str(e)}), 500
47
+
48
+ # Function to run the Flask app in a separate thread
49
+ def run_flask():
50
+ app.run(port=5000, debug=False, use_reloader=False)
51
+
52
+ # Launch Flask in a separate thread
53
+ flask_thread = threading.Thread(target=run_flask)
54
+ flask_thread.start()
55
+
56
+ # Gradio interface
57
  def predict_image(image_url):
58
  try:
59
  # Load image from URL
 
72
  except Exception as e:
73
  return str(e)
74
 
75
+ # Construct API endpoint URL
76
+ api_url = "http://127.0.0.1:5000/classify"
77
+
78
+ # Create Gradio interface with API info
79
  iface = gr.Interface(
80
  fn=predict_image,
81
  inputs=gr.Textbox(label="Image URL", placeholder="Enter image URL here"),
82
  outputs=gr.Textbox(label="Predicted Class"),
83
+ title="NSFW Image Detection",
84
+ description=f"You can get your image classification by sending an API request to: {api_url}. Example:\n"
85
+ f"curl -X POST {api_url} -H 'Content-Type: application/json' -d '{{\"image_url\": \"YOUR_IMAGE_URL\"}}'"
86
  )
87
 
88
+ # Launch Gradio interface
89
+ iface.launch()