sophiamyang commited on
Commit
3129503
Β·
1 Parent(s): 500d38a
Files changed (3) hide show
  1. Dockerfile +1 -1
  2. app.ipynb +0 -173
  3. app.py +122 -0
Dockerfile CHANGED
@@ -8,7 +8,7 @@ RUN python3 -m pip install --no-cache-dir --upgrade -r /code/requirements.txt
8
 
9
  COPY . .
10
 
11
- CMD ["panel", "serve", "/code/app.ipynb", "--address", "0.0.0.0", "--port", "7860", "--allow-websocket-origin", "*"]
12
 
13
  RUN mkdir /.cache
14
  RUN chmod 777 /.cache
 
8
 
9
  COPY . .
10
 
11
+ CMD ["panel", "serve", "/code/app.py", "--address", "0.0.0.0", "--port", "7860", "--allow-websocket-origin", "*"]
12
 
13
  RUN mkdir /.cache
14
  RUN chmod 777 /.cache
app.ipynb DELETED
@@ -1,173 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": null,
6
- "id": "8cd1e865-53d5-460b-8bae-5658e3aa3d16",
7
- "metadata": {},
8
- "outputs": [],
9
- "source": [
10
- "import panel as pn\n",
11
- "pn.extension()\n",
12
- "import requests\n",
13
- "import random\n",
14
- "import PIL\n",
15
- "from PIL import Image\n",
16
- "import io\n",
17
- "from transformers import CLIPProcessor, CLIPModel\n",
18
- "import numpy as np"
19
- ]
20
- },
21
- {
22
- "cell_type": "code",
23
- "execution_count": null,
24
- "id": "e8570053-0b83-421b-95c2-695b6c709ba1",
25
- "metadata": {},
26
- "outputs": [],
27
- "source": [
28
- "pn.extension('texteditor', template=\"bootstrap\", sizing_mode='stretch_width')\n",
29
- "\n",
30
- "pn.state.template.param.update(\n",
31
- " main_max_width=\"690px\",\n",
32
- " header_background=\"#F08080\",\n",
33
- ")"
34
- ]
35
- },
36
- {
37
- "cell_type": "code",
38
- "execution_count": null,
39
- "id": "ca65cc07-8181-4259-8770-9c780621eb78",
40
- "metadata": {},
41
- "outputs": [],
42
- "source": [
43
- "# File input widget\n",
44
- "file_input = pn.widgets.FileInput()\n",
45
- "\n",
46
- "# Button widget\n",
47
- "compute_button = pn.widgets.Button(name=\"Compute\")\n",
48
- "\n",
49
- "# Text input widget\n",
50
- "text_input = pn.widgets.TextInput(name='Possible class names (e.g., cat, dog)', placeholder='cat, dog')"
51
- ]
52
- },
53
- {
54
- "cell_type": "code",
55
- "execution_count": null,
56
- "id": "f3691594-df8c-4d03-99e8-db4d3b2520c0",
57
- "metadata": {},
58
- "outputs": [],
59
- "source": [
60
- "def normalize_image(value, width=600):\n",
61
- " \"\"\"\n",
62
- " normalize image to RBG channels and to the same size\n",
63
- " \"\"\"\n",
64
- " if value: \n",
65
- " b = io.BytesIO(value)\n",
66
- " image = PIL.Image.open(b).convert(\"RGB\")\n",
67
- " else: \n",
68
- " url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n",
69
- " image = Image.open(requests.get(url, stream=True).raw)\n",
70
- " aspect = image.size[1] / image.size[0]\n",
71
- " height = int(aspect * width)\n",
72
- " return image.resize((width, height), PIL.Image.LANCZOS)"
73
- ]
74
- },
75
- {
76
- "cell_type": "code",
77
- "execution_count": null,
78
- "id": "5b139802-c9d6-4493-acb2-5051343c1ecc",
79
- "metadata": {},
80
- "outputs": [],
81
- "source": [
82
- "def image_classification(image):\n",
83
- " model = CLIPModel.from_pretrained(\"openai/clip-vit-large-patch14\")\n",
84
- " processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-large-patch14\")\n",
85
- " possible_categories = text_input.value.split(\",\")\n",
86
- " if text_input.value == '':\n",
87
- " possible_categories = ['cat', ' dog']\n",
88
- " inputs = processor(text=possible_categories, images=image, return_tensors=\"pt\", padding=True)\n",
89
- " \n",
90
- " outputs = model(**inputs)\n",
91
- " logits_per_image = outputs.logits_per_image # this is the image-text similarity score\n",
92
- " probs = logits_per_image.softmax(dim=1)\n",
93
- " return probs.detach().numpy()"
94
- ]
95
- },
96
- {
97
- "cell_type": "code",
98
- "execution_count": null,
99
- "id": "6b6f0ce5-03a5-4a14-b0b7-74c8190ce928",
100
- "metadata": {},
101
- "outputs": [],
102
- "source": [
103
- "def get_result(_):\n",
104
- " image = normalize_image(file_input.value)\n",
105
- "\n",
106
- " result = image_classification(image)\n",
107
- " \n",
108
- " possible_categories = text_input.value.split(\",\")\n",
109
- " if text_input.value == '':\n",
110
- " possible_categories = ['cat', ' dog']\n",
111
- "\n",
112
- " progress_bars = pn.Column(*[\n",
113
- " pn.Row(\n",
114
- " possible_categories[i], \n",
115
- " pn.indicators.Progress(name='', value=int(j*100), width=500))\n",
116
- " for i, j in enumerate(result[0])\n",
117
- " ])\n",
118
- " return progress_bars\n",
119
- " "
120
- ]
121
- },
122
- {
123
- "cell_type": "code",
124
- "execution_count": null,
125
- "id": "6fd5a63f-012a-419c-8386-22b5b8ff243f",
126
- "metadata": {},
127
- "outputs": [],
128
- "source": [
129
- "# Bind the get_image function with the button widget\n",
130
- "interactive_result = pn.bind(get_result, compute_button)"
131
- ]
132
- },
133
- {
134
- "cell_type": "code",
135
- "execution_count": null,
136
- "id": "399189f1-4ff6-4f4b-b050-76e9a46443dd",
137
- "metadata": {},
138
- "outputs": [],
139
- "source": [
140
- "# layout\n",
141
- "pn.Column(\n",
142
- " \"## \\U0001F60A Upload an image file and start classifying!\",\n",
143
- " file_input,\n",
144
- " pn.bind(pn.panel, file_input),\n",
145
- " text_input, \n",
146
- " compute_button,\n",
147
- " interactive_result\n",
148
- ").servable(title=\"Panel Image Classification Demo\")"
149
- ]
150
- }
151
- ],
152
- "metadata": {
153
- "kernelspec": {
154
- "display_name": "Python 3 (ipykernel)",
155
- "language": "python",
156
- "name": "python3"
157
- },
158
- "language_info": {
159
- "codemirror_mode": {
160
- "name": "ipython",
161
- "version": 3
162
- },
163
- "file_extension": ".py",
164
- "mimetype": "text/x-python",
165
- "name": "python",
166
- "nbconvert_exporter": "python",
167
- "pygments_lexer": "ipython3",
168
- "version": "3.10.11"
169
- }
170
- },
171
- "nbformat": 4,
172
- "nbformat_minor": 5
173
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import panel as pn
3
+ import requests
4
+ from PIL import Image
5
+
6
+ from transformers import CLIPProcessor, CLIPModel
7
+ from typing import List, Tuple
8
+
9
+ pn.extension("texteditor", sizing_mode="stretch_width")
10
+
11
+
12
+ def set_random_url(_):
13
+ if random.randint(0, 1) == 0:
14
+ api_url = "https://api.thecatapi.com/v1/images/search"
15
+ else:
16
+ api_url = "https://api.thedogapi.com/v1/images/search"
17
+ with requests.get(api_url) as resp:
18
+ resp.raise_for_status()
19
+ url = resp.json()[0]["url"]
20
+ image_url.value = url
21
+
22
+
23
+ @pn.cache
24
+ def load_processor_model(
25
+ processor_name: str, model_name: str
26
+ ) -> Tuple[CLIPProcessor, CLIPModel]:
27
+ processor = CLIPProcessor.from_pretrained(processor_name)
28
+ model = CLIPModel.from_pretrained(model_name)
29
+ return processor, model
30
+
31
+
32
+ @pn.cache
33
+ def open_image_url(image_url: str) -> Image:
34
+ with requests.get(image_url, stream=True) as resp:
35
+ resp.raise_for_status()
36
+ image = Image.open(resp.raw)
37
+ return image
38
+
39
+
40
+ def get_similarity_scores(class_items: List[str], image: Image) -> List[float]:
41
+ processor, model = load_processor_model(
42
+ "openai/clip-vit-base-patch32", "openai/clip-vit-base-patch32"
43
+ )
44
+ inputs = processor(
45
+ text=class_items,
46
+ images=[image],
47
+ return_tensors="pt", # pytorch tensors
48
+ )
49
+ outputs = model(**inputs)
50
+ logits_per_image = outputs.logits_per_image
51
+ class_likelihoods = logits_per_image.softmax(dim=1).detach().numpy()
52
+ return class_likelihoods[0]
53
+
54
+
55
+ def process_inputs(class_names: List[str], image_url: str):
56
+ """
57
+ High level function that takes in the user inputs and returns the
58
+ classification results as panel objects.
59
+ """
60
+ image = open_image_url(image_url)
61
+ class_items = class_names.split(",")
62
+ class_likelihoods = get_similarity_scores(class_items, image)
63
+
64
+ # build the results column
65
+ results_column = pn.Column("## πŸŽ‰ Here are the results!")
66
+
67
+ results_column.append(
68
+ pn.pane.Image(image, max_width=698, sizing_mode="scale_width")
69
+ )
70
+
71
+ for class_item, class_likelihood in zip(class_items, class_likelihoods):
72
+ row_label = pn.widgets.StaticText(
73
+ name=class_item.strip(), value=f"{class_likelihood:.2%}", margin=(0, 10)
74
+ )
75
+ row_bar = pn.indicators.Progress(
76
+ max=100,
77
+ value=int(class_likelihood * 100),
78
+ sizing_mode="stretch_width",
79
+ bar_color="secondary",
80
+ margin=(0, 10),
81
+ )
82
+ row_column = pn.Column(row_label, row_bar)
83
+ results_column.append(row_column)
84
+ return results_column
85
+
86
+ # create widgets
87
+ randomize_url = pn.widgets.Button(name="Randomize URL", align="end")
88
+
89
+ image_url = pn.widgets.TextInput(
90
+ name="Image URL to classify",
91
+ value="https://cdn2.thecatapi.com/images/cct.jpg",
92
+ )
93
+ class_names = pn.widgets.TextInput(
94
+ name="Comma separated class names",
95
+ placeholder="Enter possible class names, e.g. cat, dog",
96
+ value="cat, dog, parrot",
97
+ )
98
+
99
+ input_widgets = pn.Column(
100
+ "## 😊 Click randomize or paste a URL to start classifying!",
101
+ pn.Row(image_url, randomize_url),
102
+ class_names,
103
+ )
104
+
105
+ # add interactivity
106
+ randomize_url.on_click(set_random_url)
107
+ interactive_result = pn.bind(
108
+ process_inputs, image_url=image_url, class_names=class_names
109
+ )
110
+
111
+ # create dashboard
112
+ main = pn.WidgetBox(
113
+ input_widgets,
114
+ interactive_result,
115
+ )
116
+
117
+ pn.template.BootstrapTemplate(
118
+ title="Panel Image Classification Demo",
119
+ main=main,
120
+ main_max_width="min(50%, 698px)",
121
+ header_background="#F08080",
122
+ ).servable(title="Panel Image Classification Demo")