Oussamahajoui commited on
Commit
62a5519
Β·
1 Parent(s): f4ec2f1
.gitignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ flagged/image/tmp_kzg5jrp.jpg
3
+ flagged/image/tmpawmdga4z.jpg
4
+ flagged/log.csv
5
+ flagged/image/tmpv22yik0n.jpg
6
+ flagged/image/tmptuiort3g.jpg
7
+ flagged/image/tmpsyptwjk0.jpg
8
+ flagged/image/tmpo70zn1zc.jpg
9
+ flagged/image/tmpm7kl_i0r.jpg
10
+ flagged/image/tmpfjoni2co.jpg
11
+ flagged/image/tmpct6wib32.jpg
12
+ *.jpg
13
+ Checkpoint/baseline_V0.pth.tar
14
+ Data/sample_submission.csv
15
+ Data/solution.csv
16
+ Data/train.csv
17
+ Data/test.zip
18
+ Data/train.zip
Code/Datadownload.ipynb ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "Token is valid.\n",
13
+ "Your token has been saved in your configured git credential helpers (manager-core).\n",
14
+ "Your token has been saved to C:\\Users\\Oussama\\.cache\\huggingface\\token\n",
15
+ "Login successful\n"
16
+ ]
17
+ }
18
+ ],
19
+ "source": [
20
+ "from huggingface_hub import login\n",
21
+ "login()"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": 3,
27
+ "metadata": {},
28
+ "outputs": [
29
+ {
30
+ "data": {
31
+ "application/vnd.jupyter.widget-view+json": {
32
+ "model_id": "e5eb52f9282f43cfa4f06b1d9c6dc08b",
33
+ "version_major": 2,
34
+ "version_minor": 0
35
+ },
36
+ "text/plain": [
37
+ "Downloading readme: 0%| | 0.00/1.24k [00:00<?, ?B/s]"
38
+ ]
39
+ },
40
+ "metadata": {},
41
+ "output_type": "display_data"
42
+ },
43
+ {
44
+ "name": "stdout",
45
+ "output_type": "stream",
46
+ "text": [
47
+ "Downloading and preparing dataset None/None to C:/Users/Oussama/.cache/huggingface/datasets/competitions___parquet/competitions--aiornot-759454878caed5d9/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...\n"
48
+ ]
49
+ },
50
+ {
51
+ "data": {
52
+ "application/vnd.jupyter.widget-view+json": {
53
+ "model_id": "37ae63ab36cc448e90aa292609f4e5d4",
54
+ "version_major": 2,
55
+ "version_minor": 0
56
+ },
57
+ "text/plain": [
58
+ "Downloading data files: 0%| | 0/2 [00:00<?, ?it/s]"
59
+ ]
60
+ },
61
+ "metadata": {},
62
+ "output_type": "display_data"
63
+ },
64
+ {
65
+ "data": {
66
+ "application/vnd.jupyter.widget-view+json": {
67
+ "model_id": "e208b22a102847b4862d336b32cce3e1",
68
+ "version_major": 2,
69
+ "version_minor": 0
70
+ },
71
+ "text/plain": [
72
+ "Downloading data: 0%| | 0.00/354M [00:00<?, ?B/s]"
73
+ ]
74
+ },
75
+ "metadata": {},
76
+ "output_type": "display_data"
77
+ },
78
+ {
79
+ "data": {
80
+ "application/vnd.jupyter.widget-view+json": {
81
+ "model_id": "c31d5357bf9c44e99b19820ea10ad70a",
82
+ "version_major": 2,
83
+ "version_minor": 0
84
+ },
85
+ "text/plain": [
86
+ "Downloading data: 0%| | 0.00/356M [00:00<?, ?B/s]"
87
+ ]
88
+ },
89
+ "metadata": {},
90
+ "output_type": "display_data"
91
+ },
92
+ {
93
+ "data": {
94
+ "application/vnd.jupyter.widget-view+json": {
95
+ "model_id": "b7de3daec8a54b28a07cd580734f4bf4",
96
+ "version_major": 2,
97
+ "version_minor": 0
98
+ },
99
+ "text/plain": [
100
+ "Downloading data: 0%| | 0.00/415M [00:00<?, ?B/s]"
101
+ ]
102
+ },
103
+ "metadata": {},
104
+ "output_type": "display_data"
105
+ },
106
+ {
107
+ "data": {
108
+ "application/vnd.jupyter.widget-view+json": {
109
+ "model_id": "67b76cdca655426d9174aa7655345ecb",
110
+ "version_major": 2,
111
+ "version_minor": 0
112
+ },
113
+ "text/plain": [
114
+ "Downloading data: 0%| | 0.00/418M [00:00<?, ?B/s]"
115
+ ]
116
+ },
117
+ "metadata": {},
118
+ "output_type": "display_data"
119
+ },
120
+ {
121
+ "data": {
122
+ "application/vnd.jupyter.widget-view+json": {
123
+ "model_id": "373dd6818e5c4500b9d48de516bcb427",
124
+ "version_major": 2,
125
+ "version_minor": 0
126
+ },
127
+ "text/plain": [
128
+ "Downloading data: 0%| | 0.00/416M [00:00<?, ?B/s]"
129
+ ]
130
+ },
131
+ "metadata": {},
132
+ "output_type": "display_data"
133
+ },
134
+ {
135
+ "data": {
136
+ "application/vnd.jupyter.widget-view+json": {
137
+ "model_id": "9fc356325be34db1844de564ee0395fc",
138
+ "version_major": 2,
139
+ "version_minor": 0
140
+ },
141
+ "text/plain": [
142
+ "Downloading data: 0%| | 0.00/416M [00:00<?, ?B/s]"
143
+ ]
144
+ },
145
+ "metadata": {},
146
+ "output_type": "display_data"
147
+ },
148
+ {
149
+ "data": {
150
+ "application/vnd.jupyter.widget-view+json": {
151
+ "model_id": "8d8d0259ea1840d38f316bfc3809fa44",
152
+ "version_major": 2,
153
+ "version_minor": 0
154
+ },
155
+ "text/plain": [
156
+ "Extracting data files: 0%| | 0/2 [00:00<?, ?it/s]"
157
+ ]
158
+ },
159
+ "metadata": {},
160
+ "output_type": "display_data"
161
+ },
162
+ {
163
+ "data": {
164
+ "application/vnd.jupyter.widget-view+json": {
165
+ "model_id": "1780ace548794a58b318e723ae297b50",
166
+ "version_major": 2,
167
+ "version_minor": 0
168
+ },
169
+ "text/plain": [
170
+ "Generating train split: 0 examples [00:00, ? examples/s]"
171
+ ]
172
+ },
173
+ "metadata": {},
174
+ "output_type": "display_data"
175
+ },
176
+ {
177
+ "data": {
178
+ "application/vnd.jupyter.widget-view+json": {
179
+ "model_id": "fbd137787e1340bfaef5b62f6fbc7556",
180
+ "version_major": 2,
181
+ "version_minor": 0
182
+ },
183
+ "text/plain": [
184
+ "Generating test split: 0 examples [00:00, ? examples/s]"
185
+ ]
186
+ },
187
+ "metadata": {},
188
+ "output_type": "display_data"
189
+ },
190
+ {
191
+ "name": "stdout",
192
+ "output_type": "stream",
193
+ "text": [
194
+ "Dataset parquet downloaded and prepared to C:/Users/Oussama/.cache/huggingface/datasets/competitions___parquet/competitions--aiornot-759454878caed5d9/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.\n"
195
+ ]
196
+ },
197
+ {
198
+ "data": {
199
+ "application/vnd.jupyter.widget-view+json": {
200
+ "model_id": "2087104741e1465d9fc4b033d5cd8bdf",
201
+ "version_major": 2,
202
+ "version_minor": 0
203
+ },
204
+ "text/plain": [
205
+ " 0%| | 0/2 [00:00<?, ?it/s]"
206
+ ]
207
+ },
208
+ "metadata": {},
209
+ "output_type": "display_data"
210
+ }
211
+ ],
212
+ "source": [
213
+ "from datasets import load_dataset\n",
214
+ "\n",
215
+ "# If the dataset is gated/private, make sure you have run huggingface-cli login\n",
216
+ "dataset = load_dataset(\"competitions/aiornot\")"
217
+ ]
218
+ },
219
+ {
220
+ "cell_type": "code",
221
+ "execution_count": null,
222
+ "metadata": {},
223
+ "outputs": [],
224
+ "source": []
225
+ },
226
+ {
227
+ "cell_type": "code",
228
+ "execution_count": null,
229
+ "metadata": {},
230
+ "outputs": [],
231
+ "source": []
232
+ }
233
+ ],
234
+ "metadata": {
235
+ "kernelspec": {
236
+ "display_name": "Python 3",
237
+ "language": "python",
238
+ "name": "python3"
239
+ },
240
+ "language_info": {
241
+ "codemirror_mode": {
242
+ "name": "ipython",
243
+ "version": 3
244
+ },
245
+ "file_extension": ".py",
246
+ "mimetype": "text/x-python",
247
+ "name": "python",
248
+ "nbconvert_exporter": "python",
249
+ "pygments_lexer": "ipython3",
250
+ "version": "3.10.11"
251
+ },
252
+ "orig_nbformat": 4
253
+ },
254
+ "nbformat": 4,
255
+ "nbformat_minor": 2
256
+ }
Code/Testing NB.ipynb ADDED
@@ -0,0 +1,1056 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ },
15
+ "accelerator": "GPU",
16
+ "gpuClass": "standard"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 32,
22
+ "metadata": {
23
+ "colab": {
24
+ "base_uri": "https://localhost:8080/"
25
+ },
26
+ "id": "L6gytYO-DHMK",
27
+ "outputId": "b0c87fe1-77a4-45c7-8ea4-b8211cc0c4a7"
28
+ },
29
+ "outputs": [
30
+ {
31
+ "output_type": "stream",
32
+ "name": "stdout",
33
+ "text": [
34
+ "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"
35
+ ]
36
+ }
37
+ ],
38
+ "source": [
39
+ "from google.colab import drive\n",
40
+ "drive.mount('/content/drive')"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "source": [
46
+ "%pip install efficientnet-pytorch"
47
+ ],
48
+ "metadata": {
49
+ "colab": {
50
+ "base_uri": "https://localhost:8080/"
51
+ },
52
+ "id": "OoBBN22XDRNG",
53
+ "outputId": "c63a35aa-a077-44c7-93e5-bc9ba9732770"
54
+ },
55
+ "execution_count": 33,
56
+ "outputs": [
57
+ {
58
+ "output_type": "stream",
59
+ "name": "stdout",
60
+ "text": [
61
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
62
+ "Requirement already satisfied: efficientnet-pytorch in /usr/local/lib/python3.9/dist-packages (0.7.1)\n",
63
+ "Requirement already satisfied: torch in /usr/local/lib/python3.9/dist-packages (from efficientnet-pytorch) (2.0.0+cu118)\n",
64
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.9/dist-packages (from torch->efficientnet-pytorch) (4.5.0)\n",
65
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.9/dist-packages (from torch->efficientnet-pytorch) (1.11.1)\n",
66
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.9/dist-packages (from torch->efficientnet-pytorch) (3.11.0)\n",
67
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.9/dist-packages (from torch->efficientnet-pytorch) (3.1)\n",
68
+ "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.9/dist-packages (from torch->efficientnet-pytorch) (2.0.0)\n",
69
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.9/dist-packages (from torch->efficientnet-pytorch) (3.1.2)\n",
70
+ "Requirement already satisfied: lit in /usr/local/lib/python3.9/dist-packages (from triton==2.0.0->torch->efficientnet-pytorch) (16.0.1)\n",
71
+ "Requirement already satisfied: cmake in /usr/local/lib/python3.9/dist-packages (from triton==2.0.0->torch->efficientnet-pytorch) (3.25.2)\n",
72
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.9/dist-packages (from jinja2->torch->efficientnet-pytorch) (2.1.2)\n",
73
+ "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.9/dist-packages (from sympy->torch->efficientnet-pytorch) (1.3.0)\n"
74
+ ]
75
+ }
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "source": [
81
+ "import numpy as np\n",
82
+ "import pandas as pd\n",
83
+ "import matplotlib.pyplot as plt\n",
84
+ "import os\n",
85
+ "from PIL import Image\n",
86
+ "import torch\n",
87
+ "from torch import nn, optim\n",
88
+ "import torch.nn.functional as F\n",
89
+ "from torch.utils.data import DataLoader, Dataset\n",
90
+ "import albumentations as A\n",
91
+ "from albumentations.pytorch import ToTensorV2 \n",
92
+ "from tqdm import tqdm\n",
93
+ "from torchvision import models\n",
94
+ "from efficientnet_pytorch import EfficientNet\n",
95
+ "from sklearn import metrics"
96
+ ],
97
+ "metadata": {
98
+ "id": "phJgllqcDSuH"
99
+ },
100
+ "execution_count": 34,
101
+ "outputs": []
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "source": [
106
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
107
+ ],
108
+ "metadata": {
109
+ "id": "DyUTFa31DTdp"
110
+ },
111
+ "execution_count": 35,
112
+ "outputs": []
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "source": [
117
+ "class Dataset(Dataset):\n",
118
+ " def __init__(self, root_images, root_file, transform = None):\n",
119
+ " self.root_images = root_images\n",
120
+ " self.root_file = root_file\n",
121
+ " self.transform = transform\n",
122
+ " self.file = pd.read_csv(root_file)\n",
123
+ "\n",
124
+ "\n",
125
+ " def __len__(self):\n",
126
+ " return self.file.shape[0]\n",
127
+ " \n",
128
+ " def __getitem__(self,index):\n",
129
+ " img_path = os.path.join(self.root_images, self.file['id'][index])\n",
130
+ " image = np.array(Image.open(img_path).convert('RGB'))\n",
131
+ " \n",
132
+ " if self.transform is not None:\n",
133
+ " augmentations = self.transform(image = image)\n",
134
+ " image = augmentations['image'] \n",
135
+ " \n",
136
+ " return image"
137
+ ],
138
+ "metadata": {
139
+ "id": "kTk-mXXUDUUA"
140
+ },
141
+ "execution_count": 36,
142
+ "outputs": []
143
+ },
144
+ {
145
+ "cell_type": "code",
146
+ "source": [
147
+ "learning_rate = 0.0001\n",
148
+ "batch_size = 32\n",
149
+ "epochs = 10\n",
150
+ "height = 224 \n",
151
+ "width = 224\n",
152
+ "IMG = '/content/drive/MyDrive/Colab Notebooks/AI images or Not/test'\n",
153
+ "FILE = '/content/sample_submission.csv'"
154
+ ],
155
+ "metadata": {
156
+ "id": "HXEpa4PlDU85"
157
+ },
158
+ "execution_count": 37,
159
+ "outputs": []
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "source": [
164
+ "def get_loader(image, file, batch_size, test_transform):\n",
165
+ " \n",
166
+ " test_ds = Dataset(image , file, test_transform)\n",
167
+ " test_loader = DataLoader(test_ds, batch_size= batch_size, shuffle= False)\n",
168
+ "\n",
169
+ "\n",
170
+ "\n",
171
+ " return test_loader "
172
+ ],
173
+ "metadata": {
174
+ "id": "i-VOTQp2DVbK"
175
+ },
176
+ "execution_count": 38,
177
+ "outputs": []
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "source": [
182
+ "normalize = A.Normalize(\n",
183
+ " mean = [0.485 , 0.456 , 0.406],\n",
184
+ " std = [0.229 , 0.224, 0.255],\n",
185
+ " max_pixel_value= 255.0\n",
186
+ ")\n",
187
+ "\n",
188
+ "\n",
189
+ "test_transform = A.Compose(\n",
190
+ " [A.Resize(width=width , height= height),\n",
191
+ " normalize,\n",
192
+ " ToTensorV2()\n",
193
+ " ]\n",
194
+ ")\n"
195
+ ],
196
+ "metadata": {
197
+ "id": "RD4GnrT6DVpr"
198
+ },
199
+ "execution_count": 39,
200
+ "outputs": []
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "source": [
205
+ "class Net(nn.Module):\n",
206
+ " def __init__(self):\n",
207
+ " super().__init__()\n",
208
+ " self.model = EfficientNet.from_pretrained('efficientnet-b4')\n",
209
+ " self.fct = nn.Linear(1000,1)\n",
210
+ " \n",
211
+ " def forward(self,img):\n",
212
+ " x = self.model(img)\n",
213
+ " # print(x.shape)\n",
214
+ " x = self.fct(x)\n",
215
+ " return x"
216
+ ],
217
+ "metadata": {
218
+ "id": "HYH0pBe9DV3M"
219
+ },
220
+ "execution_count": 40,
221
+ "outputs": []
222
+ },
223
+ {
224
+ "cell_type": "code",
225
+ "source": [
226
+ "def load_checkpoint(checkpoint, model, optimizer):\n",
227
+ " print('====> Loading...')\n",
228
+ " model.load_state_dict(checkpoint['state_dict'])\n",
229
+ " optimizer.load_state_dict(checkpoint['optimizer'])"
230
+ ],
231
+ "metadata": {
232
+ "id": "1Ype_u3qDV-n"
233
+ },
234
+ "execution_count": 41,
235
+ "outputs": []
236
+ },
237
+ {
238
+ "cell_type": "code",
239
+ "source": [
240
+ "test = pd.read_csv(FILE)\n",
241
+ "test"
242
+ ],
243
+ "metadata": {
244
+ "colab": {
245
+ "base_uri": "https://localhost:8080/",
246
+ "height": 424
247
+ },
248
+ "id": "Jf_Is1qDGz-W",
249
+ "outputId": "cf79a4c0-2bca-473c-886e-726d7956015d"
250
+ },
251
+ "execution_count": 42,
252
+ "outputs": [
253
+ {
254
+ "output_type": "execute_result",
255
+ "data": {
256
+ "text/plain": [
257
+ " id label\n",
258
+ "0 0.jpg 0\n",
259
+ "1 1.jpg 0\n",
260
+ "2 10.jpg 0\n",
261
+ "3 100.jpg 0\n",
262
+ "4 1000.jpg 0\n",
263
+ "... ... ...\n",
264
+ "43437 9995.jpg 0\n",
265
+ "43438 9996.jpg 0\n",
266
+ "43439 9997.jpg 0\n",
267
+ "43440 9998.jpg 0\n",
268
+ "43441 9999.jpg 0\n",
269
+ "\n",
270
+ "[43442 rows x 2 columns]"
271
+ ],
272
+ "text/html": [
273
+ "\n",
274
+ " <div id=\"df-e57e96ec-2c2a-4dd2-b93e-600b15eda5bc\">\n",
275
+ " <div class=\"colab-df-container\">\n",
276
+ " <div>\n",
277
+ "<style scoped>\n",
278
+ " .dataframe tbody tr th:only-of-type {\n",
279
+ " vertical-align: middle;\n",
280
+ " }\n",
281
+ "\n",
282
+ " .dataframe tbody tr th {\n",
283
+ " vertical-align: top;\n",
284
+ " }\n",
285
+ "\n",
286
+ " .dataframe thead th {\n",
287
+ " text-align: right;\n",
288
+ " }\n",
289
+ "</style>\n",
290
+ "<table border=\"1\" class=\"dataframe\">\n",
291
+ " <thead>\n",
292
+ " <tr style=\"text-align: right;\">\n",
293
+ " <th></th>\n",
294
+ " <th>id</th>\n",
295
+ " <th>label</th>\n",
296
+ " </tr>\n",
297
+ " </thead>\n",
298
+ " <tbody>\n",
299
+ " <tr>\n",
300
+ " <th>0</th>\n",
301
+ " <td>0.jpg</td>\n",
302
+ " <td>0</td>\n",
303
+ " </tr>\n",
304
+ " <tr>\n",
305
+ " <th>1</th>\n",
306
+ " <td>1.jpg</td>\n",
307
+ " <td>0</td>\n",
308
+ " </tr>\n",
309
+ " <tr>\n",
310
+ " <th>2</th>\n",
311
+ " <td>10.jpg</td>\n",
312
+ " <td>0</td>\n",
313
+ " </tr>\n",
314
+ " <tr>\n",
315
+ " <th>3</th>\n",
316
+ " <td>100.jpg</td>\n",
317
+ " <td>0</td>\n",
318
+ " </tr>\n",
319
+ " <tr>\n",
320
+ " <th>4</th>\n",
321
+ " <td>1000.jpg</td>\n",
322
+ " <td>0</td>\n",
323
+ " </tr>\n",
324
+ " <tr>\n",
325
+ " <th>...</th>\n",
326
+ " <td>...</td>\n",
327
+ " <td>...</td>\n",
328
+ " </tr>\n",
329
+ " <tr>\n",
330
+ " <th>43437</th>\n",
331
+ " <td>9995.jpg</td>\n",
332
+ " <td>0</td>\n",
333
+ " </tr>\n",
334
+ " <tr>\n",
335
+ " <th>43438</th>\n",
336
+ " <td>9996.jpg</td>\n",
337
+ " <td>0</td>\n",
338
+ " </tr>\n",
339
+ " <tr>\n",
340
+ " <th>43439</th>\n",
341
+ " <td>9997.jpg</td>\n",
342
+ " <td>0</td>\n",
343
+ " </tr>\n",
344
+ " <tr>\n",
345
+ " <th>43440</th>\n",
346
+ " <td>9998.jpg</td>\n",
347
+ " <td>0</td>\n",
348
+ " </tr>\n",
349
+ " <tr>\n",
350
+ " <th>43441</th>\n",
351
+ " <td>9999.jpg</td>\n",
352
+ " <td>0</td>\n",
353
+ " </tr>\n",
354
+ " </tbody>\n",
355
+ "</table>\n",
356
+ "<p>43442 rows Γ— 2 columns</p>\n",
357
+ "</div>\n",
358
+ " <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-e57e96ec-2c2a-4dd2-b93e-600b15eda5bc')\"\n",
359
+ " title=\"Convert this dataframe to an interactive table.\"\n",
360
+ " style=\"display:none;\">\n",
361
+ " \n",
362
+ " <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
363
+ " width=\"24px\">\n",
364
+ " <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
365
+ " <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
366
+ " </svg>\n",
367
+ " </button>\n",
368
+ " \n",
369
+ " <style>\n",
370
+ " .colab-df-container {\n",
371
+ " display:flex;\n",
372
+ " flex-wrap:wrap;\n",
373
+ " gap: 12px;\n",
374
+ " }\n",
375
+ "\n",
376
+ " .colab-df-convert {\n",
377
+ " background-color: #E8F0FE;\n",
378
+ " border: none;\n",
379
+ " border-radius: 50%;\n",
380
+ " cursor: pointer;\n",
381
+ " display: none;\n",
382
+ " fill: #1967D2;\n",
383
+ " height: 32px;\n",
384
+ " padding: 0 0 0 0;\n",
385
+ " width: 32px;\n",
386
+ " }\n",
387
+ "\n",
388
+ " .colab-df-convert:hover {\n",
389
+ " background-color: #E2EBFA;\n",
390
+ " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
391
+ " fill: #174EA6;\n",
392
+ " }\n",
393
+ "\n",
394
+ " [theme=dark] .colab-df-convert {\n",
395
+ " background-color: #3B4455;\n",
396
+ " fill: #D2E3FC;\n",
397
+ " }\n",
398
+ "\n",
399
+ " [theme=dark] .colab-df-convert:hover {\n",
400
+ " background-color: #434B5C;\n",
401
+ " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
402
+ " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
403
+ " fill: #FFFFFF;\n",
404
+ " }\n",
405
+ " </style>\n",
406
+ "\n",
407
+ " <script>\n",
408
+ " const buttonEl =\n",
409
+ " document.querySelector('#df-e57e96ec-2c2a-4dd2-b93e-600b15eda5bc button.colab-df-convert');\n",
410
+ " buttonEl.style.display =\n",
411
+ " google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
412
+ "\n",
413
+ " async function convertToInteractive(key) {\n",
414
+ " const element = document.querySelector('#df-e57e96ec-2c2a-4dd2-b93e-600b15eda5bc');\n",
415
+ " const dataTable =\n",
416
+ " await google.colab.kernel.invokeFunction('convertToInteractive',\n",
417
+ " [key], {});\n",
418
+ " if (!dataTable) return;\n",
419
+ "\n",
420
+ " const docLinkHtml = 'Like what you see? Visit the ' +\n",
421
+ " '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
422
+ " + ' to learn more about interactive tables.';\n",
423
+ " element.innerHTML = '';\n",
424
+ " dataTable['output_type'] = 'display_data';\n",
425
+ " await google.colab.output.renderOutput(dataTable, element);\n",
426
+ " const docLink = document.createElement('div');\n",
427
+ " docLink.innerHTML = docLinkHtml;\n",
428
+ " element.appendChild(docLink);\n",
429
+ " }\n",
430
+ " </script>\n",
431
+ " </div>\n",
432
+ " </div>\n",
433
+ " "
434
+ ]
435
+ },
436
+ "metadata": {},
437
+ "execution_count": 42
438
+ }
439
+ ]
440
+ },
441
+ {
442
+ "cell_type": "code",
443
+ "source": [
444
+ "model = Net().to(device)\n",
445
+ "optimizer = optim.Adam(model.parameters(), lr= learning_rate)\n",
446
+ "\n",
447
+ "checkpoint_file = '/content/drive/MyDrive/Colab Notebooks/AI images or Not/baseline_V0.pth.tar'\n",
448
+ "test_loader = get_loader(IMG, FILE, batch_size, test_transform)\n",
449
+ "checkpoint = torch.load(checkpoint_file, map_location=torch.device('cpu'))\n",
450
+ "load_checkpoint(checkpoint, model, optimizer)\n",
451
+ "\n",
452
+ "model.eval()\n",
453
+ "k = 0\n",
454
+ "for x in tqdm(test_loader):\n",
455
+ " x = x.to(device).to(torch.float32)\n",
456
+ " p = torch.sigmoid(model(x)).cpu().detach().numpy()\n",
457
+ "\n",
458
+ " for i in range(len(p)):\n",
459
+ " test['label'][k] = (p[i] > 0.75).astype('float')\n",
460
+ " k += 1"
461
+ ],
462
+ "metadata": {
463
+ "id": "qWB6WzrlDWD7",
464
+ "colab": {
465
+ "base_uri": "https://localhost:8080/"
466
+ },
467
+ "outputId": "52e74e4b-96e7-40e7-d1b3-a22c7b70098d"
468
+ },
469
+ "execution_count": 43,
470
+ "outputs": [
471
+ {
472
+ "output_type": "stream",
473
+ "name": "stdout",
474
+ "text": [
475
+ "Loaded pretrained weights for efficientnet-b4\n",
476
+ "====> Loading...\n"
477
+ ]
478
+ },
479
+ {
480
+ "output_type": "stream",
481
+ "name": "stderr",
482
+ "text": [
483
+ " 0%| | 0/1358 [00:00<?, ?it/s]<ipython-input-43-383dee41b09a>:16: SettingWithCopyWarning: \n",
484
+ "A value is trying to be set on a copy of a slice from a DataFrame\n",
485
+ "\n",
486
+ "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
487
+ " test['label'][k] = (p[i] > 0.75).astype('float')\n",
488
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1358/1358 [4:56:02<00:00, 13.08s/it]\n"
489
+ ]
490
+ }
491
+ ]
492
+ },
493
+ {
494
+ "cell_type": "code",
495
+ "source": [
496
+ "test"
497
+ ],
498
+ "metadata": {
499
+ "id": "-zS8tYPBDWG7",
500
+ "colab": {
501
+ "base_uri": "https://localhost:8080/",
502
+ "height": 424
503
+ },
504
+ "outputId": "4a4c0b81-ff75-4ed7-a5df-6f98644f03e2"
505
+ },
506
+ "execution_count": 44,
507
+ "outputs": [
508
+ {
509
+ "output_type": "execute_result",
510
+ "data": {
511
+ "text/plain": [
512
+ " id label\n",
513
+ "0 0.jpg 0\n",
514
+ "1 1.jpg 0\n",
515
+ "2 10.jpg 0\n",
516
+ "3 100.jpg 1\n",
517
+ "4 1000.jpg 0\n",
518
+ "... ... ...\n",
519
+ "43437 9995.jpg 1\n",
520
+ "43438 9996.jpg 0\n",
521
+ "43439 9997.jpg 0\n",
522
+ "43440 9998.jpg 0\n",
523
+ "43441 9999.jpg 1\n",
524
+ "\n",
525
+ "[43442 rows x 2 columns]"
526
+ ],
527
+ "text/html": [
528
+ "\n",
529
+ " <div id=\"df-00389cce-5634-451c-81fb-649bced26029\">\n",
530
+ " <div class=\"colab-df-container\">\n",
531
+ " <div>\n",
532
+ "<style scoped>\n",
533
+ " .dataframe tbody tr th:only-of-type {\n",
534
+ " vertical-align: middle;\n",
535
+ " }\n",
536
+ "\n",
537
+ " .dataframe tbody tr th {\n",
538
+ " vertical-align: top;\n",
539
+ " }\n",
540
+ "\n",
541
+ " .dataframe thead th {\n",
542
+ " text-align: right;\n",
543
+ " }\n",
544
+ "</style>\n",
545
+ "<table border=\"1\" class=\"dataframe\">\n",
546
+ " <thead>\n",
547
+ " <tr style=\"text-align: right;\">\n",
548
+ " <th></th>\n",
549
+ " <th>id</th>\n",
550
+ " <th>label</th>\n",
551
+ " </tr>\n",
552
+ " </thead>\n",
553
+ " <tbody>\n",
554
+ " <tr>\n",
555
+ " <th>0</th>\n",
556
+ " <td>0.jpg</td>\n",
557
+ " <td>0</td>\n",
558
+ " </tr>\n",
559
+ " <tr>\n",
560
+ " <th>1</th>\n",
561
+ " <td>1.jpg</td>\n",
562
+ " <td>0</td>\n",
563
+ " </tr>\n",
564
+ " <tr>\n",
565
+ " <th>2</th>\n",
566
+ " <td>10.jpg</td>\n",
567
+ " <td>0</td>\n",
568
+ " </tr>\n",
569
+ " <tr>\n",
570
+ " <th>3</th>\n",
571
+ " <td>100.jpg</td>\n",
572
+ " <td>1</td>\n",
573
+ " </tr>\n",
574
+ " <tr>\n",
575
+ " <th>4</th>\n",
576
+ " <td>1000.jpg</td>\n",
577
+ " <td>0</td>\n",
578
+ " </tr>\n",
579
+ " <tr>\n",
580
+ " <th>...</th>\n",
581
+ " <td>...</td>\n",
582
+ " <td>...</td>\n",
583
+ " </tr>\n",
584
+ " <tr>\n",
585
+ " <th>43437</th>\n",
586
+ " <td>9995.jpg</td>\n",
587
+ " <td>1</td>\n",
588
+ " </tr>\n",
589
+ " <tr>\n",
590
+ " <th>43438</th>\n",
591
+ " <td>9996.jpg</td>\n",
592
+ " <td>0</td>\n",
593
+ " </tr>\n",
594
+ " <tr>\n",
595
+ " <th>43439</th>\n",
596
+ " <td>9997.jpg</td>\n",
597
+ " <td>0</td>\n",
598
+ " </tr>\n",
599
+ " <tr>\n",
600
+ " <th>43440</th>\n",
601
+ " <td>9998.jpg</td>\n",
602
+ " <td>0</td>\n",
603
+ " </tr>\n",
604
+ " <tr>\n",
605
+ " <th>43441</th>\n",
606
+ " <td>9999.jpg</td>\n",
607
+ " <td>1</td>\n",
608
+ " </tr>\n",
609
+ " </tbody>\n",
610
+ "</table>\n",
611
+ "<p>43442 rows Γ— 2 columns</p>\n",
612
+ "</div>\n",
613
+ " <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-00389cce-5634-451c-81fb-649bced26029')\"\n",
614
+ " title=\"Convert this dataframe to an interactive table.\"\n",
615
+ " style=\"display:none;\">\n",
616
+ " \n",
617
+ " <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
618
+ " width=\"24px\">\n",
619
+ " <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
620
+ " <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
621
+ " </svg>\n",
622
+ " </button>\n",
623
+ " \n",
624
+ " <style>\n",
625
+ " .colab-df-container {\n",
626
+ " display:flex;\n",
627
+ " flex-wrap:wrap;\n",
628
+ " gap: 12px;\n",
629
+ " }\n",
630
+ "\n",
631
+ " .colab-df-convert {\n",
632
+ " background-color: #E8F0FE;\n",
633
+ " border: none;\n",
634
+ " border-radius: 50%;\n",
635
+ " cursor: pointer;\n",
636
+ " display: none;\n",
637
+ " fill: #1967D2;\n",
638
+ " height: 32px;\n",
639
+ " padding: 0 0 0 0;\n",
640
+ " width: 32px;\n",
641
+ " }\n",
642
+ "\n",
643
+ " .colab-df-convert:hover {\n",
644
+ " background-color: #E2EBFA;\n",
645
+ " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
646
+ " fill: #174EA6;\n",
647
+ " }\n",
648
+ "\n",
649
+ " [theme=dark] .colab-df-convert {\n",
650
+ " background-color: #3B4455;\n",
651
+ " fill: #D2E3FC;\n",
652
+ " }\n",
653
+ "\n",
654
+ " [theme=dark] .colab-df-convert:hover {\n",
655
+ " background-color: #434B5C;\n",
656
+ " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
657
+ " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
658
+ " fill: #FFFFFF;\n",
659
+ " }\n",
660
+ " </style>\n",
661
+ "\n",
662
+ " <script>\n",
663
+ " const buttonEl =\n",
664
+ " document.querySelector('#df-00389cce-5634-451c-81fb-649bced26029 button.colab-df-convert');\n",
665
+ " buttonEl.style.display =\n",
666
+ " google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
667
+ "\n",
668
+ " async function convertToInteractive(key) {\n",
669
+ " const element = document.querySelector('#df-00389cce-5634-451c-81fb-649bced26029');\n",
670
+ " const dataTable =\n",
671
+ " await google.colab.kernel.invokeFunction('convertToInteractive',\n",
672
+ " [key], {});\n",
673
+ " if (!dataTable) return;\n",
674
+ "\n",
675
+ " const docLinkHtml = 'Like what you see? Visit the ' +\n",
676
+ " '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
677
+ " + ' to learn more about interactive tables.';\n",
678
+ " element.innerHTML = '';\n",
679
+ " dataTable['output_type'] = 'display_data';\n",
680
+ " await google.colab.output.renderOutput(dataTable, element);\n",
681
+ " const docLink = document.createElement('div');\n",
682
+ " docLink.innerHTML = docLinkHtml;\n",
683
+ " element.appendChild(docLink);\n",
684
+ " }\n",
685
+ " </script>\n",
686
+ " </div>\n",
687
+ " </div>\n",
688
+ " "
689
+ ]
690
+ },
691
+ "metadata": {},
692
+ "execution_count": 44
693
+ }
694
+ ]
695
+ },
696
+ {
697
+ "cell_type": "code",
698
+ "source": [
699
+ "test.to_csv('sub.csv', index=False)"
700
+ ],
701
+ "metadata": {
702
+ "id": "nX_vnorKDWKK"
703
+ },
704
+ "execution_count": 45,
705
+ "outputs": []
706
+ },
707
+ {
708
+ "cell_type": "code",
709
+ "source": [],
710
+ "metadata": {
711
+ "id": "JmJa1KolDWM5"
712
+ },
713
+ "execution_count": 45,
714
+ "outputs": []
715
+ },
716
+ {
717
+ "cell_type": "code",
718
+ "source": [
719
+ "def predict(image):\n",
720
+ " image = np.array(image)\n",
721
+ " transform = A.Compose(\n",
722
+ " [A.Resize(width=width, height=height),\n",
723
+ " normalize,\n",
724
+ " ToTensorV2()\n",
725
+ " ]\n",
726
+ " )\n",
727
+ " image = transform(image=image)[\"image\"].unsqueeze(0).to(device).to(torch.float32)\n",
728
+ " with torch.no_grad():\n",
729
+ " model.eval()\n",
730
+ " output = torch.sigmoid(model(image))\n",
731
+ " label = (output > 0.75).item()\n",
732
+ " return \"AI Image\" if label else \"Not AI Image\""
733
+ ],
734
+ "metadata": {
735
+ "id": "TKs8s0TyDWP0"
736
+ },
737
+ "execution_count": 46,
738
+ "outputs": []
739
+ },
740
+ {
741
+ "cell_type": "code",
742
+ "source": [
743
+ "%pip install gradio"
744
+ ],
745
+ "metadata": {
746
+ "colab": {
747
+ "base_uri": "https://localhost:8080/"
748
+ },
749
+ "id": "k7bGi6MqqO-r",
750
+ "outputId": "120d9571-3381-418a-9056-ff8b84199ca7"
751
+ },
752
+ "execution_count": 47,
753
+ "outputs": [
754
+ {
755
+ "output_type": "stream",
756
+ "name": "stdout",
757
+ "text": [
758
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
759
+ "Collecting gradio\n",
760
+ " Downloading gradio-3.27.0-py3-none-any.whl (17.3 MB)\n",
761
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m17.3/17.3 MB\u001b[0m \u001b[31m60.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
762
+ "\u001b[?25hRequirement already satisfied: matplotlib in /usr/local/lib/python3.9/dist-packages (from gradio) (3.7.1)\n",
763
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.9/dist-packages (from gradio) (1.22.4)\n",
764
+ "Requirement already satisfied: markdown-it-py[linkify]>=2.0.0 in /usr/local/lib/python3.9/dist-packages (from gradio) (2.2.0)\n",
765
+ "Collecting aiohttp\n",
766
+ " Downloading aiohttp-3.8.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)\n",
767
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m54.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
768
+ "\u001b[?25hCollecting orjson\n",
769
+ " Downloading orjson-3.8.10-cp39-cp39-manylinux_2_28_x86_64.whl (140 kB)\n",
770
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.5/140.5 kB\u001b[0m \u001b[31m16.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
771
+ "\u001b[?25hRequirement already satisfied: pillow in /usr/local/lib/python3.9/dist-packages (from gradio) (8.4.0)\n",
772
+ "Collecting ffmpy\n",
773
+ " Downloading ffmpy-0.3.0.tar.gz (4.8 kB)\n",
774
+ " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
775
+ "Requirement already satisfied: markupsafe in /usr/local/lib/python3.9/dist-packages (from gradio) (2.1.2)\n",
776
+ "Collecting huggingface-hub>=0.13.0\n",
777
+ " Downloading huggingface_hub-0.13.4-py3-none-any.whl (200 kB)\n",
778
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m200.1/200.1 kB\u001b[0m \u001b[31m20.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
779
+ "\u001b[?25hRequirement already satisfied: requests in /usr/local/lib/python3.9/dist-packages (from gradio) (2.27.1)\n",
780
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.9/dist-packages (from gradio) (4.5.0)\n",
781
+ "Requirement already satisfied: altair>=4.2.0 in /usr/local/lib/python3.9/dist-packages (from gradio) (4.2.2)\n",
782
+ "Collecting fastapi\n",
783
+ " Downloading fastapi-0.95.1-py3-none-any.whl (56 kB)\n",
784
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m57.0/57.0 kB\u001b[0m \u001b[31m7.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
785
+ "\u001b[?25hCollecting httpx\n",
786
+ " Downloading httpx-0.24.0-py3-none-any.whl (75 kB)\n",
787
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m75.3/75.3 kB\u001b[0m \u001b[31m9.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
788
+ "\u001b[?25hCollecting uvicorn\n",
789
+ " Downloading uvicorn-0.21.1-py3-none-any.whl (57 kB)\n",
790
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m57.8/57.8 kB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
791
+ "\u001b[?25hCollecting websockets>=10.0\n",
792
+ " Downloading websockets-11.0.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (129 kB)\n",
793
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m129.7/129.7 kB\u001b[0m \u001b[31m15.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
794
+ "\u001b[?25hCollecting pydub\n",
795
+ " Downloading pydub-0.25.1-py2.py3-none-any.whl (32 kB)\n",
796
+ "Requirement already satisfied: pyyaml in /usr/local/lib/python3.9/dist-packages (from gradio) (6.0)\n",
797
+ "Collecting mdit-py-plugins<=0.3.3\n",
798
+ " Downloading mdit_py_plugins-0.3.3-py3-none-any.whl (50 kB)\n",
799
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m50.5/50.5 kB\u001b[0m \u001b[31m6.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
800
+ "\u001b[?25hCollecting python-multipart\n",
801
+ " Downloading python_multipart-0.0.6-py3-none-any.whl (45 kB)\n",
802
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m45.7/45.7 kB\u001b[0m \u001b[31m5.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
803
+ "\u001b[?25hRequirement already satisfied: pydantic in /usr/local/lib/python3.9/dist-packages (from gradio) (1.10.7)\n",
804
+ "Collecting semantic-version\n",
805
+ " Downloading semantic_version-2.10.0-py2.py3-none-any.whl (15 kB)\n",
806
+ "Requirement already satisfied: pandas in /usr/local/lib/python3.9/dist-packages (from gradio) (1.5.3)\n",
807
+ "Collecting gradio-client>=0.1.3\n",
808
+ " Downloading gradio_client-0.1.3-py3-none-any.whl (286 kB)\n",
809
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m286.2/286.2 kB\u001b[0m \u001b[31m27.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
810
+ "\u001b[?25hCollecting aiofiles\n",
811
+ " Downloading aiofiles-23.1.0-py3-none-any.whl (14 kB)\n",
812
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.9/dist-packages (from gradio) (3.1.2)\n",
813
+ "Requirement already satisfied: entrypoints in /usr/local/lib/python3.9/dist-packages (from altair>=4.2.0->gradio) (0.4)\n",
814
+ "Requirement already satisfied: toolz in /usr/local/lib/python3.9/dist-packages (from altair>=4.2.0->gradio) (0.12.0)\n",
815
+ "Requirement already satisfied: jsonschema>=3.0 in /usr/local/lib/python3.9/dist-packages (from altair>=4.2.0->gradio) (4.3.3)\n",
816
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.9/dist-packages (from gradio-client>=0.1.3->gradio) (23.1)\n",
817
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.9/dist-packages (from gradio-client>=0.1.3->gradio) (2023.4.0)\n",
818
+ "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.9/dist-packages (from huggingface-hub>=0.13.0->gradio) (4.65.0)\n",
819
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.9/dist-packages (from huggingface-hub>=0.13.0->gradio) (3.11.0)\n",
820
+ "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.9/dist-packages (from markdown-it-py[linkify]>=2.0.0->gradio) (0.1.2)\n",
821
+ "Collecting linkify-it-py<3,>=1\n",
822
+ " Downloading linkify_it_py-2.0.0-py3-none-any.whl (19 kB)\n",
823
+ "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.9/dist-packages (from pandas->gradio) (2.8.2)\n",
824
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.9/dist-packages (from pandas->gradio) (2022.7.1)\n",
825
+ "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.9/dist-packages (from aiohttp->gradio) (23.1.0)\n",
826
+ "Collecting async-timeout<5.0,>=4.0.0a3\n",
827
+ " Downloading async_timeout-4.0.2-py3-none-any.whl (5.8 kB)\n",
828
+ "Collecting multidict<7.0,>=4.5\n",
829
+ " Downloading multidict-6.0.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (114 kB)\n",
830
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m114.2/114.2 kB\u001b[0m \u001b[31m15.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
831
+ "\u001b[?25hCollecting frozenlist>=1.1.1\n",
832
+ " Downloading frozenlist-1.3.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (158 kB)\n",
833
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m158.8/158.8 kB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
834
+ "\u001b[?25hCollecting yarl<2.0,>=1.0\n",
835
+ " Downloading yarl-1.8.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (264 kB)\n",
836
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m264.6/264.6 kB\u001b[0m \u001b[31m29.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
837
+ "\u001b[?25hCollecting aiosignal>=1.1.2\n",
838
+ " Downloading aiosignal-1.3.1-py3-none-any.whl (7.6 kB)\n",
839
+ "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.9/dist-packages (from aiohttp->gradio) (2.0.12)\n",
840
+ "Collecting starlette<0.27.0,>=0.26.1\n",
841
+ " Downloading starlette-0.26.1-py3-none-any.whl (66 kB)\n",
842
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m66.9/66.9 kB\u001b[0m \u001b[31m9.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
843
+ "\u001b[?25hCollecting httpcore<0.18.0,>=0.15.0\n",
844
+ " Downloading httpcore-0.17.0-py3-none-any.whl (70 kB)\n",
845
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m70.6/70.6 kB\u001b[0m \u001b[31m10.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
846
+ "\u001b[?25hRequirement already satisfied: certifi in /usr/local/lib/python3.9/dist-packages (from httpx->gradio) (2022.12.7)\n",
847
+ "Requirement already satisfied: sniffio in /usr/local/lib/python3.9/dist-packages (from httpx->gradio) (1.3.0)\n",
848
+ "Requirement already satisfied: idna in /usr/local/lib/python3.9/dist-packages (from httpx->gradio) (3.4)\n",
849
+ "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.9/dist-packages (from matplotlib->gradio) (3.0.9)\n",
850
+ "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.9/dist-packages (from matplotlib->gradio) (1.0.7)\n",
851
+ "Requirement already satisfied: importlib-resources>=3.2.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib->gradio) (5.12.0)\n",
852
+ "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib->gradio) (4.39.3)\n",
853
+ "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.9/dist-packages (from matplotlib->gradio) (1.4.4)\n",
854
+ "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.9/dist-packages (from matplotlib->gradio) (0.11.0)\n",
855
+ "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.9/dist-packages (from requests->gradio) (1.26.15)\n",
856
+ "Collecting h11>=0.8\n",
857
+ " Downloading h11-0.14.0-py3-none-any.whl (58 kB)\n",
858
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.3/58.3 kB\u001b[0m \u001b[31m8.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
859
+ "\u001b[?25hRequirement already satisfied: click>=7.0 in /usr/local/lib/python3.9/dist-packages (from uvicorn->gradio) (8.1.3)\n",
860
+ "Requirement already satisfied: anyio<5.0,>=3.0 in /usr/local/lib/python3.9/dist-packages (from httpcore<0.18.0,>=0.15.0->httpx->gradio) (3.6.2)\n",
861
+ "Requirement already satisfied: zipp>=3.1.0 in /usr/local/lib/python3.9/dist-packages (from importlib-resources>=3.2.0->matplotlib->gradio) (3.15.0)\n",
862
+ "Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.9/dist-packages (from jsonschema>=3.0->altair>=4.2.0->gradio) (0.19.3)\n",
863
+ "Collecting uc-micro-py\n",
864
+ " Downloading uc_micro_py-1.0.1-py3-none-any.whl (6.2 kB)\n",
865
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.9/dist-packages (from python-dateutil>=2.8.1->pandas->gradio) (1.16.0)\n",
866
+ "Building wheels for collected packages: ffmpy\n",
867
+ " Building wheel for ffmpy (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
868
+ " Created wheel for ffmpy: filename=ffmpy-0.3.0-py3-none-any.whl size=4707 sha256=030fcfbd0063a8e91f56986ba5a3eaeb8d3d94a5b7f0a2c726f9367cfd7d2fbf\n",
869
+ " Stored in directory: /root/.cache/pip/wheels/91/e2/96/f676aa08bfd789328c6576cd0f1fde4a3d686703bb0c247697\n",
870
+ "Successfully built ffmpy\n",
871
+ "Installing collected packages: pydub, ffmpy, websockets, uc-micro-py, semantic-version, python-multipart, orjson, multidict, h11, frozenlist, async-timeout, aiofiles, yarl, uvicorn, starlette, mdit-py-plugins, linkify-it-py, huggingface-hub, httpcore, aiosignal, httpx, fastapi, aiohttp, gradio-client, gradio\n",
872
+ "Successfully installed aiofiles-23.1.0 aiohttp-3.8.4 aiosignal-1.3.1 async-timeout-4.0.2 fastapi-0.95.1 ffmpy-0.3.0 frozenlist-1.3.3 gradio-3.27.0 gradio-client-0.1.3 h11-0.14.0 httpcore-0.17.0 httpx-0.24.0 huggingface-hub-0.13.4 linkify-it-py-2.0.0 mdit-py-plugins-0.3.3 multidict-6.0.4 orjson-3.8.10 pydub-0.25.1 python-multipart-0.0.6 semantic-version-2.10.0 starlette-0.26.1 uc-micro-py-1.0.1 uvicorn-0.21.1 websockets-11.0.2 yarl-1.8.2\n"
873
+ ]
874
+ }
875
+ ]
876
+ },
877
+ {
878
+ "cell_type": "code",
879
+ "source": [
880
+ "import gradio as gr"
881
+ ],
882
+ "metadata": {
883
+ "id": "Q5a9SQbcqLH7"
884
+ },
885
+ "execution_count": 48,
886
+ "outputs": []
887
+ },
888
+ {
889
+ "cell_type": "code",
890
+ "source": [
891
+ "\n",
892
+ "\n",
893
+ "inputs = gr.inputs.Image()\n",
894
+ "outputs = gr.outputs.Textbox()\n",
895
+ "iface = gr.Interface(fn=predict, inputs=inputs, outputs=outputs, capture_session=True)\n",
896
+ "iface.launch()"
897
+ ],
898
+ "metadata": {
899
+ "colab": {
900
+ "base_uri": "https://localhost:8080/",
901
+ "height": 775
902
+ },
903
+ "id": "sEsxRg9IqLue",
904
+ "outputId": "1ea4931d-4001-4c37-a0f9-97017b2e55a6"
905
+ },
906
+ "execution_count": 49,
907
+ "outputs": [
908
+ {
909
+ "output_type": "stream",
910
+ "name": "stderr",
911
+ "text": [
912
+ "/usr/local/lib/python3.9/dist-packages/gradio/inputs.py:257: UserWarning: Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components\n",
913
+ " warnings.warn(\n",
914
+ "/usr/local/lib/python3.9/dist-packages/gradio/deprecation.py:40: UserWarning: `optional` parameter is deprecated, and it has no effect\n",
915
+ " warnings.warn(value)\n",
916
+ "/usr/local/lib/python3.9/dist-packages/gradio/outputs.py:22: UserWarning: Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components\n",
917
+ " warnings.warn(\n",
918
+ "/usr/local/lib/python3.9/dist-packages/gradio/deprecation.py:40: UserWarning: `capture_session` parameter is deprecated, and it has no effect\n",
919
+ " warnings.warn(value)\n"
920
+ ]
921
+ },
922
+ {
923
+ "output_type": "stream",
924
+ "name": "stdout",
925
+ "text": [
926
+ "Colab notebook detected. To show errors in colab notebook, set debug=True in launch()\n",
927
+ "Note: opening Chrome Inspector may crash demo inside Colab notebooks.\n",
928
+ "\n",
929
+ "To create a public link, set `share=True` in `launch()`.\n"
930
+ ]
931
+ },
932
+ {
933
+ "output_type": "display_data",
934
+ "data": {
935
+ "text/plain": [
936
+ "<IPython.core.display.Javascript object>"
937
+ ],
938
+ "application/javascript": [
939
+ "(async (port, path, width, height, cache, element) => {\n",
940
+ " if (!google.colab.kernel.accessAllowed && !cache) {\n",
941
+ " return;\n",
942
+ " }\n",
943
+ " element.appendChild(document.createTextNode(''));\n",
944
+ " const url = await google.colab.kernel.proxyPort(port, {cache});\n",
945
+ "\n",
946
+ " const external_link = document.createElement('div');\n",
947
+ " external_link.innerHTML = `\n",
948
+ " <div style=\"font-family: monospace; margin-bottom: 0.5rem\">\n",
949
+ " Running on <a href=${new URL(path, url).toString()} target=\"_blank\">\n",
950
+ " https://localhost:${port}${path}\n",
951
+ " </a>\n",
952
+ " </div>\n",
953
+ " `;\n",
954
+ " element.appendChild(external_link);\n",
955
+ "\n",
956
+ " const iframe = document.createElement('iframe');\n",
957
+ " iframe.src = new URL(path, url).toString();\n",
958
+ " iframe.height = height;\n",
959
+ " iframe.allow = \"autoplay; camera; microphone; clipboard-read; clipboard-write;\"\n",
960
+ " iframe.width = width;\n",
961
+ " iframe.style.border = 0;\n",
962
+ " element.appendChild(iframe);\n",
963
+ " })(7860, \"/\", \"100%\", 500, false, window.element)"
964
+ ]
965
+ },
966
+ "metadata": {}
967
+ },
968
+ {
969
+ "output_type": "execute_result",
970
+ "data": {
971
+ "text/plain": []
972
+ },
973
+ "metadata": {},
974
+ "execution_count": 49
975
+ }
976
+ ]
977
+ },
978
+ {
979
+ "cell_type": "code",
980
+ "source": [
981
+ "import gradio as gr\n",
982
+ "import torch\n",
983
+ "import numpy as np\n",
984
+ "from PIL import Image\n",
985
+ "\n",
986
+ "# define the predict function\n",
987
+ "def predict(image):\n",
988
+ " # preprocess the image\n",
989
+ " image = np.array(image)\n",
990
+ " image = test_transform(image=image)['image']\n",
991
+ " image = image.unsqueeze(0).to(device)\n",
992
+ "\n",
993
+ " # get the model prediction\n",
994
+ " with torch.no_grad():\n",
995
+ " output = model(image)\n",
996
+ " pred = torch.sigmoid(output).cpu().numpy().squeeze()\n",
997
+ " \n",
998
+ " # return the prediction as a string\n",
999
+ " return f\"This image is {'AI generated' if pred > 0.75 else 'NOT AI generated'}\"\n",
1000
+ "\n",
1001
+ "# define the input interface with examples\n",
1002
+ "inputs = gr.inputs.Image(shape=(224, 224))\n",
1003
+ "outputs = gr.outputs.Textbox()\n",
1004
+ "examples = [\n",
1005
+ " ['/content/drive/MyDrive/Colab Notebooks/AI images or Not/train/3.jpg'],\n",
1006
+ " ['/content/drive/MyDrive/Colab Notebooks/AI images or Not/train/10.jpg'],\n",
1007
+ " ['/content/drive/MyDrive/Colab Notebooks/AI images or Not/train/14.jpg'],\n",
1008
+ " ['/content/drive/MyDrive/Colab Notebooks/AI images or Not/train/4515.jpg']\n",
1009
+ " ['/content/drive/MyDrive/Colab Notebooks/AI images or Not/train/4518.jpg'],\n",
1010
+ "]\n",
1011
+ "iface = gr.Interface(fn=predict, inputs=inputs, outputs=outputs, examples=examples)\n",
1012
+ "\n",
1013
+ "# launch the gradio app\n",
1014
+ "iface.launch()"
1015
+ ],
1016
+ "metadata": {
1017
+ "colab": {
1018
+ "base_uri": "https://localhost:8080/",
1019
+ "height": 428
1020
+ },
1021
+ "id": "nMuNn5FCvEuS",
1022
+ "outputId": "ad4760a5-9458-483a-b9bc-c655f0bf6429"
1023
+ },
1024
+ "execution_count": 55,
1025
+ "outputs": [
1026
+ {
1027
+ "output_type": "stream",
1028
+ "name": "stderr",
1029
+ "text": [
1030
+ "<>:28: SyntaxWarning: list indices must be integers or slices, not str; perhaps you missed a comma?\n",
1031
+ "<>:28: SyntaxWarning: list indices must be integers or slices, not str; perhaps you missed a comma?\n",
1032
+ "/usr/local/lib/python3.9/dist-packages/gradio/inputs.py:257: UserWarning: Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components\n",
1033
+ " warnings.warn(\n",
1034
+ "/usr/local/lib/python3.9/dist-packages/gradio/deprecation.py:40: UserWarning: `optional` parameter is deprecated, and it has no effect\n",
1035
+ " warnings.warn(value)\n",
1036
+ "/usr/local/lib/python3.9/dist-packages/gradio/outputs.py:22: UserWarning: Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components\n",
1037
+ " warnings.warn(\n",
1038
+ "<ipython-input-55-ad9875932060>:28: SyntaxWarning: list indices must be integers or slices, not str; perhaps you missed a comma?\n",
1039
+ " ['/content/drive/MyDrive/Colab Notebooks/AI images or Not/train/4515.jpg']\n"
1040
+ ]
1041
+ },
1042
+ {
1043
+ "output_type": "error",
1044
+ "ename": "TypeError",
1045
+ "evalue": "ignored",
1046
+ "traceback": [
1047
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1048
+ "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
1049
+ "\u001b[0;32m<ipython-input-55-ad9875932060>\u001b[0m in \u001b[0;36m<cell line: 25>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'/content/drive/MyDrive/Colab Notebooks/AI images or Not/train/10.jpg'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'/content/drive/MyDrive/Colab Notebooks/AI images or Not/train/14.jpg'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 28\u001b[0;31m \u001b[0;34m[\u001b[0m\u001b[0;34m'/content/drive/MyDrive/Colab Notebooks/AI images or Not/train/4515.jpg'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 29\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'/content/drive/MyDrive/Colab Notebooks/AI images or Not/train/4518.jpg'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m ]\n",
1050
+ "\u001b[0;31mTypeError\u001b[0m: list indices must be integers or slices, not str"
1051
+ ]
1052
+ }
1053
+ ]
1054
+ }
1055
+ ]
1056
+ }
Code/Training NB.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Code/Training_NB.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Code/app.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import albumentations as A
4
+ import gradio as gr
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from albumentations.pytorch import ToTensorV2
11
+ from efficientnet_pytorch import EfficientNet
12
+ from PIL import Image
13
+ from sklearn import metrics
14
+ from torch import nn, optim
15
+ from torch.utils.data import DataLoader, Dataset
16
+ from torchvision import models
17
+ from tqdm import tqdm
18
+
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+
21
+
22
+ class Dataset(Dataset):
23
+ def __init__(self, root_images, root_file, transform=None):
24
+ self.root_images = root_images
25
+ self.root_file = root_file
26
+ self.transform = transform
27
+ self.file = pd.read_csv(root_file)
28
+
29
+ def __len__(self):
30
+ return self.file.shape[0]
31
+
32
+ def __getitem__(self, index):
33
+ img_path = os.path.join(self.root_images, self.file["id"][index])
34
+ image = np.array(Image.open(img_path).convert("RGB"))
35
+
36
+ if self.transform is not None:
37
+ augmentations = self.transform(image=image)
38
+ image = augmentations["image"]
39
+
40
+ return image
41
+
42
+
43
+ learning_rate = 0.0001
44
+ batch_size = 32
45
+ epochs = 10
46
+ height = 224
47
+ width = 224
48
+ IMG = "AI images or Not/test"
49
+ FILE = "Data/sample_submission.csv"
50
+
51
+
52
+ def get_loader(image, file, batch_size, test_transform):
53
+
54
+ test_ds = Dataset(image, file, test_transform)
55
+ test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)
56
+
57
+ return test_loader
58
+
59
+
60
+ normalize = A.Normalize(
61
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.255], max_pixel_value=255.0
62
+ )
63
+
64
+
65
+ test_transform = A.Compose(
66
+ [A.Resize(width=width, height=height), normalize, ToTensorV2()]
67
+ )
68
+
69
+
70
+ class Net(nn.Module):
71
+ def __init__(self):
72
+ super().__init__()
73
+ self.model = EfficientNet.from_pretrained("efficientnet-b4")
74
+ self.fct = nn.Linear(1000, 1)
75
+
76
+ def forward(self, img):
77
+ x = self.model(img)
78
+ # print(x.shape)
79
+ x = self.fct(x)
80
+ return x
81
+
82
+
83
+ def load_checkpoint(checkpoint, model, optimizer):
84
+ print("====> Loading...")
85
+ model.load_state_dict(checkpoint["state_dict"])
86
+ optimizer.load_state_dict(checkpoint["optimizer"])
87
+
88
+
89
+ # test = pd.read_csv(FILE)
90
+ # test
91
+
92
+ model = Net().to(device)
93
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
94
+
95
+ checkpoint_file = "Checkpoint/baseline_V0.pth.tar"
96
+ test_loader = get_loader(IMG, FILE, batch_size, test_transform)
97
+ checkpoint = torch.load(checkpoint_file, map_location=torch.device("cpu"))
98
+ load_checkpoint(checkpoint, model, optimizer)
99
+
100
+ model.eval()
101
+
102
+
103
+ # define the predict function
104
+ def predict(image):
105
+ # preprocess the image
106
+ image = np.array(image)
107
+ image = test_transform(image=image)["image"]
108
+ image = image.unsqueeze(0).to(device)
109
+
110
+ # get the model prediction
111
+ with torch.no_grad():
112
+ output = model(image)
113
+ pred = torch.sigmoid(output).cpu().numpy().squeeze()
114
+
115
+ # check if prediction is AI generated, not AI generated, or uncertain
116
+ if pred >= 0.6:
117
+ prediction = "AI generated"
118
+ confidence = pred
119
+ elif pred <= 0.4:
120
+ prediction = "NOT AI generated"
121
+ confidence = 1 - pred
122
+ else:
123
+ prediction = "uncertain"
124
+ confidence = abs(0.5 - pred) * 2
125
+
126
+ # return the prediction and confidence as a string
127
+ return f"This image is {prediction} with {confidence:.2%} confidence."
128
+
129
+
130
+ # define the input interface with examples
131
+ inputs = gr.inputs.Image(shape=(224, 224))
132
+ outputs = gr.outputs.Textbox()
133
+ examples = [
134
+ ["Data/train/3.jpg"],
135
+ ["Data/train/10.jpg"],
136
+ ["Data/train/14.jpg"],
137
+ ["Data/train/4515.jpg"],
138
+ ["Data/train/4518.jpg"],
139
+ ["Data/train/6122.jpg"],
140
+ ["Data/train/6123.jpg"],
141
+ ["Data/train/6124.jpg"],
142
+ ["Data/train/6125.jpg"],
143
+ ["Data/train/7461.jpg"],
144
+ ["Data/train/7462.jpg"],
145
+ ["Data/train/7463.jpg"],
146
+ ["Data/train/7464.jpg"],
147
+ ["Data/train/7465.jpg"],
148
+ ["Data/train/8546.jpg"],
149
+ ["Data/train/8543.jpg"],
150
+ ["Data/train/9120.jpg"],
151
+ ["Data/train/10120.jpg"],
152
+ ]
153
+ iface = gr.Interface(
154
+ fn=predict,
155
+ inputs=inputs,
156
+ outputs=outputs,
157
+ title="AI image detector πŸ”Ž",
158
+ description="Check if an image is AI generated or real.",
159
+ examples=examples,
160
+ )
161
+
162
+ # launch the gradio app
163
+ iface.launch()
Code/test.ipynb ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "Token is valid.\n",
13
+ "Your token has been saved in your configured git credential helpers (manager-core).\n",
14
+ "Your token has been saved to C:\\Users\\Oussama\\.cache\\huggingface\\token\n",
15
+ "Login successful\n"
16
+ ]
17
+ }
18
+ ],
19
+ "source": [
20
+ "from huggingface_hub import login\n",
21
+ "login()"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": 3,
27
+ "metadata": {},
28
+ "outputs": [
29
+ {
30
+ "data": {
31
+ "application/vnd.jupyter.widget-view+json": {
32
+ "model_id": "e5eb52f9282f43cfa4f06b1d9c6dc08b",
33
+ "version_major": 2,
34
+ "version_minor": 0
35
+ },
36
+ "text/plain": [
37
+ "Downloading readme: 0%| | 0.00/1.24k [00:00<?, ?B/s]"
38
+ ]
39
+ },
40
+ "metadata": {},
41
+ "output_type": "display_data"
42
+ },
43
+ {
44
+ "name": "stdout",
45
+ "output_type": "stream",
46
+ "text": [
47
+ "Downloading and preparing dataset None/None to C:/Users/Oussama/.cache/huggingface/datasets/competitions___parquet/competitions--aiornot-759454878caed5d9/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...\n"
48
+ ]
49
+ },
50
+ {
51
+ "data": {
52
+ "application/vnd.jupyter.widget-view+json": {
53
+ "model_id": "37ae63ab36cc448e90aa292609f4e5d4",
54
+ "version_major": 2,
55
+ "version_minor": 0
56
+ },
57
+ "text/plain": [
58
+ "Downloading data files: 0%| | 0/2 [00:00<?, ?it/s]"
59
+ ]
60
+ },
61
+ "metadata": {},
62
+ "output_type": "display_data"
63
+ },
64
+ {
65
+ "data": {
66
+ "application/vnd.jupyter.widget-view+json": {
67
+ "model_id": "e208b22a102847b4862d336b32cce3e1",
68
+ "version_major": 2,
69
+ "version_minor": 0
70
+ },
71
+ "text/plain": [
72
+ "Downloading data: 0%| | 0.00/354M [00:00<?, ?B/s]"
73
+ ]
74
+ },
75
+ "metadata": {},
76
+ "output_type": "display_data"
77
+ },
78
+ {
79
+ "data": {
80
+ "application/vnd.jupyter.widget-view+json": {
81
+ "model_id": "c31d5357bf9c44e99b19820ea10ad70a",
82
+ "version_major": 2,
83
+ "version_minor": 0
84
+ },
85
+ "text/plain": [
86
+ "Downloading data: 0%| | 0.00/356M [00:00<?, ?B/s]"
87
+ ]
88
+ },
89
+ "metadata": {},
90
+ "output_type": "display_data"
91
+ },
92
+ {
93
+ "data": {
94
+ "application/vnd.jupyter.widget-view+json": {
95
+ "model_id": "b7de3daec8a54b28a07cd580734f4bf4",
96
+ "version_major": 2,
97
+ "version_minor": 0
98
+ },
99
+ "text/plain": [
100
+ "Downloading data: 0%| | 0.00/415M [00:00<?, ?B/s]"
101
+ ]
102
+ },
103
+ "metadata": {},
104
+ "output_type": "display_data"
105
+ },
106
+ {
107
+ "data": {
108
+ "application/vnd.jupyter.widget-view+json": {
109
+ "model_id": "67b76cdca655426d9174aa7655345ecb",
110
+ "version_major": 2,
111
+ "version_minor": 0
112
+ },
113
+ "text/plain": [
114
+ "Downloading data: 0%| | 0.00/418M [00:00<?, ?B/s]"
115
+ ]
116
+ },
117
+ "metadata": {},
118
+ "output_type": "display_data"
119
+ },
120
+ {
121
+ "data": {
122
+ "application/vnd.jupyter.widget-view+json": {
123
+ "model_id": "373dd6818e5c4500b9d48de516bcb427",
124
+ "version_major": 2,
125
+ "version_minor": 0
126
+ },
127
+ "text/plain": [
128
+ "Downloading data: 0%| | 0.00/416M [00:00<?, ?B/s]"
129
+ ]
130
+ },
131
+ "metadata": {},
132
+ "output_type": "display_data"
133
+ },
134
+ {
135
+ "data": {
136
+ "application/vnd.jupyter.widget-view+json": {
137
+ "model_id": "9fc356325be34db1844de564ee0395fc",
138
+ "version_major": 2,
139
+ "version_minor": 0
140
+ },
141
+ "text/plain": [
142
+ "Downloading data: 0%| | 0.00/416M [00:00<?, ?B/s]"
143
+ ]
144
+ },
145
+ "metadata": {},
146
+ "output_type": "display_data"
147
+ },
148
+ {
149
+ "data": {
150
+ "application/vnd.jupyter.widget-view+json": {
151
+ "model_id": "8d8d0259ea1840d38f316bfc3809fa44",
152
+ "version_major": 2,
153
+ "version_minor": 0
154
+ },
155
+ "text/plain": [
156
+ "Extracting data files: 0%| | 0/2 [00:00<?, ?it/s]"
157
+ ]
158
+ },
159
+ "metadata": {},
160
+ "output_type": "display_data"
161
+ },
162
+ {
163
+ "data": {
164
+ "application/vnd.jupyter.widget-view+json": {
165
+ "model_id": "1780ace548794a58b318e723ae297b50",
166
+ "version_major": 2,
167
+ "version_minor": 0
168
+ },
169
+ "text/plain": [
170
+ "Generating train split: 0 examples [00:00, ? examples/s]"
171
+ ]
172
+ },
173
+ "metadata": {},
174
+ "output_type": "display_data"
175
+ },
176
+ {
177
+ "data": {
178
+ "application/vnd.jupyter.widget-view+json": {
179
+ "model_id": "fbd137787e1340bfaef5b62f6fbc7556",
180
+ "version_major": 2,
181
+ "version_minor": 0
182
+ },
183
+ "text/plain": [
184
+ "Generating test split: 0 examples [00:00, ? examples/s]"
185
+ ]
186
+ },
187
+ "metadata": {},
188
+ "output_type": "display_data"
189
+ },
190
+ {
191
+ "name": "stdout",
192
+ "output_type": "stream",
193
+ "text": [
194
+ "Dataset parquet downloaded and prepared to C:/Users/Oussama/.cache/huggingface/datasets/competitions___parquet/competitions--aiornot-759454878caed5d9/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.\n"
195
+ ]
196
+ },
197
+ {
198
+ "data": {
199
+ "application/vnd.jupyter.widget-view+json": {
200
+ "model_id": "2087104741e1465d9fc4b033d5cd8bdf",
201
+ "version_major": 2,
202
+ "version_minor": 0
203
+ },
204
+ "text/plain": [
205
+ " 0%| | 0/2 [00:00<?, ?it/s]"
206
+ ]
207
+ },
208
+ "metadata": {},
209
+ "output_type": "display_data"
210
+ }
211
+ ],
212
+ "source": [
213
+ "from datasets import load_dataset\n",
214
+ "\n",
215
+ "# If the dataset is gated/private, make sure you have run huggingface-cli login\n",
216
+ "dataset = load_dataset(\"competitions/aiornot\")"
217
+ ]
218
+ },
219
+ {
220
+ "cell_type": "code",
221
+ "execution_count": null,
222
+ "metadata": {},
223
+ "outputs": [],
224
+ "source": []
225
+ },
226
+ {
227
+ "cell_type": "code",
228
+ "execution_count": null,
229
+ "metadata": {},
230
+ "outputs": [],
231
+ "source": []
232
+ }
233
+ ],
234
+ "metadata": {
235
+ "kernelspec": {
236
+ "display_name": "Python 3",
237
+ "language": "python",
238
+ "name": "python3"
239
+ },
240
+ "language_info": {
241
+ "codemirror_mode": {
242
+ "name": "ipython",
243
+ "version": 3
244
+ },
245
+ "file_extension": ".py",
246
+ "mimetype": "text/x-python",
247
+ "name": "python",
248
+ "nbconvert_exporter": "python",
249
+ "pygments_lexer": "ipython3",
250
+ "version": "3.10.11"
251
+ },
252
+ "orig_nbformat": 4
253
+ },
254
+ "nbformat": 4,
255
+ "nbformat_minor": 2
256
+ }