Jangai commited on
Commit
ffb04af
·
verified ·
1 Parent(s): 5372355

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -18
app.py CHANGED
@@ -1,18 +1,31 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import DetrImageProcessor, DetrForObjectDetection
4
  from PIL import Image
5
  import numpy as np
 
 
6
  import logging
 
7
 
8
  # Configure logging
9
  logging.basicConfig(level=logging.DEBUG)
10
 
11
  # Load the pre-trained model and feature extractor
12
- model_name = "IDEA-Research/grounding-dino-tiny"
13
  logging.info("Loading image processor and model...")
14
- feature_extractor = DetrImageProcessor.from_pretrained(model_name)
15
- model = DetrForObjectDetection.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
16
 
17
  # Define the prediction function
18
  def predict(image):
@@ -20,15 +33,9 @@ def predict(image):
20
  logging.info("Received image of type: %s", type(image))
21
  logging.debug("Image content: %s", image)
22
 
23
- # Use the 'composite' key to get the final image
24
- if isinstance(image, dict):
25
- image = image['composite']
26
-
27
- logging.debug("Converting to NumPy array...")
28
  image = np.array(image).astype('uint8')
29
- logging.debug("Converting NumPy array to PIL image...")
30
  image = Image.fromarray(image, 'RGBA').convert('RGB')
31
- logging.debug("Image converted successfully.")
32
 
33
  logging.info("Processing image...")
34
  inputs = feature_extractor(images=image, return_tensors="pt")
@@ -36,8 +43,8 @@ def predict(image):
36
  logits = outputs.logits
37
  probs = torch.nn.functional.softmax(logits, dim=-1)
38
  top_probs, top_idxs = probs.topk(3, dim=-1)
39
- top_probs = top_probs.detach().numpy()[0].tolist() # Convert to list
40
- top_idxs = top_idxs.detach().numpy()[0].tolist() # Convert to list
41
  top_classes = [model.config.id2label[idx] for idx in top_idxs]
42
  result = {top_classes[i]: float(top_probs[i]) for i in range(3)}
43
  logging.info("Prediction successful.")
@@ -46,14 +53,103 @@ def predict(image):
46
  logging.error("Error during prediction: %s", e)
47
  return {"error": str(e)}
48
 
49
- # Create the Gradio interface
50
- iface = gr.Interface(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  fn=predict,
52
  inputs=gr.Sketchpad(),
53
  outputs=gr.JSON(),
54
  title="Drawing Classifier",
55
- description="Draw something and the model will try to identify it!"
 
 
 
 
 
 
 
 
 
56
  )
57
 
58
- # Launch the interface
59
- iface.launch()
 
 
 
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification, Trainer, TrainingArguments
4
  from PIL import Image
5
  import numpy as np
6
+ import pandas as pd
7
+ import os
8
  import logging
9
+ from datasets import Dataset
10
 
11
  # Configure logging
12
  logging.basicConfig(level=logging.DEBUG)
13
 
14
  # Load the pre-trained model and feature extractor
15
+ model_name = "google/vit-base-patch16-224"
16
  logging.info("Loading image processor and model...")
17
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
18
+ model = AutoModelForImageClassification.from_pretrained(model_name)
19
+
20
+ # Load or initialize the feedback data
21
+ feedback_data_path = "feedback_data.csv"
22
+ if os.path.exists(feedback_data_path):
23
+ feedback_data = pd.read_csv(feedback_data_path)
24
+ else:
25
+ feedback_data = pd.DataFrame(columns=["image_path", "correct_label"])
26
+
27
+ # Directory to save images
28
+ os.makedirs("images", exist_ok=True)
29
 
30
  # Define the prediction function
31
  def predict(image):
 
33
  logging.info("Received image of type: %s", type(image))
34
  logging.debug("Image content: %s", image)
35
 
36
+ # Convert to NumPy array and then to PIL image
 
 
 
 
37
  image = np.array(image).astype('uint8')
 
38
  image = Image.fromarray(image, 'RGBA').convert('RGB')
 
39
 
40
  logging.info("Processing image...")
41
  inputs = feature_extractor(images=image, return_tensors="pt")
 
43
  logits = outputs.logits
44
  probs = torch.nn.functional.softmax(logits, dim=-1)
45
  top_probs, top_idxs = probs.topk(3, dim=-1)
