MeghanaVeerannagari commited on
Commit
406c1a0
·
1 Parent(s): b6f4cbd

text-classification

Browse files
Files changed (1) hide show
  1. text-classification.py +77 -0
text-classification.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from uuid import uuid4
3
+
4
+ from datasets import load_dataset
5
+
6
+ from autotrain.dataset import AutoTrainDataset
7
+ from autotrain.project import Project
8
+
9
+
10
+ RANDOM_ID = str(uuid4())
11
+ DATASET = "imdb"
12
+ PROJECT_NAME = f"imdb_{RANDOM_ID}"
13
+ TASK = "text_binary_classification"
14
+ MODEL = "bert-base-uncased"
15
+
16
+ USERNAME = os.environ["AUTOTRAIN_USERNAME"]
17
+ TOKEN = os.environ["HF_TOKEN"]
18
+
19
+
20
+ if __name__ == "__main__":
21
+ dataset = load_dataset(DATASET)
22
+ train = dataset["train"]
23
+ validation = dataset["test"]
24
+
25
+ # convert to pandas dataframe
26
+ train_df = train.to_pandas()
27
+ validation_df = validation.to_pandas()
28
+
29
+ # prepare dataset for AutoTrain
30
+ dset = AutoTrainDataset(
31
+ train_data=[train_df],
32
+ valid_data=[validation_df],
33
+ task=TASK,
34
+ token=TOKEN,
35
+ project_name=PROJECT_NAME,
36
+ username=USERNAME,
37
+ column_mapping={"text": "text", "label": "label"},
38
+ percent_valid=None,
39
+ )
40
+ dset.prepare()
41
+
42
+ #
43
+ # How to get params for a task:
44
+ #
45
+ # from autotrain.params import Params
46
+ # params = Params(task=TASK, training_type="hub_model").get()
47
+ # print(params) to get full list of params for the task
48
+
49
+ # define params in proper format
50
+ job1 = {
51
+ "task": TASK,
52
+ "learning_rate": 1e-5,
53
+ "optimizer": "adamw_torch",
54
+ "scheduler": "linear",
55
+ "epochs": 5,
56
+ }
57
+
58
+ job2 = {
59
+ "task": TASK,
60
+ "learning_rate": 3e-5,
61
+ "optimizer": "adamw_torch",
62
+ "scheduler": "cosine",
63
+ "epochs": 5,
64
+ }
65
+
66
+ job3 = {
67
+ "task": TASK,
68
+ "learning_rate": 5e-5,
69
+ "optimizer": "sgd",
70
+ "scheduler": "cosine",
71
+ "epochs": 5,
72
+ }
73
+
74
+ jobs = [job1, job2, job3]
75
+ project = Project(dataset=dset, hub_model=MODEL, job_params=jobs)
76
+ project_id = project.create()
77
+ project.approve(project_id)