Woziii commited on
Commit
55f3d52
·
verified ·
1 Parent(s): 45f8781

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -140
app.py CHANGED
@@ -11,37 +11,55 @@ import time
11
  # Authentification
12
  login(token=os.environ["HF_TOKEN"])
13
 
14
- # Liste des modèles et leurs langues supportées
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  models_and_languages = {
16
- "meta-llama/Llama-2-13b-hf": ["en"],
17
- "meta-llama/Llama-2-7b-hf": ["en"],
18
- "meta-llama/Llama-2-70b-hf": ["en"],
19
- "meta-llama/Meta-Llama-3-8B": ["en"],
20
- "meta-llama/Llama-3.2-3B": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
21
- "meta-llama/Llama-3.1-8B": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
22
  "mistralai/Mistral-7B-v0.1": ["en"],
23
  "mistralai/Mixtral-8x7B-v0.1": ["en", "fr", "it", "de", "es"],
24
  "mistralai/Mistral-7B-v0.3": ["en"],
25
- "google/gemma-2-2b": ["en"],
26
- "google/gemma-2-9b": ["en"],
27
- "google/gemma-2-27b": ["en"],
28
  "croissantllm/CroissantLLMBase": ["en", "fr"]
29
  }
30
 
31
  # Paramètres recommandés pour chaque modèle
32
  model_parameters = {
33
- "meta-llama/Llama-2-13b-hf": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
34
- "meta-llama/Llama-2-7b-hf": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
35
- "meta-llama/Llama-2-70b-hf": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
36
- "meta-llama/Meta-Llama-3-8B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
37
- "meta-llama/Llama-3.2-3B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
38
- "meta-llama/Llama-3.1-8B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
39
  "mistralai/Mistral-7B-v0.1": {"temperature": 0.7, "top_p": 0.9, "top_k": 50},
40
  "mistralai/Mixtral-8x7B-v0.1": {"temperature": 0.8, "top_p": 0.95, "top_k": 50},
41
  "mistralai/Mistral-7B-v0.3": {"temperature": 0.7, "top_p": 0.9, "top_k": 50},
42
- "google/gemma-2-2b": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
43
- "google/gemma-2-9b": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
44
- "google/gemma-2-27b": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
45
  "croissantllm/CroissantLLMBase": {"temperature": 0.8, "top_p": 0.92, "top_k": 50}
46
  }
47
 
@@ -50,24 +68,32 @@ model = None
50
  tokenizer = None
51
  selected_language = None
52
 
53
- def load_model(model_name, progress=gr.Progress()):
 
 
 
 
 
 
54
  global model, tokenizer
 
 
55
  try:
56
  progress(0, desc="Chargement du tokenizer")
57
- tokenizer = AutoTokenizer.from_pretrained(model_name)
58
  progress(0.5, desc="Chargement du modèle")
59
 
60
  # Configurations spécifiques par modèle
61
- if "mixtral" in model_name.lower():
62
  model = AutoModelForCausalLM.from_pretrained(
63
- model_name,
64
  torch_dtype=torch.float16,
65
  device_map="auto",
66
  load_in_8bit=True
67
  )
68
  else:
69
  model = AutoModelForCausalLM.from_pretrained(
70
- model_name,
71
  torch_dtype=torch.float16,
72
  device_map="auto"
73
  )
@@ -76,12 +102,12 @@ def load_model(model_name, progress=gr.Progress()):
76
  tokenizer.pad_token = tokenizer.eos_token
77
 
78
  progress(1.0, desc="Modèle chargé")
79
- available_languages = models_and_languages[model_name]
80
 
81
  # Mise à jour des sliders avec les valeurs recommandées
