PierreH commited on
Commit
45188a9
·
verified ·
1 Parent(s): d3d3d13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -1
app.py CHANGED
@@ -1,3 +1,50 @@
1
  import gradio as gr
 
 
2
 
3
- gr.load("models/EleutherAI/gpt-neo-1.3B").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import GPTNeoForCausalLM, GPT2Tokenizer, Trainer, TrainingArguments
3
+ from datasets import load_dataset
4
 
5
+ # Charger le modèle GPT-Neo et le tokenizer
6
+ model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
7
+ tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
8
+
9
+ # Charger les données (remplacer par le chemin vers vos propres données)
10
+ dataset = load_dataset("json", data_files={"train": "data.jsonl"})
11
+
12
+ # Tokeniser les données
13
+ def tokenize_function(examples):
14
+ return tokenizer(examples["prompt"], padding="max_length", truncation=True)
15
+
16
+ # Tokenisation des données
17
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
18
+
19
+ # Arguments d'entraînement
20
+ training_args = TrainingArguments(
21
+ output_dir="./results",
22
+ num_train_epochs=3, # Nombre d'époques
23
+ per_device_train_batch_size=4, # Taille du batch, ajustez selon vos ressources
24
+ save_steps=10_000, # Sauvegarder tous les 10 000 steps
25
+ save_total_limit=2, # Conserver seulement 2 checkpoints
26
+ )
27
+
28
+ # Initialiser le Trainer pour fine-tuner GPT-Neo
29
+ trainer = Trainer(
30
+ model=model,
31
+ args=training_args,
32
+ train_dataset=tokenized_datasets["train"],
33
+ )
34
+
35
+ # Lancer le fine-tuning
36
+ trainer.train()
37
+
38
+ # Sauvegarder le modèle fine-tuné
39
+ model.save_pretrained("./fine_tuned_gpt_neo")
40
+ tokenizer.save_pretrained("./fine_tuned_gpt_neo")
41
+
42
+ # Interface Gradio pour tester le modèle fine-tuné
43
+ def generate_text(prompt):
44
+ inputs = tokenizer(prompt, return_tensors="pt")
45
+ outputs = model.generate(inputs["input_ids"], max_length=150)
46
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
47
+
48
+ # Créer une interface avec Gradio pour interagir avec le modèle
49
+ interface = gr.Interface(fn=generate_text, inputs="text", outputs="text")
50
+ interface.launch()