Woziii commited on
Commit
bd87014
·
verified ·
1 Parent(s): e1ef0ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -41
app.py CHANGED
@@ -62,7 +62,7 @@ def analyze_next_token(input_text, temperature, top_p, top_k):
62
 
63
  try:
64
  with torch.no_grad():
65
- outputs = model(**inputs)
66
 
67
  last_token_logits = outputs.logits[0, -1, :]
68
  probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
@@ -76,13 +76,14 @@ def analyze_next_token(input_text, temperature, top_p, top_k):
76
  # Créer un texte explicatif
77
  prob_text = "Prochains tokens les plus probables :\n\n"
78
  for word, prob in prob_data.items():
79
- prob_text += f"{word}: {prob:.2%}\n"
 
80
 
81
  # Créer les visualisations
82
  prob_plot = plot_probabilities(prob_data)
83
- importance_plot = plot_token_importance(inputs["input_ids"][0], last_token_logits)
84
 
85
- return prob_text, importance_plot, prob_plot
86
  except Exception as e:
87
  return f"Erreur lors de l'analyse : {str(e)}", None, None
88
 
@@ -113,52 +114,33 @@ def plot_probabilities(prob_data):
113
  words = list(prob_data.keys())
114
  probs = list(prob_data.values())
115
 
116
- fig, ax = plt.subplots(figsize=(10, 5))
117
- bars = ax.bar(words, probs, color='lightgreen')
118
  ax.set_title("Probabilités des tokens suivants les plus probables")
119
  ax.set_xlabel("Tokens")
120
  ax.set_ylabel("Probabilité")
121
- plt.xticks(rotation=45, ha='right')
122
 
123
- # Ajouter les valeurs sur les barres
124
- for bar in bars:
 
 
125
  height = bar.get_height()
126
- ax.text(bar.get_x() + bar.get_width()/2., height,
127
- f'{height:.2%}',
128
- ha='center', va='bottom')
129
 
130
  plt.tight_layout()
131
  return fig
132
 
133
- def plot_token_importance(input_ids, last_token_logits):
134
  input_tokens = tokenizer.convert_ids_to_tokens(input_ids)
135
 
136
- # Calculer l'importance de chaque token
137
- importances = torch.abs(last_token_logits).sum() / len(last_token_logits)
138
- importances = importances.repeat(len(input_tokens))
139
-
140
- # Normaliser les importances
141
- importances = importances / importances.sum()
142
-
143
- # Créer la figure
144
- fig, ax = plt.subplots(figsize=(12, 3))
145
 
146
- # Créer un graphique à barres
147
- bars = ax.bar(range(len(input_tokens)), importances, color='skyblue')
148
-
149
- # Ajouter les labels et le titre
150
- ax.set_xticks(range(len(input_tokens)))
151
- ax.set_xticklabels(input_tokens, rotation=45, ha='right')
152
- ax.set_ylabel('Importance relative')
153
- ax.set_title('Importance des tokens d\'entrée pour la prédiction')
154
-
155
- # Ajouter les valeurs sur les barres
156
- for bar in bars:
157
- height = bar.get_height()
158
- ax.text(bar.get_x() + bar.get_width()/2., height,
159
- f'{height:.2%}',
160
- ha='center', va='bottom')
161
 
 
162
  plt.tight_layout()
163
  return fig
164
 
@@ -169,7 +151,7 @@ def reset():
169
  return "", 1.0, 1.0, 50, None, None, None, None
170
 
171
  with gr.Blocks() as demo:
172
- gr.Markdown("# LLM & Bias ")
173
 
174
  with gr.Accordion("Sélection du modèle"):
175
  model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle")
@@ -187,7 +169,7 @@ with gr.Blocks() as demo:
187
  next_token_probs = gr.Textbox(label="Probabilités du prochain token")
188
 
189
  with gr.Row():
190
- importance_plot = gr.Plot(label="Importance des tokens d'entrée")
191
  prob_plot = gr.Plot(label="Probabilités des tokens suivants")
192
 
193
  generate_button = gr.Button("Générer le prochain mot")
@@ -198,12 +180,12 @@ with gr.Blocks() as demo:
198
  load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
