# Install

In [2]:
%pip install uv

Note: you may need to restart the kernel to use updated packages.


In [None]:
!uv pip install dagshub setuptools accelerate toml torch torchvision transformers mlflow datasets ipywidgets python-dotenv evaluate

# Setup

In [1]:
import os
import toml
import torch
import mlflow
import dagshub
import datasets
import evaluate
from dotenv import load_dotenv
from torchvision.transforms import v2
from transformers import AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer

ENV_PATH = "/Users/andrewmayes/Openclassroom/CanineNet/.env"
CONFIG_PATH = "/Users/andrewmayes/Openclassroom/CanineNet/code/config.toml"
CONFIG = toml.load(CONFIG_PATH)

load_dotenv(ENV_PATH)

dagshub.init(repo_name=os.environ['MLFLOW_TRACKING_PROJECTNAME'], repo_owner=os.environ['MLFLOW_TRACKING_USERNAME'], mlflow=True, dvc=True)

os.environ['MLFLOW_TRACKING_USERNAME'] = "amaye15"

mlflow.set_tracking_uri(f'https://dagshub.com/' + os.environ['MLFLOW_TRACKING_USERNAME']
                         + '/' + os.environ['MLFLOW_TRACKING_PROJECTNAME'] + '.mlflow')

CREATE_DATASET = True
ORIGINAL_DATASET = "Alanox/stanford-dogs"
MODIFIED_DATASET = "amaye15/stanford-dogs"
REMOVE_COLUMNS = ["name", "annotations"]
RENAME_COLUMNS = {"image":"pixel_values", "target":"label"}
SPLIT = 0.2

METRICS = ["accuracy", "f1", "precision", "recall"]
# MODELS = 'google/vit-base-patch16-224'
# MODELS = "google/siglip-base-patch16-224"



# Dataset

In [2]:
if CREATE_DATASET:
    ds = datasets.load_dataset(ORIGINAL_DATASET, token=os.getenv("HF_TOKEN"), split="full", trust_remote_code=True)
    ds = ds.remove_columns(REMOVE_COLUMNS).rename_columns(RENAME_COLUMNS)

    labels = ds.select_columns("label").to_pandas().sort_values("label").get("label").unique().tolist()
    numbers = range(len(labels))
    label2int = dict(zip(labels, numbers))
    int2label = dict(zip(numbers, labels))

    for key, val in label2int.items():
        print(f"{key}: {val}")

    ds = ds.class_encode_column("label")
    ds = ds.align_labels_with_mapping(label2int, "label")

    ds = ds.train_test_split(test_size=SPLIT, stratify_by_column = "label")
    #ds.push_to_hub(MODIFIED_DATASET, token=os.getenv("HF_TOKEN"))

    CONFIG["label2int"] = str(label2int)
    CONFIG["int2label"] = str(int2label)

    # with open("output.toml", "w") as toml_file:
    #     toml.dump(toml.dumps(CONFIG), toml_file)

    #ds = datasets.load_dataset(MODIFIED_DATASET, token=os.getenv("HF_TOKEN"), trust_remote_code=True, streaming=True)

Affenpinscher: 0
Afghan Hound: 1
African Hunting Dog: 2
Airedale: 3
American Staffordshire Terrier: 4
Appenzeller: 5
Australian Terrier: 6
Basenji: 7
Basset: 8
Beagle: 9
Bedlington Terrier: 10
Bernese Mountain Dog: 11
Black And Tan Coonhound: 12
Blenheim Spaniel: 13
Bloodhound: 14
Bluetick: 15
Border Collie: 16
Border Terrier: 17
Borzoi: 18
Boston Bull: 19
Bouvier Des Flandres: 20
Boxer: 21
Brabancon Griffon: 22
Briard: 23
Brittany Spaniel: 24
Bull Mastiff: 25
Cairn: 26
Cardigan: 27
Chesapeake Bay Retriever: 28
Chihuahua: 29
Chow: 30
Clumber: 31
Cocker Spaniel: 32
Collie: 33
Curly Coated Retriever: 34
Dandie Dinmont: 35
Dhole: 36
Dingo: 37
Doberman: 38
English Foxhound: 39
English Setter: 40
English Springer: 41
Entlebucher: 42
Eskimo Dog: 43
Flat Coated Retriever: 44
French Bulldog: 45
German Shepherd: 46
German Short Haired Pointer: 47
Giant Schnauzer: 48
Golden Retriever: 49
Gordon Setter: 50
Great Dane: 51
Great Pyrenees: 52
Greater Swiss Mountain Dog: 53
Groenendael: 54
Ibizan Hou

