ayushsinha's picture
Create README.md
e8dd40e verified
### **🩺 ResNet-18 Cataract Detection System**
This repository hosts a quantized version of **ResNet-18-based** model optimized for **cataract detection** having two labels either normal or cataract. The model detects images into these 2 labels.
---
## **πŸ“Œ Model Details**
- **Model Architecture**: ResNet-18
- **Task**: Cataract Detection System
- **Dataset**: Cataract Dataset ([Kaggle](https://www.kaggle.com/datasets/nandanp6/cataract-image-dataset))
- **Framework**: PyTorch
- **Input Image Size**: 224x224
- **Number of Classes**: 2
---
## **πŸš€ Usage**
### **Installation**
```bash
pip install torch torchvision pillow
```
### **Loading the Model**
```python
import torch
import torchvision.models as models
from huggingface_hub import hf_hub_download
import json
from PIL import Image
import torchvision.transforms as transforms
device = "cuda" if torch.cuda.is_available() else "cpu"
weights_path = hf_hub_download(repo_id="AventIQ-AI/resnet18-cataract-detection-system", filename="cataract_detection_resnet18_quantized.pth")
labels_path = hf_hub_download(repo_id="AventIQ-AI/resnet18-cataract-detection-system", filename="class_names.json")
with open(labels_path, "r") as f:
class_labels = json.load(f)
model = models.resnet18(pretrained=False)
num_classes = len(class_labels)
model.fc = torch.nn.Linear(in_features=512, out_features=num_classes)
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
model.eval()
print("Model loaded successfully!")
```
---
### **πŸ” Perform Classification**
```python
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def predict_image(image_path):
image = Image.open(image_path).convert("RGB")
image = transform(image).unsqueeze(0).to(device) # Add batch dimension
with torch.no_grad():
outputs = model(image)
_, predicted_class = torch.max(outputs, 1)
predicted_label = class_labels[predicted_class.item()]
print(f"Predicted Class: {predicted_label}")
# Example usage:
image_path = "your_image_path"
predict_image(image_path)
```
## **πŸ“Š Evaluation Results**
After fine-tuning, the model was evaluated on the **Chest X-ray Pneumonia Dataset**, achieving the following performance:
| **Metric** | **Score** |
|------------------|----------|
| **Accuracy** | 97.52% |
| **Precision** | 98.31% |
| **Recall** | 96.67% |
| **F1-Score** | 97.48% |
---
## **πŸ”§ Fine-Tuning Details**
### **Dataset**
The model was trained on **Cataract Dataset** having two labels.
### **Training Configuration**
- **Number of epochs**: 10
- **Batch size**: 32
- **Optimizer**: Adam
- **Learning rate**: 1e-4
- **Loss Function**: Cross-Entropy
- **Evaluation Strategy**: Validation at each epoch
---
## **⚠️ Limitations**
- **Misclassification risk**: The model may produce **false positives or false negatives**. Always verify results with a radiologist.
- **Dataset bias**: Performance may be affected by **dataset distribution**. It may not generalize well to **different populations**.
- **Black-box nature**: Like all deep learning models, it does not explain why a prediction was made.
---