199
  analyze_button.click(analyze_next_token,
200
  inputs=[input_text, temperature, top_p, top_k],
201
- outputs=[next_token_probs, importance_plot, prob_plot])
202
  generate_button.click(generate_text,
203
  inputs=[input_text, temperature, top_p, top_k],
204
  outputs=[generated_text])
205
  reset_button.click(reset,
206
- outputs=[input_text, temperature, top_p, top_k, next_token_probs, importance_plot, prob_plot, generated_text])
207
 
208
  if __name__ == "__main__":
209
  demo.launch()
 
62
 
63
  try:
64
  with torch.no_grad():
65
+ outputs = model(**inputs, output_attentions=True)
66
 
67
  last_token_logits = outputs.logits[0, -1, :]
68
  probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
 
76
  # Créer un texte explicatif
77
  prob_text = "Prochains tokens les plus probables :\n\n"
78
  for word, prob in prob_data.items():
79
+ escaped_word = word.replace("<", "&lt;").replace(">", "&gt;")
80
+ prob_text += f"{escaped_word}: {prob:.2%}\n"
81
 
82
  # Créer les visualisations
83
  prob_plot = plot_probabilities(prob_data)
84
+ attention_plot = plot_attention(inputs["input_ids"][0], outputs.attentions)
85
 
86
+ return prob_text, attention_plot, prob_plot
87
  except Exception as e:
88
  return f"Erreur lors de l'analyse : {str(e)}", None, None
89
 
 
114
  words = list(prob_data.keys())
115
  probs = list(prob_data.values())
116
 
117
+ fig, ax = plt.subplots(figsize=(12, 6))
118
+ bars = ax.bar(range(len(words)), probs, color='lightgreen')
119
  ax.set_title("Probabilités des tokens suivants les plus probables")
120
  ax.set_xlabel("Tokens")
121
  ax.set_ylabel("Probabilité")
 
122
 
123
+ ax.set_xticks(range(len(words)))
124
+ ax.set_xticklabels(words, rotation=45, ha='right')
125
+
126
+ for i, (bar, word) in enumerate(zip(bars, words)):
127
  height = bar.get_height()
128
+ ax.text(i, height, f'{word}\n{height:.2%}',
129
+ ha='center', va='bottom', rotation=0)
 
130
 
131
  plt.tight_layout()
132
  return fig
133
 
134
+ def plot_attention(input_ids, attention_outputs):
135
  input_tokens = tokenizer.convert_ids_to_tokens(input_ids)
136
 
137
+ # Prendre la moyenne des attentions sur toutes les couches et têtes
138
+ attention = torch.mean(torch.cat(attention_outputs), dim=(0, 1)).cpu().numpy()
 
 
 
 
 
 
 
139
 
140
+ fig, ax = plt.subplots(figsize=(12, 10))
141
+ sns.heatmap(attention, annot=True, cmap="YlOrRd", xticklabels=input_tokens, yticklabels=input_tokens, ax=ax)
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ ax.set_title("Carte d'attention moyenne")
144
  plt.tight_layout()
145
  return fig
146
 
 
151
  return "", 1.0, 1.0, 50, None, None, None, None
152
 
153
  with gr.Blocks() as demo:
154
+ gr.Markdown("# Analyse et génération de texte")
155
 
156
  with gr.Accordion("Sélection du modèle"):
157
  model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle")
 
169
  next_token_probs = gr.Textbox(label="Probabilités du prochain token")
170
 
171
  with gr.Row():
172
+ attention_plot = gr.Plot(label="Visualisation de l'attention")
173
  prob_plot = gr.Plot(label="Probabilités des tokens suivants")
174
 
175
  generate_button = gr.Button("Générer le prochain mot")
 
180
  load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
181
  analyze_button.click(analyze_next_token,
182
  inputs=[input_text, temperature, top_p, top_k],
183
+ outputs=[next_token_probs, attention_plot, prob_plot])
184
  generate_button.click(generate_text,
185
  inputs=[input_text, temperature, top_p, top_k],
186
  outputs=[generated_text])
187
  reset_button.click(reset,
188
+ outputs=[input_text, temperature, top_p, top_k, next_token_probs, attention_plot, prob_plot, generated_text])
189
 
190
  if __name__ == "__main__":
191
  demo.launch()