In [3]:
metrics = {metric: evaluate.load(metric) for metric in METRICS}


# for lr in [5e-3, 5e-4, 5e-5]: # 5e-5
#     for batch in [64]: # 32
#         for model_name in ["google/vit-base-patch16-224", "microsoft/swinv2-base-patch4-window16-256", "google/siglip-base-patch16-224"]: # "facebook/dinov2-base"

lr = 5e-5
batch = 64
model_name = "google/siglip-base-patch16-224"

image_processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(
model_name,
num_labels=len(label2int),
id2label=int2label,
label2id=label2int,
ignore_mismatched_sizes=True,
)

# Then, in your transformations:
def train_transform(examples, num_ops=10, magnitude=9, num_magnitude_bins=31):

    transformation = v2.Compose(
        [
            v2.RandAugment(
                num_ops=num_ops,
                magnitude=magnitude,
                num_magnitude_bins=num_magnitude_bins,
            )
        ]
    )
    # Ensure each image has three dimensions (in this case, ensure it's RGB)
    examples["pixel_values"] = [
        image.convert("RGB") for image in examples["pixel_values"]
    ]
    # Apply transformations
    examples["pixel_values"] = [
        image_processor(transformation(image), return_tensors="pt")[
            "pixel_values"
        ].squeeze()
        for image in examples["pixel_values"]
    ]
    return examples


def test_transform(examples):
    # Ensure each image is RGB
    examples["pixel_values"] = [
        image.convert("RGB") for image in examples["pixel_values"]
    ]
    # Apply processing
    examples["pixel_values"] = [
        image_processor(image, return_tensors="pt")["pixel_values"].squeeze()
        for image in examples["pixel_values"]
    ]
    return examples


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    # predictions = np.argmax(logits, axis=-1)
    results = {}
    for key, val in metrics.items():
        if "accuracy" == key:
            result = next(
                iter(val.compute(predictions=predictions, references=labels).items())
            )
        if "accuracy" != key:
            result = next(
                iter(
                    val.compute(
                        predictions=predictions, references=labels, average="macro"
                    ).items()
                )
            )
        results[result[0]] = result[1]
    return results


def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}


def preprocess_logits_for_metrics(logits, labels):
    """
    Original Trainer may have a memory leak.
    This is a workaround to avoid storing too many tensors that are not needed.
    """
    pred_ids = torch.argmax(logits, dim=-1)
    return pred_ids

ds["train"].set_transform(train_transform)
ds["test"].set_transform(test_transform)

training_args = TrainingArguments(**CONFIG["training_args"])
training_args.per_device_train_batch_size = batch
training_args.per_device_eval_batch_size = batch
training_args.hub_model_id = f"amaye15/{model_name.replace('/','-')}-batch{batch}-lr{lr}-standford-dogs"

mlflow.start_run(run_name=f"{model_name.replace('/','-')}-batch{batch}-lr{lr}")

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds["train"],
    eval_dataset=ds["test"],
    tokenizer=image_processor,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    # callbacks=[early_stopping_callback],
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)

# Train the model
trainer.train()

trainer.push_to_hub()

mlflow.end_run()

Some weights of SiglipForImageClassification were not initialized from the model checkpoint at google/siglip-base-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
max_steps is given, it will override any value given in num_train_epochs


  0%|          | 0/1000 [00:00<?, ?it/s]



