Test / app.py
Shakir60's picture
Update app.py
7136f34 verified
raw
history blame
2.62 kB
from flask import Flask, request, jsonify
import torch
from transformers import AutoProcessor, AutoModelForImageClassification
from sentence_transformers import SentenceTransformer
import sqlite3
app = Flask(__name__)
# Load the defect detection model (open-source, hugging face model)
DETECTION_MODEL_NAME = "microsoft/beit-base-patch16-224-pt22k-ft22k"
processor = AutoProcessor.from_pretrained(DETECTION_MODEL_NAME)
detection_model = AutoModelForImageClassification.from_pretrained(DETECTION_MODEL_NAME)
defects_to_remedies = {
"crack": "Fill cracks with epoxy. Structural cracks might need professional inspection.",
"spalling": "Clean affected area and apply anti-corrosion primer before repairing.",
"leakage": "Fix water source, seal with water-proofing compounds.",
"mold": "Clean the mold, improve ventilation, and apply mold-resistant paint."
}
# Initialize a Sentence Transformer for text embeddings
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
# SQLite database setup
db_conn = sqlite3.connect('defects.db', check_same_thread=False)
c = db_conn.cursor()
c.execute('''CREATE TABLE IF NOT EXISTS defects (id INTEGER PRIMARY KEY, defect TEXT, remedy TEXT, embedding BLOB)''')
db_conn.commit()
# Populate defect remedies table if empty
def seed_database():
for defect, remedy in defects_to_remedies.items():
c.execute("SELECT * FROM defects WHERE defect=?", (defect,))
if not c.fetchone():
embedding = embedding_model.encode(remedy).tolist()
c.execute("INSERT INTO defects (defect, remedy, embedding) VALUES (?, ?, ?)", (defect, remedy, str(embedding)))
db_conn.commit()
seed_database()
@app.route('/detect', methods=['POST'])
def detect_defect():
if 'image' not in request.files:
return jsonify({"error": "No image uploaded."}), 400
image = request.files['image'].read()
# Preprocess and predict the defect
inputs = processor(images=image, return_tensors="pt")
outputs = detection_model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
predicted_class = torch.argmax(probs, dim=1)
class_name = detection_model.config.id2label[predicted_class.item()]
# Query remedy
c.execute("SELECT remedy FROM defects WHERE defect=?", (class_name,))
row = c.fetchone()
if row:
remedy = row[0]
else:
remedy = "No specific remedy available for this defect."
return jsonify({"detected_defect": class_name, "remedy": remedy})
if __name__ == '__main__':
app.run(debug=True)