Hugging Face

FlowerNet

Нейросеть для многоклассовой классификации цветов.

одуванчик

Введение

Цель данной работы заключается в разработке нейронной сети для многоклассовой классификации, обладающей высокой устойчивостью к переобучению.

Набор данных (Dataset)

Для решения задачи многоклассовой классификации цветов, я использовал набор данных tf_flowers из tensorflow. Набор имеет 5 классов цветов: 'Одуванчик', 'Ромашка', 'Тюльпаны', 'Подсолнухи' и 'Розы'. Поэтому на конечном слое Dense 5 нейронов. Теперь про выборки. Я разбил набор данных на три выборки: от 0 до 80% - тренировочная, от 80% до 90% - проверочная(валидационная) и от 90% до 100% - тестовая.

Архитектура сети

К качестве архитектуры я использовал xception. Схема архитектуры получилась большая, поэтому я решил не вставлять ей сюда, а загрузить в файлы проекта. Нейронная сеть предназначена для работы на тензорных процессорах (TPU), это позволяет повысить количество эпох и мощность.

Оптимизатор и функция потерь

image Моей целью было создать крепкую нейронную сеть, которая обладала бы высокой устойчивостью к переобучению. И тут начинается настройка. Если использовать оптимизатор Adam, который я использовал ранее, то точность будет 90%, но при этом будет переобучение. Поэтому я решил зайти с другого бока, и использовать оптимизатор Adagrad(Adaptive Gradient) - его точность на 10 эпохе была 40%, но чем больше эпох, тем лучше его точность, и при этом точность проверочной выборки будет всегда выше чем тренировочной, и переобучения не будет. В качестве функции потерь я использую SparseCategoricalCrossentropy, так как именно её нужно использовать на TPU моделях. Так как модель моя модель использует тензорный процессор и быстро проходит эпохи, я решил увеличить количество эпох до тысячи. Adagrad начал с 40%, постепенно его точность увеличивалась, и в конечном итоге я получил точность 89.65% на проверочных данных и 0.87% на тестовых. При этом на графике можно увидеть, что модель не подвергается переобучению.

Результат

image

Задача выполнена. Я создал модель которая имеет устойчивую защиту от переобучения и хорошую точность 87%. В файлах проекта модель называется FlowerNet.h5

Страница на github: https://github.com/laf3r/FlowerNet

Программа предоставляется в виде открытого исходного кода.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.