{'loss': 4.822, 'grad_norm': 11.180054664611816, 'learning_rate': 4.9500000000000004e-05, 'epoch': 0.16}


  0%|          | 0/65 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 4.254875183105469, 'eval_accuracy': 0.0782312925170068, 'eval_f1': 0.04927852996179247, 'eval_precision': 0.09874043278607707, 'eval_recall': 0.07264375052644872, 'eval_runtime': 55.3923, 'eval_samples_per_second': 74.306, 'eval_steps_per_second': 1.173, 'epoch': 0.16}




{'loss': 4.236, 'grad_norm': 17.628389358520508, 'learning_rate': 4.9e-05, 'epoch': 0.31}


  0%|          | 0/65 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 3.5278525352478027, 'eval_accuracy': 0.19071914480077745, 'eval_f1': 0.15072037668109098, 'eval_precision': 0.22011886017337598, 'eval_recall': 0.18300531751649823, 'eval_runtime': 55.5125, 'eval_samples_per_second': 74.145, 'eval_steps_per_second': 1.171, 'epoch': 0.31}




{'loss': 3.5066, 'grad_norm': 19.224912643432617, 'learning_rate': 4.85e-05, 'epoch': 0.47}


  0%|          | 0/65 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 2.531590223312378, 'eval_accuracy': 0.33187560738581146, 'eval_f1': 0.2941424042839124, 'eval_precision': 0.4180298856360352, 'eval_recall': 0.320509389455932, 'eval_runtime': 56.5067, 'eval_samples_per_second': 72.841, 'eval_steps_per_second': 1.15, 'epoch': 0.47}




{'loss': 2.8064, 'grad_norm': 22.580602645874023, 'learning_rate': 4.8e-05, 'epoch': 0.62}


  0%|          | 0/65 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 2.1243278980255127, 'eval_accuracy': 0.4361030126336249, 'eval_f1': 0.409040351489377, 'eval_precision': 0.5324247354698377, 'eval_recall': 0.4282087854976091, 'eval_runtime': 56.8186, 'eval_samples_per_second': 72.441, 'eval_steps_per_second': 1.144, 'epoch': 0.62}




{'loss': 2.441, 'grad_norm': 17.738447189331055, 'learning_rate': 4.75e-05, 'epoch': 0.78}


  0%|          | 0/65 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 1.5798275470733643, 'eval_accuracy': 0.5510204081632653, 'eval_f1': 0.5250154943185481, 'eval_precision': 0.6242284529813324, 'eval_recall': 0.5437767896591994, 'eval_runtime': 57.0171, 'eval_samples_per_second': 72.189, 'eval_steps_per_second': 1.14, 'epoch': 0.78}




{'loss': 2.0985, 'grad_norm': 18.94181251525879, 'learning_rate': 4.7e-05, 'epoch': 0.93}


  0%|          | 0/65 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 1.42424476146698, 'eval_accuracy': 0.5843051506316812, 'eval_f1': 0.557705493250987, 'eval_precision': 0.6400162236362443, 'eval_recall': 0.5768419977409593, 'eval_runtime': 57.2333, 'eval_samples_per_second': 71.916, 'eval_steps_per_second': 1.136, 'epoch': 0.93}




{'loss': 1.8689, 'grad_norm': 15.593049049377441, 'learning_rate': 4.6500000000000005e-05, 'epoch': 1.09}


  0%|          | 0/65 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 1.1481006145477295, 'eval_accuracy': 0.6625364431486881, 'eval_f1': 0.6455514206859728, 'eval_precision': 0.7142859368225736, 'eval_recall': 0.6564757487305617, 'eval_runtime': 54.7373, 'eval_samples_per_second': 75.196, 'eval_steps_per_second': 1.187, 'epoch': 1.09}




