yeftakun commited on
Commit
b75787e
·
verified ·
1 Parent(s): 2f8f091

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -6
app.py CHANGED
@@ -1,18 +1,17 @@
1
  import gradio as gr
2
  from transformers import ViTImageProcessor, AutoModelForImageClassification
3
  from PIL import Image
 
4
  import requests
 
5
 
6
  # Load the model and processor
7
  processor = ViTImageProcessor.from_pretrained('AdamCodd/vit-base-nsfw-detector')
8
  model = AutoModelForImageClassification.from_pretrained('AdamCodd/vit-base-nsfw-detector')
9
 
10
  # Define prediction function
11
- def predict_image(image_url):
12
  try:
13
- # Load image from URL
14
- image = Image.open(requests.get(image_url, stream=True).raw)
15
-
16
  # Process the image and make prediction
17
  inputs = processor(images=image, return_tensors="pt")
18
  outputs = model(**inputs)
@@ -29,10 +28,43 @@ def predict_image(image_url):
29
  # Create Gradio interface
30
  iface = gr.Interface(
31
  fn=predict_image,
32
- inputs=gr.Textbox(label="Image URL", placeholder="Enter image URL here"),
33
  outputs=gr.Textbox(label="Predicted Class"),
34
  title="NSFW Image Classifier"
35
  )
36
 
37
- # Launch the interface
38
  iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from transformers import ViTImageProcessor, AutoModelForImageClassification
3
  from PIL import Image
4
+ import io
5
  import requests
6
+ from flask import Flask, request, jsonify
7
 
8
  # Load the model and processor
9
  processor = ViTImageProcessor.from_pretrained('AdamCodd/vit-base-nsfw-detector')
10
  model = AutoModelForImageClassification.from_pretrained('AdamCodd/vit-base-nsfw-detector')
11
 
12
  # Define prediction function
13
+ def predict_image(image):
14
  try:
 
 
 
15
  # Process the image and make prediction
16
  inputs = processor(images=image, return_tensors="pt")
17
  outputs = model(**inputs)
 
28
  # Create Gradio interface
29
  iface = gr.Interface(
30
  fn=predict_image,
31
+ inputs=gr.Image(type="pil", label="Upload Image"),
32
  outputs=gr.Textbox(label="Predicted Class"),
33
  title="NSFW Image Classifier"
34
  )
35
 
36
+ # Launch the Gradio interface
37
  iface.launch()
38
+
39
+ # Flask app for API endpoint
40
+ app = Flask(__name__)
41
+
42
+ @app.route('/predict', methods=['POST'])
43
+ def predict():
44
+ if 'file' not in request.files:
45
+ return jsonify({'error': 'No file part'}), 400
46
+
47
+ file = request.files['file']
48
+ if file.filename == '':
49
+ return jsonify({'error': 'No selected file'}), 400
50
+
51
+ try:
52
+ # Load image from the uploaded file
53
+ image = Image.open(file.stream)
54
+
55
+ # Process the image and make prediction
56
+ inputs = processor(images=image, return_tensors="pt")
57
+ outputs = model(**inputs)
58
+ logits = outputs.logits
59
+
60
+ # Get predicted class
61
+ predicted_class_idx = logits.argmax(-1).item()
62
+ predicted_label = model.config.id2label[predicted_class_idx]
63
+
64
+ return jsonify({'predicted_class': predicted_label})
65
+ except Exception as e:
66
+ return jsonify({'error': str(e)}), 500
67
+
68
+ # Run Flask app
69
+ if __name__ == '__main__':
70
+ app.run(port=5000)