Image Quality Regression Model
This model is trained on the dataset yigagilbert/image_quality_dataset and performs regression tasks to predict image quality scores.
Model Details
- Dataset: yigagilbert/image_quality_dataset
- Target Column: quality_score
- Test Split: 20% test data
- Training Epochs: 3
- Learning Rate: 5e-5
- Max Value in Dataset: 54.02
This model fine-tunes the google/vit-base-patch16-224 Vision Transformer using PyTorch and Hugging Face's ๐ค Transformers library. It predicts a numerical score based on the quality of the input image.
Image Regression Model
This repository contains a model for image regression tasks, where the goal is to predict a numerical value from an input image. The model fine-tunes the google/vit-base-patch16-224 Vision Transformer using PyTorch and ๐ค Hugging Face tools. You can train the model, upload it to the ๐ค Model Hub, and perform inference using a simple API.
Installation
Install the required packages by running:
pip install -r requirements.txt
Usage
Import Functions
from ImageRegression import train_model, upload_model, predict
Train Model
Train the model using the train_model()
function. Below are the key parameters:
- dataset_id: Hugging Face dataset identifier or path to your local dataset.
- value_column_name: Column in the dataset containing the target regression values.
- test_split: Proportion of data to use for testing (e.g.,
0.2
for 20% test data). - output_dir: Directory where model checkpoints will be saved.
- num_train_epochs: Number of training epochs.
- learning_rate: Learning rate for the optimizer.
train_model(dataset_id='yigagilbert/image_quality_dataset',
value_column_name='quality_score',
test_split=0.2,
output_dir='./model_output',
num_train_epochs=10,
learning_rate=1e-4)
Training progress will be logged, and checkpoints will be saved in output_dir
. These checkpoints can be used for model inference and uploaded to the ๐ค Hub.
Upload Model to Hugging Face Hub
To upload your trained model to the ๐ค Hub, use the upload_model()
function:
- model_id: The name of the model repository on the ๐ค Hub.
- token: Authentication token (create one here).
- checkpoint_dir: Directory where the trained model checkpoints are located.
upload_model(model_id='yigagilbert/image-qaulity-model',
token='your_HF_token',
checkpoint_dir='./model_output/checkpoint-940')
Once uploaded, the model can be used for inference directly from the Hub.
Model Inference (Prediction)
You can perform inference using the predict()
function.
- repo_id: The repository identifier of the uploaded model.
- image_path: Path to the image file you want to run predictions on.
predict(repo_id='yigagilbert/image-qaulity-model',
image_path='path_to_image.jpg')
The first time you run inference, the model will be downloaded from the Hugging Face Hub. Subsequent inferences will run faster as the model is cached locally.