{'loss': 1.6588, 'grad_norm': 18.39203453063965, 'learning_rate': 4.600000000000001e-05, 'epoch': 1.24}


  0%|          | 0/65 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 1.1937264204025269, 'eval_accuracy': 0.6465014577259475, 'eval_f1': 0.6361000380324133, 'eval_precision': 0.7061715448588218, 'eval_recall': 0.6438641849166267, 'eval_runtime': 55.5191, 'eval_samples_per_second': 74.137, 'eval_steps_per_second': 1.171, 'epoch': 1.24}




{'loss': 1.5807, 'grad_norm': 15.319233894348145, 'learning_rate': 4.55e-05, 'epoch': 1.4}


  0%|          | 0/65 [00:00<?, ?it/s]

{'eval_loss': 0.9817520976066589, 'eval_accuracy': 0.70578231292517, 'eval_f1': 0.6890227220667341, 'eval_precision': 0.7438497507413404, 'eval_recall': 0.6980582780473442, 'eval_runtime': 54.4988, 'eval_samples_per_second': 75.525, 'eval_steps_per_second': 1.193, 'epoch': 1.4}




{'loss': 1.4851, 'grad_norm': 15.890103340148926, 'learning_rate': 4.5e-05, 'epoch': 1.55}


  0%|          | 0/65 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 1.0180633068084717, 'eval_accuracy': 0.6999514091350826, 'eval_f1': 0.6838587523312375, 'eval_precision': 0.7373019639077568, 'eval_recall': 0.6959074888023662, 'eval_runtime': 55.2973, 'eval_samples_per_second': 74.434, 'eval_steps_per_second': 1.175, 'epoch': 1.55}




{'loss': 1.5033, 'grad_norm': 17.170801162719727, 'learning_rate': 4.4500000000000004e-05, 'epoch': 1.71}


  0%|          | 0/65 [00:00<?, ?it/s]

{'eval_loss': 1.0169070959091187, 'eval_accuracy': 0.6914480077745384, 'eval_f1': 0.6845415929736886, 'eval_precision': 0.7489788726612852, 'eval_recall': 0.6883375806361393, 'eval_runtime': 54.525, 'eval_samples_per_second': 75.488, 'eval_steps_per_second': 1.192, 'epoch': 1.71}




{'loss': 1.3022, 'grad_norm': 15.557647705078125, 'learning_rate': 4.4000000000000006e-05, 'epoch': 1.86}


  0%|          | 0/65 [00:00<?, ?it/s]

{'eval_loss': 0.9087187051773071, 'eval_accuracy': 0.7276482021379981, 'eval_f1': 0.7169827898093813, 'eval_precision': 0.7642639115410531, 'eval_recall': 0.722171618202087, 'eval_runtime': 54.7556, 'eval_samples_per_second': 75.17, 'eval_steps_per_second': 1.187, 'epoch': 1.86}




{'loss': 1.3106, 'grad_norm': 15.203620910644531, 'learning_rate': 4.35e-05, 'epoch': 2.02}


  0%|          | 0/65 [00:00<?, ?it/s]

{'eval_loss': 0.8385488986968994, 'eval_accuracy': 0.7431972789115646, 'eval_f1': 0.7352483871752059, 'eval_precision': 0.7666810806987456, 'eval_recall': 0.7363282855094594, 'eval_runtime': 57.5486, 'eval_samples_per_second': 71.522, 'eval_steps_per_second': 1.129, 'epoch': 2.02}




{'loss': 1.1721, 'grad_norm': 18.051284790039062, 'learning_rate': 4.3e-05, 'epoch': 2.17}


  0%|          | 0/65 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 0.8956524133682251, 'eval_accuracy': 0.7128279883381924, 'eval_f1': 0.7025737877793609, 'eval_precision': 0.7591947211938203, 'eval_recall': 0.7074780847115492, 'eval_runtime': 58.6317, 'eval_samples_per_second': 70.201, 'eval_steps_per_second': 1.109, 'epoch': 2.17}




