Update app.py
Browse files
app.py
CHANGED
@@ -1,18 +1,31 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
-
from transformers import
|
4 |
from PIL import Image
|
5 |
import numpy as np
|
|
|
|
|
6 |
import logging
|
|
|
7 |
|
8 |
# Configure logging
|
9 |
logging.basicConfig(level=logging.DEBUG)
|
10 |
|
11 |
# Load the pre-trained model and feature extractor
|
12 |
-
model_name = "
|
13 |
logging.info("Loading image processor and model...")
|
14 |
-
feature_extractor =
|
15 |
-
model =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
# Define the prediction function
|
18 |
def predict(image):
|
@@ -20,15 +33,9 @@ def predict(image):
|
|
20 |
logging.info("Received image of type: %s", type(image))
|
21 |
logging.debug("Image content: %s", image)
|
22 |
|
23 |
-
#
|
24 |
-
if isinstance(image, dict):
|
25 |
-
image = image['composite']
|
26 |
-
|
27 |
-
logging.debug("Converting to NumPy array...")
|
28 |
image = np.array(image).astype('uint8')
|
29 |
-
logging.debug("Converting NumPy array to PIL image...")
|
30 |
image = Image.fromarray(image, 'RGBA').convert('RGB')
|
31 |
-
logging.debug("Image converted successfully.")
|
32 |
|
33 |
logging.info("Processing image...")
|
34 |
inputs = feature_extractor(images=image, return_tensors="pt")
|
@@ -36,8 +43,8 @@ def predict(image):
|
|
36 |
logits = outputs.logits
|
37 |
probs = torch.nn.functional.softmax(logits, dim=-1)
|
38 |
top_probs, top_idxs = probs.topk(3, dim=-1)
|
39 |
-
top_probs = top_probs.detach().numpy()[0]
|
40 |
-
top_idxs = top_idxs.detach().numpy()[0]
|
41 |
top_classes = [model.config.id2label[idx] for idx in top_idxs]
|
42 |
result = {top_classes[i]: float(top_probs[i]) for i in range(3)}
|
43 |
logging.info("Prediction successful.")
|
@@ -46,14 +53,103 @@ def predict(image):
|
|
46 |
logging.error("Error during prediction: %s", e)
|
47 |
return {"error": str(e)}
|
48 |
|
49 |
-
#
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
fn=predict,
|
52 |
inputs=gr.Sketchpad(),
|
53 |
outputs=gr.JSON(),
|
54 |
title="Drawing Classifier",
|
55 |
-
description="Draw something and the model will try to identify it!"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
)
|
57 |
|
58 |
-
# Launch the
|
59 |
-
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
+
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, Trainer, TrainingArguments
|
4 |
from PIL import Image
|
5 |
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import os
|
8 |
import logging
|
9 |
+
from datasets import Dataset
|
10 |
|
11 |
# Configure logging
|
12 |
logging.basicConfig(level=logging.DEBUG)
|
13 |
|
14 |
# Load the pre-trained model and feature extractor
|
15 |
+
model_name = "google/vit-base-patch16-224"
|
16 |
logging.info("Loading image processor and model...")
|
17 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
|
18 |
+
model = AutoModelForImageClassification.from_pretrained(model_name)
|
19 |
+
|
20 |
+
# Load or initialize the feedback data
|
21 |
+
feedback_data_path = "feedback_data.csv"
|
22 |
+
if os.path.exists(feedback_data_path):
|
23 |
+
feedback_data = pd.read_csv(feedback_data_path)
|
24 |
+
else:
|
25 |
+
feedback_data = pd.DataFrame(columns=["image_path", "correct_label"])
|
26 |
+
|
27 |
+
# Directory to save images
|
28 |
+
os.makedirs("images", exist_ok=True)
|
29 |
|
30 |
# Define the prediction function
|
31 |
def predict(image):
|
|
|
33 |
logging.info("Received image of type: %s", type(image))
|
34 |
logging.debug("Image content: %s", image)
|
35 |
|
36 |
+
# Convert to NumPy array and then to PIL image
|
|
|
|
|
|
|
|
|
37 |
image = np.array(image).astype('uint8')
|
|
|
38 |
image = Image.fromarray(image, 'RGBA').convert('RGB')
|
|
|
39 |
|
40 |
logging.info("Processing image...")
|
41 |
inputs = feature_extractor(images=image, return_tensors="pt")
|
|
|
43 |
logits = outputs.logits
|
44 |
probs = torch.nn.functional.softmax(logits, dim=-1)
|
45 |
top_probs, top_idxs = probs.topk(3, dim=-1)
|
46 |
+
top_probs = top_probs.detach().numpy()[0]
|
47 |
+
top_idxs = top_idxs.detach().numpy()[0]
|
48 |
top_classes = [model.config.id2label[idx] for idx in top_idxs]
|
49 |
result = {top_classes[i]: float(top_probs[i]) for i in range(3)}
|
50 |
logging.info("Prediction successful.")
|
|
|
53 |
logging.error("Error during prediction: %s", e)
|
54 |
return {"error": str(e)}
|
55 |
|
56 |
+
# Save feedback and retrain if necessary
|
57 |
+
def save_feedback(image, correct_label):
|
58 |
+
global feedback_data
|
59 |
+
try:
|
60 |
+
image_np = np.array(image['composite']).astype('uint8')
|
61 |
+
image_pil = Image.fromarray(image_np, 'RGBA').convert('RGB')
|
62 |
+
image_path = f"images/{len(feedback_data)}.png"
|
63 |
+
image_pil.save(image_path)
|
64 |
+
|
65 |
+
# Add the feedback to the DataFrame
|
66 |
+
feedback_data = feedback_data.append({"image_path": image_path, "correct_label": correct_label}, ignore_index=True)
|
67 |
+
feedback_data.to_csv(feedback_data_path, index=False)
|
68 |
+
|
69 |
+
# Retrain if we have collected 5 new feedbacks
|
70 |
+
if len(feedback_data) % 5 == 0:
|
71 |
+
retrain_model(feedback_data)
|
72 |
+
|
73 |
+
return "Feedback saved and model retrained!" if len(feedback_data) % 5 == 0 else "Feedback saved!"
|
74 |
+
except Exception as e:
|
75 |
+
logging.error("Error saving feedback: %s", e)
|
76 |
+
return {"error": str(e)}
|
77 |
+
|
78 |
+
# Retrain the model with the feedback data
|
79 |
+
def retrain_model(feedback_data):
|
80 |
+
try:
|
81 |
+
logging.info("Retraining the model with feedback data...")
|
82 |
+
|
83 |
+
# Load images and labels into a Hugging Face dataset
|
84 |
+
def load_image(file_path):
|
85 |
+
return Image.open(file_path).convert("RGB")
|
86 |
+
|
87 |
+
dataset_dict = {
|
88 |
+
"image": [load_image(f) for f in feedback_data["image_path"]],
|
89 |
+
"label": feedback_data["correct_label"].tolist()
|
90 |
+
}
|
91 |
+
|
92 |
+
dataset = Dataset.from_dict(dataset_dict)
|
93 |
+
dataset = dataset.train_test_split(test_size=0.1)
|
94 |
+
|
95 |
+
# Preprocess the dataset
|
96 |
+
def preprocess(examples):
|
97 |
+
inputs = feature_extractor(images=examples["image"], return_tensors="pt")
|
98 |
+
inputs["labels"] = examples["label"]
|
99 |
+
return inputs
|
100 |
+
|
101 |
+
dataset = dataset.with_transform(preprocess)
|
102 |
+
|
103 |
+
# Set up the training arguments
|
104 |
+
training_args = TrainingArguments(
|
105 |
+
output_dir="./results",
|
106 |
+
evaluation_strategy="epoch",
|
107 |
+
per_device_train_batch_size=4,
|
108 |
+
per_device_eval_batch_size=4,
|
109 |
+
num_train_epochs=3,
|
110 |
+
save_strategy="epoch",
|
111 |
+
save_total_limit=2,
|
112 |
+
remove_unused_columns=False,
|
113 |
+
)
|
114 |
+
|
115 |
+
# Initialize the Trainer
|
116 |
+
trainer = Trainer(
|
117 |
+
model=model,
|
118 |
+
args=training_args,
|
119 |
+
train_dataset=dataset["train"],
|
120 |
+
eval_dataset=dataset["test"],
|
121 |
+
)
|
122 |
+
|
123 |
+
# Train the model
|
124 |
+
trainer.train()
|
125 |
+
|
126 |
+
# Save the model
|
127 |
+
model.save_pretrained("./fine_tuned_model")
|
128 |
+
feature_extractor.save_pretrained("./fine_tuned_model")
|
129 |
+
logging.info("Model retrained and saved successfully.")
|
130 |
+
except Exception as e:
|
131 |
+
logging.error("Error during model retraining: %s", e)
|
132 |
+
|
133 |
+
# Create the Gradio interfaces
|
134 |
+
predict_interface = gr.Interface(
|
135 |
fn=predict,
|
136 |
inputs=gr.Sketchpad(),
|
137 |
outputs=gr.JSON(),
|
138 |
title="Drawing Classifier",
|
139 |
+
description="Draw something and the model will try to identify it!",
|
140 |
+
live=True
|
141 |
+
)
|
142 |
+
|
143 |
+
feedback_interface = gr.Interface(
|
144 |
+
fn=save_feedback,
|
145 |
+
inputs=[gr.Sketchpad(), gr.Textbox(label="Correct Label")],
|
146 |
+
outputs="text",
|
147 |
+
title="Save Feedback",
|
148 |
+
description="Draw something and provide the correct label to improve the model."
|
149 |
)
|
150 |
|
151 |
+
# Launch the interfaces together
|
152 |
+
gr.TabbedInterface(
|
153 |
+
[predict_interface, feedback_interface],
|
154 |
+
["Predict", "Provide Feedback"]
|
155 |
+
).launch(share=True)
|