Steph974 commited on
Commit
35abf20
·
verified ·
1 Parent(s): bca559d

Upload gradio - Copie.ipynb

Browse files
Files changed (1) hide show
  1. gradio - Copie.ipynb +226 -0
gradio - Copie.ipynb ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "<center>\n",
8
+ "\n",
9
+ "## [S. Mussard](https://sites.google.com/view/cv-stphane-mussard/accueil \"Homepage\")\n",
10
+ "\n",
11
+ "# UM6P\n",
12
+ "\n",
13
+ "# Natural Language Processing: LOGIT\n",
14
+ "\n",
15
+ "\n",
16
+ "<center> <a href=\"https://www.fgses-um6p.ma/\"><img src=\"UM6P.png\",style=\"float: left; max-width: 500px; width: 20\" />\n",
17
+ "\n",
18
+ "\n",
19
+ "\n",
20
+ "<div align=\"center\"> \n",
21
+ "<a href=\"https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.TfidfVectorizer.html\"><img src=\"http://scikit-learn.org/stable/_static/scikit-learn-logo-small.png\" style=\"max-width: 180px; display: inline\" alt=\"Scikit-Learn\"/></a>\n",
22
+ "</div>\n",
23
+ "<div align=\"center\"> <a href=\"https://www.python.org/\"><img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/f/f8/Python_logo_and_wordmark.svg/390px-Python_logo_and_wordmark.svg.png\" style=\"max-width: 150px; display: inline\" alt=\"Python\"/></a> \n",
24
+ "</div>\n",
25
+ " \n"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "markdown",
30
+ "metadata": {},
31
+ "source": [
32
+ "<div align=\"center\">\n",
33
+ "\n",
34
+ "## Sentiment Analysis"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": 1,
40
+ "metadata": {},
41
+ "outputs": [],
42
+ "source": [
43
+ "# Importation \n",
44
+ "\n",
45
+ "%matplotlib inline \n",
46
+ "import numpy as np\n",
47
+ "import pandas as pd\n",
48
+ "import matplotlib.pyplot as plt\n",
49
+ "from sklearn import metrics\n",
50
+ "import torch\n",
51
+ "from torch.utils.data import Dataset, DataLoader\n",
52
+ "from transformers import AutoModel, AutoTokenizer\n",
53
+ "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
54
+ "\n",
55
+ "import gradio as gr\n",
56
+ "from gradio.components import Label"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": 5,
62
+ "metadata": {},
63
+ "outputs": [
64
+ {
65
+ "name": "stderr",
66
+ "output_type": "stream",
67
+ "text": [
68
+ "Some weights of the model checkpoint at ./poids were not used when initializing RobertaModel: ['classifier.out_proj.weight', 'classifier.dense.bias', 'classifier.out_proj.bias', 'classifier.dense.weight']\n",
69
+ "- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
70
+ "- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
71
+ "Some weights of RobertaModel were not initialized from the model checkpoint at ./poids and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']\n",
72
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
73
+ ]
74
+ }
75
+ ],
76
+ "source": [
77
+ "path = \"./weights\"\n",
78
+ "model = AutoModel.from_pretrained(path, trust_remote_code=True)\n",
79
+ "class CamembertClass(torch.nn.Module):\n",
80
+ " def __init__(self):\n",
81
+ " super(CamembertClass, self).__init__()\n",
82
+ " self.l1 = model\n",
83
+ " self.dropout = torch.nn.Dropout(0.1)\n",
84
+ " self.pre_classifier = torch.nn.Linear(1024, 1024)\n",
85
+ " self.classifier = torch.nn.Linear(1024, 3)\n",
86
+ "\n",
87
+ " def forward(self, input_ids, attention_mask, token_type_ids):\n",
88
+ " output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)\n",
89
+ " hidden_state = output_1[0]\n",
90
+ " pooler = hidden_state[:, 0]\n",
91
+ " pooler = self.pre_classifier(pooler)\n",
92
+ " pooler = torch.nn.ReLU()(pooler)\n",
93
+ " pooler = self.dropout(pooler)\n",
94
+ " output = self.classifier(pooler)\n",
95
+ " return output"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": 6,
101
+ "metadata": {},
102
+ "outputs": [],
103
+ "source": [
104
+ "#model_gradio = CamembertClass()\n",
105
+ "path = \"./pytorch_model.bin\"\n",
106
+ "model = torch.load(path, map_location=\"cpu\")\n",
107
+ "path_tokenizer = \"./\"\n",
108
+ "tokenizer = AutoTokenizer.from_pretrained(path_tokenizer)\n"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": 4,
114
+ "metadata": {},
115
+ "outputs": [],
116
+ "source": [
117
+ "#pip install pydantic==1.10.7"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": 7,
123
+ "metadata": {},
124
+ "outputs": [
125
+ {
126
+ "name": "stdout",
127
+ "output_type": "stream",
128
+ "text": [
129
+ "Running on local URL: http://127.0.0.1:7860\n",
130
+ "Running on public URL: https://93ecddda8853b625c0.gradio.live\n",
131
+ "\n",
132
+ "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
133
+ ]
134
+ },
135
+ {
136
+ "data": {
137
+ "text/html": [
138
+ "<div><iframe src=\"https://93ecddda8853b625c0.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
139
+ ],
140
+ "text/plain": [
141
+ "<IPython.core.display.HTML object>"
142
+ ]
143
+ },
144
+ "metadata": {},
145
+ "output_type": "display_data"
146
+ },
147
+ {
148
+ "data": {
149
+ "text/plain": []
150
+ },
151
+ "execution_count": 7,
152
+ "metadata": {},
153
+ "output_type": "execute_result"
154
+ }
155
+ ],
156
+ "source": [
157
+ "model.eval() # Mettez votre modèle en mode évaluation\n",
158
+ "\n",
159
+ "# Fonction d'inférence pour Gradio\n",
160
+ "def predict(text):\n",
161
+ " inputs = tokenizer(text, return_tensors=\"pt\", padding=True, truncation=True, max_length=512)\n",
162
+ " \n",
163
+ " # Extract necessary inputs for the model\n",
164
+ " input_ids = inputs['input_ids']\n",
165
+ " attention_mask = inputs['attention_mask']\n",
166
+ " token_type_ids = inputs.get('token_type_ids', None) # Some models do not use segment IDs\n",
167
+ " \n",
168
+ " # Make prediction\n",
169
+ " with torch.no_grad():\n",
170
+ " # Directly use outputs if your model returns logits directly\n",
171
+ " logits = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)\n",
172
+ "\n",
173
+ " \n",
174
+ " # Convert logits to probabilities\n",
175
+ " probabilities = torch.softmax(logits, dim=1).detach().cpu().numpy()[0]\n",
176
+ " # Replace the following with your actual classes\n",
177
+ " classes = ['Negative Sentiment', 'Positive Sentiment']\n",
178
+ " return {classes[i]: float(probabilities[i]) for i in range(len(classes))}\n",
179
+ "\n",
180
+ "# Création de l'interface Gradio\n",
181
+ "iface = gr.Interface(fn=predict,\n",
182
+ " inputs=gr.components.Textbox(placeholder=\"Enter your text here...\"),\n",
183
+ " outputs=gr.components.Label(num_top_classes=2))\n",
184
+ "iface.launch(share=True)\n"
185
+ ]
186
+ }
187
+ ],
188
+ "metadata": {
189
+ "hide_input": false,
190
+ "kernelspec": {
191
+ "display_name": "Python 3",
192
+ "language": "python",
193
+ "name": "python3"
194
+ },
195
+ "language_info": {
196
+ "codemirror_mode": {
197
+ "name": "ipython",
198
+ "version": 3
199
+ },
200
+ "file_extension": ".py",
201
+ "mimetype": "text/x-python",
202
+ "name": "python",
203
+ "nbconvert_exporter": "python",
204
+ "pygments_lexer": "ipython3",
205
+ "version": "3.7.8"
206
+ },
207
+ "toc": {
208
+ "base_numbering": 1,
209
+ "nav_menu": {
210
+ "height": "244px",
211
+ "width": "252px"
212
+ },
213
+ "number_sections": true,
214
+ "sideBar": true,
215
+ "skip_h1_title": false,
216
+ "title_cell": "Table of Contents",
217
+ "title_sidebar": "Contents",
218
+ "toc_cell": false,
219
+ "toc_position": {},
220
+ "toc_section_display": "block",
221
+ "toc_window_display": false
222
+ }
223
+ },
224
+ "nbformat": 4,
225
+ "nbformat_minor": 1
226
+ }