{'loss': 1.131, 'grad_norm': 16.522109985351562, 'learning_rate': 4.25e-05, 'epoch': 2.33}


  0%|          | 0/65 [00:00<?, ?it/s]

{'eval_loss': 0.8729854226112366, 'eval_accuracy': 0.7259475218658892, 'eval_f1': 0.7148538617252097, 'eval_precision': 0.7687155689784482, 'eval_recall': 0.719605645045331, 'eval_runtime': 54.6147, 'eval_samples_per_second': 75.364, 'eval_steps_per_second': 1.19, 'epoch': 2.33}




{'loss': 1.1223, 'grad_norm': 16.727994918823242, 'learning_rate': 4.2e-05, 'epoch': 2.48}


  0%|          | 0/65 [00:00<?, ?it/s]

{'eval_loss': 0.8132386803627014, 'eval_accuracy': 0.7546161321671526, 'eval_f1': 0.7457409451548778, 'eval_precision': 0.7855075954723149, 'eval_recall': 0.7482023394172116, 'eval_runtime': 54.8217, 'eval_samples_per_second': 75.08, 'eval_steps_per_second': 1.186, 'epoch': 2.48}




{'loss': 1.0688, 'grad_norm': 14.611897468566895, 'learning_rate': 4.15e-05, 'epoch': 2.64}


  0%|          | 0/65 [00:00<?, ?it/s]

{'eval_loss': 0.7485197186470032, 'eval_accuracy': 0.7704081632653061, 'eval_f1': 0.7600821180493263, 'eval_precision': 0.7863249503261968, 'eval_recall': 0.7631296317667979, 'eval_runtime': 56.4165, 'eval_samples_per_second': 72.957, 'eval_steps_per_second': 1.152, 'epoch': 2.64}




{'loss': 1.0686, 'grad_norm': 17.756242752075195, 'learning_rate': 4.1e-05, 'epoch': 2.79}


  0%|          | 0/65 [00:00<?, ?it/s]

{'eval_loss': 0.7559003233909607, 'eval_accuracy': 0.7650631681243926, 'eval_f1': 0.7586751052497263, 'eval_precision': 0.7920018070685718, 'eval_recall': 0.7609226412898984, 'eval_runtime': 53.1768, 'eval_samples_per_second': 77.402, 'eval_steps_per_second': 1.222, 'epoch': 2.79}




{'loss': 0.9733, 'grad_norm': 14.432697296142578, 'learning_rate': 4.05e-05, 'epoch': 2.95}


  0%|          | 0/65 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 0.7778576612472534, 'eval_accuracy': 0.7553449951409135, 'eval_f1': 0.7458481776644565, 'eval_precision': 0.7797152587168623, 'eval_recall': 0.7521461869682271, 'eval_runtime': 54.2043, 'eval_samples_per_second': 75.935, 'eval_steps_per_second': 1.199, 'epoch': 2.95}




In [None]:
mlflow.end_run()

NameError: name 'mlflow' is not defined

In [None]:
# training_args = TrainingArguments(**CONFIG["training_args"])

# image_processor = AutoImageProcessor.from_pretrained(MODELS)
# model = AutoModelForImageClassification.from_pretrained(
# MODELS,
# num_labels=len(CONFIG["label2int"]),
# id2label=CONFIG["label2int"],
# label2id=CONFIG["int2label"],
# ignore_mismatched_sizes=True,
# )


# training_args = TrainingArguments(**CONFIG["training_args"])

# trainer = Trainer(
#     model=model,
#     args=training_args,
#     train_dataset=ds["train"],
#     eval_dataset=ds["test"],
#     tokenizer=image_processor,
#     data_collator=collate_fn,
#     compute_metrics=compute_metrics,
#     # callbacks=[early_stopping_callback],
#     preprocess_logits_for_metrics=preprocess_logits_for_metrics,
# )

# # Train the model
# trainer.train()

# mlflow.end_run()