Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -62,7 +62,7 @@ def analyze_next_token(input_text, temperature, top_p, top_k):
|
|
62 |
|
63 |
try:
|
64 |
with torch.no_grad():
|
65 |
-
outputs = model(**inputs)
|
66 |
|
67 |
last_token_logits = outputs.logits[0, -1, :]
|
68 |
probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
|
@@ -76,13 +76,14 @@ def analyze_next_token(input_text, temperature, top_p, top_k):
|
|
76 |
# Créer un texte explicatif
|
77 |
prob_text = "Prochains tokens les plus probables :\n\n"
|
78 |
for word, prob in prob_data.items():
|
79 |
-
|
|
|
80 |
|
81 |
# Créer les visualisations
|
82 |
prob_plot = plot_probabilities(prob_data)
|
83 |
-
|
84 |
|
85 |
-
return prob_text,
|
86 |
except Exception as e:
|
87 |
return f"Erreur lors de l'analyse : {str(e)}", None, None
|
88 |
|
@@ -113,52 +114,33 @@ def plot_probabilities(prob_data):
|
|
113 |
words = list(prob_data.keys())
|
114 |
probs = list(prob_data.values())
|
115 |
|
116 |
-
fig, ax = plt.subplots(figsize=(
|
117 |
-
bars = ax.bar(words, probs, color='lightgreen')
|
118 |
ax.set_title("Probabilités des tokens suivants les plus probables")
|
119 |
ax.set_xlabel("Tokens")
|
120 |
ax.set_ylabel("Probabilité")
|
121 |
-
plt.xticks(rotation=45, ha='right')
|
122 |
|
123 |
-
|
124 |
-
|
|
|
|
|
125 |
height = bar.get_height()
|
126 |
-
ax.text(
|
127 |
-
|
128 |
-
ha='center', va='bottom')
|
129 |
|
130 |
plt.tight_layout()
|
131 |
return fig
|
132 |
|
133 |
-
def
|
134 |
input_tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
135 |
|
136 |
-
#
|
137 |
-
|
138 |
-
importances = importances.repeat(len(input_tokens))
|
139 |
-
|
140 |
-
# Normaliser les importances
|
141 |
-
importances = importances / importances.sum()
|
142 |
-
|
143 |
-
# Créer la figure
|
144 |
-
fig, ax = plt.subplots(figsize=(12, 3))
|
145 |
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
# Ajouter les labels et le titre
|
150 |
-
ax.set_xticks(range(len(input_tokens)))
|
151 |
-
ax.set_xticklabels(input_tokens, rotation=45, ha='right')
|
152 |
-
ax.set_ylabel('Importance relative')
|
153 |
-
ax.set_title('Importance des tokens d\'entrée pour la prédiction')
|
154 |
-
|
155 |
-
# Ajouter les valeurs sur les barres
|
156 |
-
for bar in bars:
|
157 |
-
height = bar.get_height()
|
158 |
-
ax.text(bar.get_x() + bar.get_width()/2., height,
|
159 |
-
f'{height:.2%}',
|
160 |
-
ha='center', va='bottom')
|
161 |
|
|
|
162 |
plt.tight_layout()
|
163 |
return fig
|
164 |
|
@@ -169,7 +151,7 @@ def reset():
|
|
169 |
return "", 1.0, 1.0, 50, None, None, None, None
|
170 |
|
171 |
with gr.Blocks() as demo:
|
172 |
-
gr.Markdown("#
|
173 |
|
174 |
with gr.Accordion("Sélection du modèle"):
|
175 |
model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle")
|
@@ -187,7 +169,7 @@ with gr.Blocks() as demo:
|
|
187 |
next_token_probs = gr.Textbox(label="Probabilités du prochain token")
|
188 |
|
189 |
with gr.Row():
|
190 |
-
|
191 |
prob_plot = gr.Plot(label="Probabilités des tokens suivants")
|
192 |
|
193 |
generate_button = gr.Button("Générer le prochain mot")
|
@@ -198,12 +180,12 @@ with gr.Blocks() as demo:
|
|
198 |
load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
|
199 |
analyze_button.click(analyze_next_token,
|
200 |
inputs=[input_text, temperature, top_p, top_k],
|
201 |
-
outputs=[next_token_probs,
|
202 |
generate_button.click(generate_text,
|
203 |
inputs=[input_text, temperature, top_p, top_k],
|
204 |
outputs=[generated_text])
|
205 |
reset_button.click(reset,
|
206 |
-
outputs=[input_text, temperature, top_p, top_k, next_token_probs,
|
207 |
|
208 |
if __name__ == "__main__":
|
209 |
demo.launch()
|
|
|
62 |
|
63 |
try:
|
64 |
with torch.no_grad():
|
65 |
+
outputs = model(**inputs, output_attentions=True)
|
66 |
|
67 |
last_token_logits = outputs.logits[0, -1, :]
|
68 |
probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
|
|
|
76 |
# Créer un texte explicatif
|
77 |
prob_text = "Prochains tokens les plus probables :\n\n"
|
78 |
for word, prob in prob_data.items():
|
79 |
+
escaped_word = word.replace("<", "<").replace(">", ">")
|
80 |
+
prob_text += f"{escaped_word}: {prob:.2%}\n"
|
81 |
|
82 |
# Créer les visualisations
|
83 |
prob_plot = plot_probabilities(prob_data)
|
84 |
+
attention_plot = plot_attention(inputs["input_ids"][0], outputs.attentions)
|
85 |
|
86 |
+
return prob_text, attention_plot, prob_plot
|
87 |
except Exception as e:
|
88 |
return f"Erreur lors de l'analyse : {str(e)}", None, None
|
89 |
|
|
|
114 |
words = list(prob_data.keys())
|
115 |
probs = list(prob_data.values())
|
116 |
|
117 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
118 |
+
bars = ax.bar(range(len(words)), probs, color='lightgreen')
|
119 |
ax.set_title("Probabilités des tokens suivants les plus probables")
|
120 |
ax.set_xlabel("Tokens")
|
121 |
ax.set_ylabel("Probabilité")
|
|
|
122 |
|
123 |
+
ax.set_xticks(range(len(words)))
|
124 |
+
ax.set_xticklabels(words, rotation=45, ha='right')
|
125 |
+
|
126 |
+
for i, (bar, word) in enumerate(zip(bars, words)):
|
127 |
height = bar.get_height()
|
128 |
+
ax.text(i, height, f'{word}\n{height:.2%}',
|
129 |
+
ha='center', va='bottom', rotation=0)
|
|
|
130 |
|
131 |
plt.tight_layout()
|
132 |
return fig
|
133 |
|
134 |
+
def plot_attention(input_ids, attention_outputs):
|
135 |
input_tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
136 |
|
137 |
+
# Prendre la moyenne des attentions sur toutes les couches et têtes
|
138 |
+
attention = torch.mean(torch.cat(attention_outputs), dim=(0, 1)).cpu().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
140 |
+
fig, ax = plt.subplots(figsize=(12, 10))
|
141 |
+
sns.heatmap(attention, annot=True, cmap="YlOrRd", xticklabels=input_tokens, yticklabels=input_tokens, ax=ax)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
+
ax.set_title("Carte d'attention moyenne")
|
144 |
plt.tight_layout()
|
145 |
return fig
|
146 |
|
|
|
151 |
return "", 1.0, 1.0, 50, None, None, None, None
|
152 |
|
153 |
with gr.Blocks() as demo:
|
154 |
+
gr.Markdown("# Analyse et génération de texte")
|
155 |
|
156 |
with gr.Accordion("Sélection du modèle"):
|
157 |
model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle")
|
|
|
169 |
next_token_probs = gr.Textbox(label="Probabilités du prochain token")
|
170 |
|
171 |
with gr.Row():
|
172 |
+
attention_plot = gr.Plot(label="Visualisation de l'attention")
|
173 |
prob_plot = gr.Plot(label="Probabilités des tokens suivants")
|
174 |
|
175 |
generate_button = gr.Button("Générer le prochain mot")
|
|
|
180 |
load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
|
181 |
analyze_button.click(analyze_next_token,
|
182 |
inputs=[input_text, temperature, top_p, top_k],
|
183 |
+
outputs=[next_token_probs, attention_plot, prob_plot])
|
184 |
generate_button.click(generate_text,
|
185 |
inputs=[input_text, temperature, top_p, top_k],
|
186 |
outputs=[generated_text])
|
187 |
reset_button.click(reset,
|
188 |
+
outputs=[input_text, temperature, top_p, top_k, next_token_probs, attention_plot, prob_plot, generated_text])
|
189 |
|
190 |
if __name__ == "__main__":
|
191 |
demo.launch()
|