Jangai commited on
Commit
3dd9291
·
verified ·
1 Parent(s): a56fdb6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -144
app.py CHANGED
@@ -1,155 +1,62 @@
1
  import gradio as gr
2
- import torch
3
- from transformers import AutoFeatureExtractor, AutoModelForImageClassification, Trainer, TrainingArguments
4
- from PIL import Image
5
  import numpy as np
6
- import pandas as pd
7
  import os
8
- import logging
9
- from datasets import Dataset, DatasetDict
 
10
 
11
- # Configure logging
12
- logging.basicConfig(level=logging.DEBUG)
13
 
14
- # Load the pre-trained model and feature extractor
15
- model_name = "google/vit-base-patch16-224"
16
- logging.info("Loading image processor and model...")
17
- feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
18
- model = AutoModelForImageClassification.from_pretrained(model_name)
19
 
20
- # Load or initialize the feedback data
21
- feedback_data_path = "feedback_data.csv"
22
- if os.path.exists(feedback_data_path):
23
- feedback_data = pd.read_csv(feedback_data_path)
24
- else:
25
- feedback_data = pd.DataFrame(columns=["image_path", "correct_label"])
26
 
27
- # Directory to save images
28
- os.makedirs("images", exist_ok=True)
 
 
 
29
 
30
- # Define the prediction function
31
  def predict(image):
32
- try:
33
- logging.info("Received image of type: %s", type(image))
34
- logging.debug("Image content: %s", image)
35
-
36
- # Convert to NumPy array and then to PIL image
37
- image = np.array(image).astype('uint8')
38
- image = Image.fromarray(image, 'RGBA').convert('RGB')
39
-
40
- logging.info("Processing image...")
41
- inputs = feature_extractor(images=image, return_tensors="pt")
42
- outputs = model(**inputs)
43
- logits = outputs.logits
44
- probs = torch.nn.functional.softmax(logits, dim=-1)
45
- top_probs, top_idxs = probs.topk(3, dim=-1)
46
- top_probs = top_probs.detach().numpy()[0]
47
- top_idxs = top_idxs.detach().numpy()[0]
48
- top_classes = [model.config.id2label[idx] for idx in top_idxs]
49
- result = {top_classes[i]: float(top_probs[i]) for i in range(3)}
50
- logging.info("Prediction successful.")
51
- return result
52
- except Exception as e:
53
- logging.error("Error during prediction: %s", e)
54
- return {"error": str(e)}
55
 
56
- # Save feedback and retrain if necessary
57
- def save_feedback(image, correct_label):
58
  global feedback_data
