seronk commited on
Commit
2e75c99
·
verified ·
1 Parent(s): 2812524

training distillbert on data (#3)

Browse files

- training distillbert on data (a1e89098c1a4457b75004ffc0a21c98b10b929a5)

Files changed (1) hide show
  1. tasks/text.py +31 -5
tasks/text.py CHANGED
@@ -8,7 +8,7 @@ from .utils.evaluation import TextEvaluationRequest
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
10
  #additional imports
11
- from transformers import pipeline
12
  import logging
13
 
14
  router = APIRouter()
@@ -63,11 +63,37 @@ async def evaluate_text(request: TextEvaluationRequest):
63
  # Make random predictions (placeholder for actual model inference)
64
  true_labels = test_dataset["label"]
65
 
66
- available_pipeline = pipeline(tasks="text_classfication")
67
- print(available_pipeline)
68
- logging.log(INFO, available_pipeline)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
 
 
 
 
 
 
 
 
 
 
71
 
72
  #--------------------------------------------------------------------------------------------
73
  # YOUR MODEL INFERENCE STOPS HERE
 
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
10
  #additional imports
11
+ from transformers import Trainer, TrainingArguments, DistilBertForSequenceClassification, DistilBertTokenizerFast
12
  import logging
13
 
14
  router = APIRouter()
 
63
  # Make random predictions (placeholder for actual model inference)
64
  true_labels = test_dataset["label"]
65
 
66
+ tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
67
+ # Tokenize the datasets
68
+ def tokenize_function(examples):
69
+ return tokenizer(examples["quote"], padding="max_length", truncation=True)
70
+
71
+ train_dataset = dataset["train"].map(tokenize_function, batched=True)
72
+ test_dataset = dataset["test"].map(tokenize_function, batched=True)
73
+
74
+ model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=8) # Set num_labels for your classification task
75
+
76
+ training_args = TrainingArguments(
77
+ output_dir="./results",
78
+ eval_strategy="epoch", # Evaluation strategy (can be "steps" or "epoch")
79
+ per_device_train_batch_size=16, # Batch size for training
80
+ per_device_eval_batch_size=64, # Batch size for evaluation
81
+ num_train_epochs=3, # Number of training epochs
82
+ logging_dir="./logs", # Directory for logs
83
+ logging_steps=10, # How often to log
84
+ )
85
 
86
+ trainer = Trainer(
87
+ model=model, # The model to train
88
+ args=training_args, # The training arguments
89
+ train_dataset=train_dataset, # The training dataset
90
+ eval_dataset=test_dataset # The evaluation dataset
91
+ )
92
+
93
+
94
+ trainer.train()
95
+ predictions = trainer.evaluate()
96
+
97
 
98
  #--------------------------------------------------------------------------------------------
99
  # YOUR MODEL INFERENCE STOPS HERE