|
from flask import Flask, request, jsonify |
|
import torch |
|
from transformers import AutoProcessor, AutoModelForImageClassification |
|
from sentence_transformers import SentenceTransformer |
|
import sqlite3 |
|
|
|
app = Flask(__name__) |
|
|
|
|
|
DETECTION_MODEL_NAME = "microsoft/beit-base-patch16-224-pt22k-ft22k" |
|
processor = AutoProcessor.from_pretrained(DETECTION_MODEL_NAME) |
|
detection_model = AutoModelForImageClassification.from_pretrained(DETECTION_MODEL_NAME) |
|
|
|
defects_to_remedies = { |
|
"crack": "Fill cracks with epoxy. Structural cracks might need professional inspection.", |
|
"spalling": "Clean affected area and apply anti-corrosion primer before repairing.", |
|
"leakage": "Fix water source, seal with water-proofing compounds.", |
|
"mold": "Clean the mold, improve ventilation, and apply mold-resistant paint." |
|
} |
|
|
|
|
|
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" |
|
embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME) |
|
|
|
|
|
db_conn = sqlite3.connect('defects.db', check_same_thread=False) |
|
c = db_conn.cursor() |
|
c.execute('''CREATE TABLE IF NOT EXISTS defects (id INTEGER PRIMARY KEY, defect TEXT, remedy TEXT, embedding BLOB)''') |
|
db_conn.commit() |
|
|
|
|
|
def seed_database(): |
|
for defect, remedy in defects_to_remedies.items(): |
|
c.execute("SELECT * FROM defects WHERE defect=?", (defect,)) |
|
if not c.fetchone(): |
|
embedding = embedding_model.encode(remedy).tolist() |
|
c.execute("INSERT INTO defects (defect, remedy, embedding) VALUES (?, ?, ?)", (defect, remedy, str(embedding))) |
|
db_conn.commit() |
|
|
|
seed_database() |
|
|
|
@app.route('/detect', methods=['POST']) |
|
def detect_defect(): |
|
if 'image' not in request.files: |
|
return jsonify({"error": "No image uploaded."}), 400 |
|
|
|
image = request.files['image'].read() |
|
|
|
|
|
inputs = processor(images=image, return_tensors="pt") |
|
outputs = detection_model(**inputs) |
|
probs = torch.nn.functional.softmax(outputs.logits, dim=1) |
|
predicted_class = torch.argmax(probs, dim=1) |
|
class_name = detection_model.config.id2label[predicted_class.item()] |
|
|
|
|
|
c.execute("SELECT remedy FROM defects WHERE defect=?", (class_name,)) |
|
row = c.fetchone() |
|
if row: |
|
remedy = row[0] |
|
else: |
|
remedy = "No specific remedy available for this defect." |
|
|
|
return jsonify({"detected_defect": class_name, "remedy": remedy}) |
|
|
|
if __name__ == '__main__': |
|
app.run(debug=True) |