82
- params = model_parameters[model_name]
83
  return (
84
- f"Modèle {model_name} chargé avec succès. Langues disponibles : {', '.join(available_languages)}",
85
  gr.Dropdown(choices=available_languages, value=available_languages[0], visible=True, interactive=True),
86
  params["temperature"],
87
  params["top_p"],
@@ -90,123 +116,16 @@ def load_model(model_name, progress=gr.Progress()):
90
  except Exception as e:
91
  return f"Erreur lors du chargement du modèle : {str(e)}", gr.Dropdown(visible=False), None, None, None
92
 
93
- def set_language(lang):
94
- global selected_language
95
- selected_language = lang
96
- return f"Langue sélectionnée : {lang}"
97
-
98
- def ensure_token_display(token):
99
- """Assure que le token est affiché correctement."""
100
- if token.isdigit() or (token.startswith('-') and token[1:].isdigit()):
101
- return tokenizer.decode([int(token)])
102
- return token
103
-
104
- def analyze_next_token(input_text, temperature, top_p, top_k):
105
- global model, tokenizer, selected_language
106
-
107
- if model is None or tokenizer is None:
108
- return "Veuillez d'abord charger un modèle.", None, None
109
-
110
- inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
111
-
112
- try:
113
- with torch.no_grad():
114
- outputs = model(**inputs)
115
-
116
- last_token_logits = outputs.logits[0, -1, :]
117
- probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
118
-
119
- top_k = 10
120
- top_probs, top_indices = torch.topk(probabilities, top_k)
121
- top_words = [ensure_token_display(tokenizer.decode([idx.item()])) for idx in top_indices]
122
- prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
123
-
124
- prob_text = "Prochains tokens les plus probables :\n\n"
125
- for word, prob in prob_data.items():
126
- prob_text += f"{word}: {prob:.2%}\n"
127
-
128
- prob_plot = plot_probabilities(prob_data)
129
- attention_plot = plot_attention(inputs["input_ids"][0].cpu(), last_token_logits.cpu())
130
-
131
- return prob_text, attention_plot, prob_plot
132
- except Exception as e:
133
- return f"Erreur lors de l'analyse : {str(e)}", None, None
134
-
135
- def generate_text(input_text, temperature, top_p, top_k):
136
- global model, tokenizer, selected_language
137
-
138
- if model is None or tokenizer is None:
139
- return "Veuillez d'abord charger un modèle."
140
-
141
- inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
142
-
143
- try:
144
- with torch.no_grad():
145
- outputs = model.generate(
146
- **inputs,
147
- max_new_tokens=10,
148
- temperature=temperature,
149
- top_p=top_p,
150
- top_k=top_k
151
- )
152
-
153
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
154
- return generated_text
155
- except Exception as e:
156
- return f"Erreur lors de la génération : {str(e)}"
157
-
158
- def plot_probabilities(prob_data):
159
- words = list(prob_data.keys())
160
- probs = list(prob_data.values())
161
-
162
- fig, ax = plt.subplots(figsize=(12, 6))
163
- bars = ax.bar(range(len(words)), probs, color='lightgreen')
164
- ax.set_title("Probabilités des tokens suivants les plus probables")
165
- ax.set_xlabel("Tokens")
166
- ax.set_ylabel("Probabilité")
167
-
168
- ax.set_xticks(range(len(words)))
169
- ax.set_xticklabels(words, rotation=45, ha='right')
170
-
171
- for i, (bar, word) in enumerate(zip(bars, words)):
172
- height = bar.get_height()
173
- ax.text(i, height, f'{height:.2%}',
174
- ha='center', va='bottom', rotation=0)
175
-
176
- plt.tight_layout()
177
- return fig
178
-
179
- def plot_attention(input_ids, last_token_logits):
180
- input_tokens = [ensure_token_display(tokenizer.decode([id])) for id in input_ids]
181
- attention_scores = torch.nn.functional.softmax(last_token_logits, dim=-1)
182
- top_k = min(len(input_tokens), 10)
183
- top_attention_scores, _ = torch.topk(attention_scores, top_k)
184
-
185
- fig, ax = plt.subplots(figsize=(14, 7))
186
- sns.heatmap(top_attention_scores.unsqueeze(0).numpy(), annot=True, cmap="YlOrRd", cbar=True, ax=ax, fmt='.2%')
187
- ax.set_xticklabels(input_tokens[-top_k:], rotation=45, ha="right", fontsize=10)
188
- ax.set_yticklabels(["Attention"], rotation=0, fontsize=10)
189
- ax.set_title("Scores d'attention pour les derniers tokens", fontsize=16)
190
-
191
- cbar = ax.collections[0].colorbar
192
- cbar.set_label("Score d'attention", fontsize=12)
193
- cbar.ax.tick_params(labelsize=10)
194
-
195
- plt.tight_layout()
196
- return fig
197
-
198
- def reset():
199
- global model, tokenizer, selected_language
200
- model = None
201
- tokenizer = None
202
- selected_language = None
203
- return "", 1.0, 1.0, 50, None, None, None, None, gr.Dropdown(visible=False), ""
204
 
205
  with gr.Blocks() as demo:
206
  gr.Markdown("# LLM&BIAS")
207
 
208
  with gr.Accordion("Sélection du modèle"):
209
- model_dropdown = gr.Dropdown(choices=list(models_and_languages.keys()), label="Choisissez un modèle")
 
 
210
  load_button = gr.Button("Charger le modèle")
211
  load_output = gr.Textbox(label="Statut du chargement")
212
  language_dropdown = gr.Dropdown(label="Choisissez une langue", visible=False)
@@ -231,8 +150,10 @@ with gr.Blocks() as demo:
231
 
232
  reset_button = gr.Button("Réinitialiser")
233
 
 
 
234
  load_button.click(load_model,
235
- inputs=[model_dropdown],
236
  outputs=[load_output, language_dropdown, temperature, top_p, top_k])
237
  language_dropdown.change(set_language, inputs=[language_dropdown], outputs=[language_output])
238
  analyze_button.click(analyze_next_token,
 
11
  # Authentification
12
  login(token=os.environ["HF_TOKEN"])
13
 
14
+ # Structure hiérarchique des modèles
15
+ model_hierarchy = {
16
+ "meta-llama": {
17
+ "Llama-2": ["7B", "13B", "70B"],
18
+ "Llama-3": ["8B", "3.2B", "3.1B"]
19
+ },
20
+ "mistralai": {
21
+ "Mistral": ["7B-v0.1", "7B-v0.3"],
22
+ "Mixtral": ["8x7B-v0.1"]
23
+ },
24
+ "google": {
25
+ "Gemma": ["2B", "9B", "27B"]
26
+ },
27
+ "croissantllm": {
28
+ "CroissantLLM": ["Base"]
29
+ }
30
+ }
31
+
32
+ # Mise à jour de la liste des modèles et leurs langues supportées
33
  models_and_languages = {
34
+ "meta-llama/Llama-2-7B": ["en"],
35
+ "meta-llama/Llama-2-13B": ["en"],
36
+ "meta-llama/Llama-2-70B": ["en"],
37
+ "meta-llama/Llama-3-8B": ["en"],
38
+ "meta-llama/Llama-3-3.2B": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
39
+ "meta-llama/Llama-3-3.1B": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
40
  "mistralai/Mistral-7B-v0.1": ["en"],
41
  "mistralai/Mixtral-8x7B-v0.1": ["en", "fr", "it", "de", "es"],
42
  "mistralai/Mistral-7B-v0.3": ["en"],
43
+ "google/Gemma-2B": ["en"],
44
+ "google/Gemma-9B": ["en"],
45
+ "google/Gemma-27B": ["en"],
46
  "croissantllm/CroissantLLMBase": ["en", "fr"]
47
  }
48
 
49
  # Paramètres recommandés pour chaque modèle
50
  model_parameters = {
51
+ "meta-llama/Llama-2-7B": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
52
+ "meta-llama/Llama-2-13B": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
53
+ "meta-llama/Llama-2-70B": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
54
+ "meta-llama/Llama-3-8B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
55
+ "meta-llama/Llama-3-3.2B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
56
+ "meta-llama/Llama-3-3.1B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
57
  "mistralai/Mistral-7B-v0.1": {"temperature": 0.7, "top_p": 0.9, "top_k": 50},
58
  "mistralai/Mixtral-8x7B-v0.1": {"temperature": 0.8, "top_p": 0.95, "top_k": 50},
59
  "mistralai/Mistral-7B-v0.3": {"temperature": 0.7, "top_p": 0.9, "top_k": 50},
60
+ "google/Gemma-2B": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
61
+ "google/Gemma-9B": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
62
+ "google/Gemma-27B": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
63
  "croissantllm/CroissantLLMBase": {"temperature": 0.8, "top_p": 0.92, "top_k": 50}
64
  }
65
 
 
68
  tokenizer = None
69
  selected_language = None
70
 
71
+ def update_model_choices(company):
72
+ return gr.Dropdown(choices=list(model_hierarchy[company].keys()), value=None)
73
+
74
+ def update_variation_choices(company, model_name):
75
+ return gr.Dropdown(choices=model_hierarchy[company][model_name], value=None)
76
+
77
+ def load_model(company, model_name, variation, progress=gr.Progress()):
78
  global model, tokenizer
79
+ full_model_name = f"{company}/{model_name}-{variation}"
80
+
81
  try:
82
  progress(0, desc="Chargement du tokenizer")
83
+ tokenizer = AutoTokenizer.from_pretrained(full_model_name)
84
  progress(0.5, desc="Chargement du modèle")
85
 
86
  # Configurations spécifiques par modèle
87
+ if "mixtral" in full_model_name.lower():
88
  model = AutoModelForCausalLM.from_pretrained(
89
+ full_model_name,
90
  torch_dtype=torch.float16,
91
  device_map="auto",
92
  load_in_8bit=True
93
  )
94
  else:
95
  model = AutoModelForCausalLM.from_pretrained(
96
+ full_model_name,
97
  torch_dtype=torch.float16,
98
  device_map="auto"
99
  )
 
102
  tokenizer.pad_token = tokenizer.eos_token
103
 
104
  progress(1.0, desc="Modèle chargé")
105
+ available_languages = models_and_languages[full_model_name]
106
 
107
  # Mise à jour des sliders avec les valeurs recommandées
108
+ params = model_parameters[full_model_name]
109
  return (
110
+ f"Modèle {full_model_name} chargé avec succès. Langues disponibles : {', '.join(available_languages)}",
111
  gr.Dropdown(choices=available_languages, value=available_languages[0], visible=True, interactive=True),
112
  params["temperature"],
113
  params["top_p"],
 
116
  except Exception as e:
117
  return f"Erreur lors du chargement du modèle : {str(e)}", gr.Dropdown(visible=False), None, None, None
118
 
119
+ # Le reste du code reste inchangé
120
+ # ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  with gr.Blocks() as demo:
123
  gr.Markdown("# LLM&BIAS")
124
 
125
  with gr.Accordion("Sélection du modèle"):
126
+ company_dropdown = gr.Dropdown(choices=list(model_hierarchy.keys()), label="Choisissez une société")
127
+ model_dropdown = gr.Dropdown(label="Choisissez un modèle", choices=[])
128
+ variation_dropdown = gr.Dropdown(label="Choisissez une variation", choices=[])
129
  load_button = gr.Button("Charger le modèle")
130
  load_output = gr.Textbox(label="Statut du chargement")
131
  language_dropdown = gr.Dropdown(label="Choisissez une langue", visible=False)
 
150
 
151
  reset_button = gr.Button("Réinitialiser")
152
 
153
+ company_dropdown.change(update_model_choices, inputs=[company_dropdown], outputs=[model_dropdown])
154
+ model_dropdown.change(update_variation_choices, inputs=[company_dropdown, model_dropdown], outputs=[variation_dropdown])
155
  load_button.click(load_model,
156
+ inputs=[company_dropdown, model_dropdown, variation_dropdown],
157
  outputs=[load_output, language_dropdown, temperature, top_p, top_k])
158
  language_dropdown.change(set_language, inputs=[language_dropdown], outputs=[language_output])
159
  analyze_button.click(analyze_next_token,