vaibhaviiii28's picture
Update app.py
bd4bfef verified
import os
import re
import string
import io
import joblib
import uvicorn
from threading import Thread
from fastapi import FastAPI
from pydantic import BaseModel
from flask import Flask, request, jsonify
from transformers import pipeline
from PIL import Image
import huggingface_hub
# βœ… Set Hugging Face cache to a writable directory
cache_dir = "/tmp/huggingface_cache"
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["HF_HOME"] = cache_dir
huggingface_hub.utils.HF_HUB_CACHE = cache_dir # βœ… Force Hugging Face to use the correct cache directory
# βœ… Ensure the cache directory exists
os.makedirs(cache_dir, exist_ok=True)
# βœ… Initialize FastAPI and Flask
app_fastapi = FastAPI()
app_flask = Flask(__name__)
# βœ… Load NSFW Image Classification Model with explicit cache directory
pipe = pipeline("image-classification", model="LukeJacob2023/nsfw-image-detector", cache_dir=cache_dir)
# βœ… Load Toxic Text Classification Model
try:
model = joblib.load("toxic_classifier.pkl")
vectorizer = joblib.load("vectorizer.pkl")
print("βœ… Model & Vectorizer Loaded Successfully!")
except Exception as e:
print(f"❌ Error loading model/vectorizer: {e}")
exit(1)
# πŸ“Œ Text Input Data Model
class TextInput(BaseModel):
text: str
# πŸ”Ή Text Preprocessing Function
def preprocess_text(text):
text = text.lower()
text = re.sub(r'\d+', '', text) # Remove numbers
text = text.translate(str.maketrans('', '', string.punctuation)) # Remove punctuation
return text.strip()
# πŸ“Œ NSFW Image Classification API (Flask)
@app_flask.route('/classify_image', methods=['POST'])
def classify_image():
if 'file' not in request.files:
return jsonify({"error": "No file uploaded"}), 400
file = request.files['file']
image = Image.open(io.BytesIO(file.read()))
results = pipe(image)
classification_label = max(results, key=lambda x: x['score'])['label']
nsfw_labels = {"sexy", "porn", "hentai"}
nsfw_status = "NSFW" if classification_label in nsfw_labels else "SFW"
return jsonify({"status": nsfw_status, "results": results})
# πŸ“Œ Toxic Text Classification API (FastAPI)
@app_fastapi.post("/classify_text/")
async def classify_text(data: TextInput):
try:
processed_text = preprocess_text(data.text)
text_vectorized = vectorizer.transform([processed_text])
prediction = model.predict(text_vectorized)
result = "Toxic" if prediction[0] == 1 else "Safe"
return {"prediction": result}
except Exception as e:
return {"error": str(e)}
# πŸ”₯ Start both Flask & FastAPI servers in threads
def run_flask():
app_flask.run(host="0.0.0.0", port=5000)
def run_fastapi():
uvicorn.run(app_fastapi, host="0.0.0.0", port=8000)
# Start Flask & FastAPI in separate threads
Thread(target=run_flask).start()
Thread(target=run_fastapi).start()