0xnu
/

Image Classification
Keras
vision
0xnu commited on
Commit
730ff55
·
verified ·
1 Parent(s): f5b9188

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +107 -0
README.md CHANGED
@@ -7,6 +7,113 @@ datasets:
7
  inference: false
8
  ---
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  ### Dataset
12
 
 
7
  inference: false
8
  ---
9
 
10
+ ### Install Packages
11
+
12
+ ```sh
13
+ pip install numpy opencv-python requests pillow transformers tensorflow
14
+ ```
15
+
16
+ ### Usage
17
+
18
+ ```python
19
+ import numpy as np
20
+ import cv2
21
+ import requests
22
+ from PIL import Image
23
+ from io import BytesIO
24
+ from transformers import TFAutoModelForImageClassification, AutoFeatureExtractor
25
+
26
+ class MNISTPredictor:
27
+ def __init__(self, model_name):
28
+ self.model = TFAutoModelForImageClassification.from_pretrained(model_name)
29
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
30
+
31
+ def extract_features(self, image):
32
+ """Extract features from the image for multiple digits."""
33
+ # Convert to grayscale
34
+ gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
35
+
36
+ # Apply Gaussian blur
37
+ blurred = cv2.GaussianBlur(gray, (5, 5), 0)
38
+
39
+ # Apply adaptive thresholding
40
+ thresh = cv2.adaptiveThreshold(blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2)
41
+
42
+ # Find contours
43
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
44
+
45
+ digit_images = []
46
+ for contour in contours:
47
+ # Filter small contours
48
+ if cv2.contourArea(contour) > 50: # Adjust this threshold as needed
49
+ x, y, w, h = cv2.boundingRect(contour)
50
+ roi = thresh[y:y+h, x:x+w]
51
+ resized = cv2.resize(roi, (28, 28), interpolation=cv2.INTER_AREA)
52
+ digit_images.append(Image.fromarray(resized).convert('RGB'))
53
+
54
+ return digit_images
55
+
56
+ def predict(self, image):
57
+ """Predict digits in the image."""
58
+ try:
59
+ digit_images = self.extract_features(image)
60
+ predictions = []
61
+ for digit_image in digit_images:
62
+ inputs = self.feature_extractor(images=digit_image, return_tensors="tf")
63
+ outputs = self.model(**inputs)
64
+ predicted_class = int(np.argmax(outputs.logits))
65
+ predictions.append(predicted_class)
66
+ return predictions
67
+ except Exception as e:
68
+ print(f"Error during prediction: {e}")
69
+ return None
70
+
71
+ def download_image(url):
72
+ """Download an image from a URL."""
73
+ try:
74
+ response = requests.get(url)
75
+ response.raise_for_status()
76
+ return Image.open(BytesIO(response.content))
77
+ except Exception as e:
78
+ print(f"Error downloading image: {e}")
79
+ return None
80
+
81
+ def save_predictions_to_file(predictions, output_path):
82
+ """Save predictions to a text file."""
83
+ try:
84
+ with open(output_path, 'w') as f:
85
+ f.write(f"Predicted digits are: {', '.join(map(str, predictions))}\n")
86
+ except Exception as e:
87
+ print(f"Error saving predictions to file: {e}")
88
+
89
+ def main(image_url, model_name, output_path):
90
+ try:
91
+ predictor = MNISTPredictor(model_name)
92
+
93
+ # Download image
94
+ image = download_image(image_url)
95
+ if image is None:
96
+ raise Exception("Failed to download image")
97
+
98
+ print(f"Image downloaded successfully.")
99
+
100
+ # Predict digits
101
+ digits = predictor.predict(image)
102
+ print(f"Predicted digits are: {digits}")
103
+
104
+ # Save predictions to file
105
+ save_predictions_to_file(digits, output_path)
106
+ print(f"Predictions saved to {output_path}")
107
+ except Exception as e:
108
+ print(f"An error occurred: {e}")
109
+
110
+ if __name__ == "__main__":
111
+ image_url = "https://miro.medium.com/v2/resize:fit:720/format:webp/1*w7pBsjI3t3ZP-4Gdog-JdQ.png"
112
+ model_name = "0xnu/mnist-ocr"
113
+ output_path = "predictions.txt"
114
+
115
+ main(image_url, model_name, output_path)
116
+ ```
117
 
118
  ### Dataset
119