{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "87345732-d868-473b-b1a1-5c25839ce25b",
"metadata": {},
"outputs": [],
"source": [
"from fastai.vision.all import *"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "79b9fbad-7b99-40fd-8768-b0a091bf85cb",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/conda/envs/py310-cuda116/lib/python3.10/site-packages/paramiko/transport.py:236: CryptographyDeprecationWarning: Blowfish has been deprecated\n",
" \"class\": algorithms.Blowfish,\n"
]
}
],
"source": [
"import gradio"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "5409c6a7-5cae-42bb-8335-587a04471f22",
"metadata": {},
"outputs": [],
"source": [
"MODELS_PATH = Path(\"./models\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "4e836799-6858-438a-8d70-d95f98cf54f7",
"metadata": {},
"outputs": [],
"source": [
"EXAMPLES_PATH = Path('./examples')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "9ed20c60-9f23-4795-bb4b-79b00af0f6d1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#2) [Path('models/food-101-resnet34.pkl'),Path('models/food-101-resnet50.pkl')]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"MODELS_PATH.ls()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "0969ba8e-b0df-4550-a900-5d5a30fb0187",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#9) [Path('examples/pad_thai.jpeg'),Path('examples/takoyaki.jpeg'),Path('examples/momo.jpeg'),Path('examples/falafel.jpeg'),Path('examples/paella.jpeg'),Path('examples/ravioli.jpeg'),Path('examples/huevos_rancheros.jpeg'),Path('examples/edamame.jpeg'),Path('examples/sushi.jpeg')]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"EXAMPLES_PATH.ls()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "e9143742-c6bc-44f6-8ecd-3826502c84ac",
"metadata": {},
"outputs": [],
"source": [
"def label_func(filepath):\n",
" return filepath.parent.name"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "c6ad64e8-f163-4472-b2f0-c0aa50ead4d8",
"metadata": {},
"outputs": [],
"source": [
"learn = load_learner(MODELS_PATH/'food-101-resnet50.pkl')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "d1370d20-fd51-4512-bd28-5f170d216c7b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', 'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito', 'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake', 'ceviche', 'cheese_plate', 'cheesecake', 'chicken_curry', 'chicken_quesadilla', 'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros', 'clam_chowder', 'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes', 'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict', 'escargots', 'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras', 'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice', 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', 'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza', 'hamburger', 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream', 'lasagna', 'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese', 'macarons', 'miso_soup', 'mussels', 'nachos', 'omelette', 'onion_rings', 'oysters', 'pad_thai', 'paella', 'pancakes', 'panna_cotta', 'peking_duck', 'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib', 'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake', 'risotto', 'samosa', 'sashimi', 'scallops', 'seaweed_salad', 'shrimp_and_grits', 'spaghetti_bolognese', 'spaghetti_carbonara', 'spring_rolls', 'steak', 'strawberry_shortcake', 'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles']"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"labels = learn.dls.vocab\n",
"labels"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8f666b42-9fdd-45ca-81ca-7e98dd191369",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 11,
"id": "a360dd6b-75a9-43e5-b91d-c6963ea462ea",
"metadata": {},
"outputs": [],
"source": [
"def predict(img):\n",
" img = PILImage.create(img)\n",
" _pred, _pred_w_idx, probs = learn.predict(img)\n",
" labels_probs = {labels[i]: float(probs[i]) for i, _ in enumerate(labels)}\n",
" return labels_probs"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "febc7266-8587-4530-811b-f2fa9117dcd5",
"metadata": {},
"outputs": [],
"source": [
"with open('gradio_article.md') as f:\n",
" article = f.read()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "8fd4ffb4-11ca-4b25-999c-cde2a4e236b4",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/conda/envs/py310-cuda116/lib/python3.10/site-packages/gradio/interface.py:419: UserWarning: The `enable_queue` parameter in the `Interface`will be deprecated and may not work properly. Please use the `enable_queue` parameter in `launch()` instead\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running on local URL: http://localhost:9999/\n",
"\n",
"To create a public link, set `share=True` in `launch()`.\n"
]
},
{
"data": {
"text/html": [
"\n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(,\n",
" 'http://localhost:9999/',\n",
" None)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"interface_options = {\n",
" \"title\": \"Food-101 Classifier\",\n",
" \"description\": \"A food image classifier trained on the Food-101 (https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/) dataset with fastai with a ResNet50 CNN model.\",\n",
" \"article\": article,\n",
" \"examples\" : [f'{EXAMPLES_PATH}/{f.name}' for f in EXAMPLES_PATH.iterdir()],\n",
" \"interpretation\": \"default\",\n",
" \"layout\": \"horizontal\",\n",
" \"allow_flagging\": \"never\",\n",
" \"enable_queue\": True \n",
"}\n",
"\n",
"demo = gradio.Interface(fn=predict,\n",
" inputs=gradio.inputs.Image(shape=(512, 512)),\n",
" outputs=gradio.outputs.Label(num_top_classes=5),\n",
" **interface_options)\n",
"\n",
"demo_options = {\n",
" \"inline\": True,\n",
" \"inbrowser\": False,\n",
" \"share\": False,\n",
" \"show_error\": True,\n",
" \"server_name\": \"0.0.0.0\",\n",
" \"server_port\": 9999,\n",
" \"enable_queue\": True,\n",
"}\n",
"\n",
"demo.launch(**demo_options)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "570f8a3c-367e-4a7f-808d-8fa2e925a444",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}