46
+ top_probs = top_probs.detach().numpy()[0]
47
+ top_idxs = top_idxs.detach().numpy()[0]
48
  top_classes = [model.config.id2label[idx] for idx in top_idxs]
49
  result = {top_classes[i]: float(top_probs[i]) for i in range(3)}
50
  logging.info("Prediction successful.")
 
53
  logging.error("Error during prediction: %s", e)
54
  return {"error": str(e)}
55
 
56
+ # Save feedback and retrain if necessary
57
+ def save_feedback(image, correct_label):
58
+ global feedback_data
59
+ try:
60
+ image_np = np.array(image['composite']).astype('uint8')
61
+ image_pil = Image.fromarray(image_np, 'RGBA').convert('RGB')
62
+ image_path = f"images/{len(feedback_data)}.png"
63
+ image_pil.save(image_path)
64
+
65
+ # Add the feedback to the DataFrame
66
+ feedback_data = feedback_data.append({"image_path": image_path, "correct_label": correct_label}, ignore_index=True)
67
+ feedback_data.to_csv(feedback_data_path, index=False)
68
+
69
+ # Retrain if we have collected 5 new feedbacks
70
+ if len(feedback_data) % 5 == 0:
71
+ retrain_model(feedback_data)
72
+
73
+ return "Feedback saved and model retrained!" if len(feedback_data) % 5 == 0 else "Feedback saved!"
74
+ except Exception as e:
75
+ logging.error("Error saving feedback: %s", e)
76
+ return {"error": str(e)}
77
+
78
+ # Retrain the model with the feedback data
79
+ def retrain_model(feedback_data):
80
+ try:
81
+ logging.info("Retraining the model with feedback data...")
82
+
83
+ # Load images and labels into a Hugging Face dataset
84
+ def load_image(file_path):
85
+ return Image.open(file_path).convert("RGB")
86
+
87
+ dataset_dict = {
88
+ "image": [load_image(f) for f in feedback_data["image_path"]],
89
+ "label": feedback_data["correct_label"].tolist()
90
+ }
91
+
92
+ dataset = Dataset.from_dict(dataset_dict)
93
+ dataset = dataset.train_test_split(test_size=0.1)
94
+
95
+ # Preprocess the dataset
96
+ def preprocess(examples):
97
+ inputs = feature_extractor(images=examples["image"], return_tensors="pt")
98
+ inputs["labels"] = examples["label"]
99
+ return inputs
100
+
101
+ dataset = dataset.with_transform(preprocess)
102
+
103
+ # Set up the training arguments
104
+ training_args = TrainingArguments(
105
+ output_dir="./results",
106
+ evaluation_strategy="epoch",
107
+ per_device_train_batch_size=4,
108
+ per_device_eval_batch_size=4,
109
+ num_train_epochs=3,
110
+ save_strategy="epoch",
111
+ save_total_limit=2,
112
+ remove_unused_columns=False,
113
+ )
114
+
115
+ # Initialize the Trainer
116
+ trainer = Trainer(
117
+ model=model,
118
+ args=training_args,
119
+ train_dataset=dataset["train"],
120
+ eval_dataset=dataset["test"],
121
+ )
122
+
123
+ # Train the model
124
+ trainer.train()
125
+
126
+ # Save the model
127
+ model.save_pretrained("./fine_tuned_model")
128
+ feature_extractor.save_pretrained("./fine_tuned_model")
129
+ logging.info("Model retrained and saved successfully.")
130
+ except Exception as e:
131
+ logging.error("Error during model retraining: %s", e)
132
+
133
+ # Create the Gradio interfaces
134
+ predict_interface = gr.Interface(
135
  fn=predict,
136
  inputs=gr.Sketchpad(),
137
  outputs=gr.JSON(),
138
  title="Drawing Classifier",
139
+ description="Draw something and the model will try to identify it!",
140
+ live=True
141
+ )
142
+
143
+ feedback_interface = gr.Interface(
144
+ fn=save_feedback,
145
+ inputs=[gr.Sketchpad(), gr.Textbox(label="Correct Label")],
146
+ outputs="text",
147
+ title="Save Feedback",
148
+ description="Draw something and provide the correct label to improve the model."
149
  )
150
 
151
+ # Launch the interfaces together
152
+ gr.TabbedInterface(
153
+ [predict_interface, feedback_interface],
154
+ ["Predict", "Provide Feedback"]
155
+ ).launch(share=True)