gnosticdev commited on
Commit
e90622e
Β·
verified Β·
1 Parent(s): 31a1a1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -9
app.py CHANGED
@@ -1,4 +1,5 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
2
  import gradio as gr
3
  import torch
4
 
@@ -6,23 +7,62 @@ import torch
6
  tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
7
  model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  # Inicializar el historial de conversaciΓ³n
10
  chat_history_ids = None
11
 
 
12
  def chat_with_bot(user_input):
13
  global chat_history_ids
14
- # Codificar la entrada del usuario
15
  new_user_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
16
-
17
- # Concatenar la entrada del usuario con el historial de conversaciΓ³n
18
  bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if chat_history_ids is not None else new_user_input_ids
19
-
20
- # Generar una respuesta
21
  chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
22
-
23
- # Decodificar y devolver la respuesta
24
  return tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
25
 
26
  # Crear la interfaz de Gradio
27
- iface = gr.Interface(fn=chat_with_bot, inputs="text", outputs="text", title="Chatbot con DialoGPT")
28
  iface.launch()
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
2
+ from datasets import load_dataset
3
  import gradio as gr
4
  import torch
5
 
 
7
  tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
8
  model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
9
 
10
+ # Cargar tu conjunto de datos
11
+ dataset = load_dataset('csv', data_files='alpaca.csv')
12
+
13
+ # Preprocesar los datos
14
+ def preprocess_function(examples):
15
+ inputs = [ex for ex in examples['input_text']]
16
+ outputs = [ex for ex in examples['response_text']]
17
+ model_inputs = tokenizer(inputs, max_length=512, truncation=True)
18
+
19
+ # Configurar las etiquetas
20
+ with tokenizer.as_target_tokenizer():
21
+ labels = tokenizer(outputs, max_length=512, truncation=True)
22
+
23
+ model_inputs["labels"] = labels["input_ids"]
24
+ return model_inputs
25
+
26
+ tokenized_dataset = dataset.map(preprocess_function, batched=True)
27
+
28
+ # Configurar los argumentos de entrenamiento
29
+ training_args = TrainingArguments(
30
+ output_dir="./results",
31
+ evaluation_strategy="epoch",
32
+ learning_rate=2e-5,
33
+ per_device_train_batch_size=2,
34
+ num_train_epochs=3,
35
+ )
36
+
37
+ # Crear el Trainer
38
+ trainer = Trainer(
39
+ model=model,
40
+ args=training_args,
41
+ train_dataset=tokenized_dataset['train'],
42
+ )
43
+
44
+ # Entrenar el modelo
45
+ trainer.train()
46
+
47
+ # Guardar el modelo entrenado
48
+ model.save_pretrained("./mi_modelo_entrenado")
49
+ tokenizer.save_pretrained("./mi_modelo_entrenado")
50
+
51
+ # Cargar el modelo entrenado
52
+ model = AutoModelForCausalLM.from_pretrained("./mi_modelo_entrenado")
53
+ tokenizer = AutoTokenizer.from_pretrained("./mi_modelo_entrenado")
54
+
55
  # Inicializar el historial de conversaciΓ³n
56
  chat_history_ids = None
57
 
58
+ # FunciΓ³n de chat
59
  def chat_with_bot(user_input):
60
  global chat_history_ids
 
61
  new_user_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
 
 
62
  bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if chat_history_ids is not None else new_user_input_ids
 
 
63
  chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
 
 
64
  return tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
65
 
66
  # Crear la interfaz de Gradio
67
+ iface = gr.Interface(fn=chat_with_bot, inputs="text", outputs="text", title="Chatbot Entrenado")
68
  iface.launch()