59
- try:
60
- image_np = np.array(image).astype('uint8')
61
- image_pil = Image.fromarray(image_np, 'RGBA').convert('RGB')
62
- image_path = f"images/{len(feedback_data)}.png"
63
- image_pil.save(image_path)
64
-
65
- # Add the feedback to the DataFrame
66
- feedback_data = feedback_data.append({"image_path": image_path, "correct_label": correct_label}, ignore_index=True)
67
- feedback_data.to_csv(feedback_data_path, index=False)
68
-
69
- # Retrain if we have collected 5 new feedbacks
70
- if len(feedback_data) % 5 == 0:
71
- retrain_model(feedback_data)
72
-
73
- return "Feedback saved and model retrained!" if len(feedback_data) % 5 == 0 else "Feedback saved!"
74
- except Exception as e:
75
- logging.error("Error saving feedback: %s", e)
76
- return {"error": str(e)}
77
-
78
- # Retrain the model with the feedback data
79
- def retrain_model(feedback_data):
80
- try:
81
- logging.info("Retraining the model with feedback data...")
82
-
83
- # Load images and labels into a Hugging Face dataset
84
- def load_image(file_path):
85
- return Image.open(file_path).convert("RGB")
86
-
87
- dataset_dict = {
88
- "image": [load_image(f) for f in feedback_data["image_path"]],
89
- "label": feedback_data["correct_label"].astype(int).tolist() # Ensure labels are integers
90
- }
91
-
92
- dataset = Dataset.from_dict(dataset_dict)
93
- dataset = dataset.train_test_split(test_size=0.1)
94
-
95
- # Preprocess the dataset
96
- def preprocess(examples):
97
- inputs = feature_extractor(images=examples["image"], return_tensors="pt")
98
- inputs["labels"] = examples["label"]
99
- return inputs
100
-
101
- dataset = dataset.with_transform(preprocess)
102
-
103
- # Set up the training arguments
104
- training_args = TrainingArguments(
105
- output_dir="./results",
106
- evaluation_strategy="epoch",
107
- per_device_train_batch_size=4,
108
- per_device_eval_batch_size=4,
109
- num_train_epochs=3,
110
- save_strategy="epoch",
111
- save_total_limit=2,
112
- remove_unused_columns=False,
113
- )
114
-
115
- # Initialize the Trainer
116
- trainer = Trainer(
117
- model=model,
118
- args=training_args,
119
- train_dataset=dataset["train"],
120
- eval_dataset=dataset["test"],
121
- )
122
-
123
- # Train the model
124
- trainer.train()
125
-
126
- # Save the model
127
- model.save_pretrained("./fine_tuned_model")
128
- feature_extractor.save_pretrained("./fine_tuned_model")
129
- logging.info("Model retrained and saved successfully.")
130
- except Exception as e:
131
- logging.error("Error during model retraining: %s", e)
132
-
133
- # Create the Gradio interfaces
134
- predict_interface = gr.Interface(
135
- fn=predict,
136
- inputs=gr.Sketchpad(label="Draw something"),
137
- outputs=gr.JSON(),
138
- title="Drawing Classifier",
139
- description="Draw something and the model will try to identify it!",
140
- live=False
141
- )
142
-
143
- feedback_interface = gr.Interface(
144
- fn=save_feedback,
145
- inputs=[gr.Sketchpad(label="Draw something"), gr.Textbox(label="Enter the correct label")],
146
- outputs="text",
147
- title="Save Feedback",
148
- description="Draw something and provide the correct label to improve the model."
149
- )
150
-
151
- # Launch the interfaces together
152
- gr.TabbedInterface(
153
- [predict_interface, feedback_interface],
154
- ["Predict", "Provide Feedback"]
155
- ).launch(share=True)
 
1
  import gradio as gr
 
 
 
2
  import numpy as np
 
3
  import os
4
+ import pickle
5
+ from PIL import Image
6
+ from transformers import pipeline
7
 
8
+ # Load the classifier model
9
+ classifier = pipeline("image-classification", model="google/vit-base-patch16-224")
10
 
11
+ # Initialize a dictionary to store feedback
12
+ feedback_data = {"images": [], "labels": []}
 
 
 
13
 
14
+ def save_feedback():
15
+ with open("feedback_data.pkl", "wb") as f:
16
+ pickle.dump(feedback_data, f)
 
 
 
17
 
18
+ def load_feedback():
19
+ global feedback_data
20
+ if os.path.exists("feedback_data.pkl"):
21
+ with open("feedback_data.pkl", "rb") as f:
22
+ feedback_data = pickle.load(f)
23
 
 
24
  def predict(image):
25
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
26
+ prediction = classifier(image)
27
+ return {pred["label"]: pred["score"] for pred in prediction}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ def provide_feedback(image, label):
 
30
  global feedback_data
31
+ feedback_data["images"].append(image)
32
+ feedback_data["labels"].append(label)
33
+ save_feedback()
34
+
35
+ if len(feedback_data["images"]) % 5 == 0:
36
+ retrain_model()
37
+ return "Feedback saved. Thank you!"
38
+
39
+ def retrain_model():
40
+ global classifier
41
+ # Here, include the retraining logic using the feedback_data
42
+ # This is a placeholder for actual retraining logic
43
+ print("Retraining the model with new data...")
44
+
45
+ # Load existing feedback data
46
+ load_feedback()
47
+
48
+ with gr.Blocks() as demo:
49
+ with gr.Tab("Predict"):
50
+ image_input = gr.Sketchpad()
51
+ output = gr.JSON()
52
+ image_input.change(fn=predict, inputs=image_input, outputs=output)
53
+
54
+ with gr.Tab("Provide Feedback"):
55
+ image_feedback = gr.Sketchpad()
56
+ label_feedback = gr.Textbox(label="Enter the correct label")
57
+ feedback_button = gr.Button("Submit Feedback")
58
+ feedback_output = gr.Textbox()
59
+
60
+ feedback_button.click(fn=provide_feedback, inputs=[image_feedback, label_feedback], outputs=feedback_output)
61
+
62
+ demo.launch()