windmaple commited on
Commit
143d1fc
·
1 Parent(s): fc36b22

Update quickstart_sst_demo.py

Browse files
Files changed (1) hide show
  1. quickstart_sst_demo.py +89 -30
quickstart_sst_demo.py CHANGED
@@ -1,17 +1,16 @@
1
  # Lint as: python3
2
- r"""Quick-start demo for a sentiment analysis model.
3
 
4
- This demo fine-tunes a small Transformer (BERT-tiny) on the Stanford Sentiment
5
- Treebank (SST-2), and starts a LIT server.
 
6
 
7
- To run locally:
8
- python -m lit_nlp.examples.quickstart_sst_demo --port=5432
9
 
10
- Training should take less than 5 minutes on a single GPU. Once you see the
11
- ASCII-art LIT logo, navigate to localhost:5432 to access the demo UI.
12
  """
13
  import sys
14
- import tempfile
15
 
16
  from absl import app
17
  from absl import flags
@@ -22,21 +21,52 @@ from lit_nlp import server_flags
22
  from lit_nlp.examples.datasets import glue
23
  from lit_nlp.examples.models import glue_models
24
 
 
 
25
  # NOTE: additional flags defined in server_flags.py
26
 
27
  FLAGS = flags.FLAGS
28
 
29
  FLAGS.set_default("development_demo", True)
30
 
31
- flags.DEFINE_string(
32
- "encoder_name", "google/bert_uncased_L-2_H-128_A-2",
33
- "Encoder name to use for fine-tuning. See https://huggingface.co/models.")
34
-
35
- flags.DEFINE_string("model_path", None, "Path to save trained model.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
 
38
  def get_wsgi_app():
39
- """Returns a LitApp instance for consumption by gunicorn."""
40
  FLAGS.set_default("server_type", "external")
41
  FLAGS.set_default("demo_mode", True)
42
  # Parse flags without calling app.run(main), to avoid conflict with
@@ -45,23 +75,52 @@ def get_wsgi_app():
45
  return main(unused)
46
 
47
 
48
- def run_finetuning(train_path):
49
- """Fine-tune a transformer model."""
50
- train_data = glue.SST2Data("train")
51
- val_data = glue.SST2Data("validation")
52
- model = glue_models.SST2Model(FLAGS.encoder_name)
53
- model.train(train_data.examples, validation_inputs=val_data.examples)
54
- model.save(train_path)
55
-
56
-
57
  def main(_):
58
- model_path = FLAGS.model_path or tempfile.mkdtemp()
59
- logging.info("Working directory: %s", model_path)
60
- run_finetuning(model_path)
61
-
62
- # Load our trained model.
63
- models = {"sst": glue_models.SST2Model(model_path)}
64
- datasets = {"sst_dev": glue.SST2Data("validation")}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  # Start the LIT server. See server_flags.py for server options.
67
  lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
 
1
  # Lint as: python3
2
+ r"""Example demo loading a handful of GLUE models.
3
 
4
+ For a quick-start set of models, run:
5
+ python -m lit_nlp.examples.glue_demo \
6
+ --quickstart --port=5432
7
 
8
+ To run with the 'normal' defaults, including full-size BERT models:
9
+ python -m lit_nlp.examples.glue_demo --port=5432
10
 
11
+ Then navigate to localhost:5432 to access the demo UI.
 
12
  """
13
  import sys
 
14
 
15
  from absl import app
16
  from absl import flags
 
21
  from lit_nlp.examples.datasets import glue
22
  from lit_nlp.examples.models import glue_models
23
 
24
+ import transformers # for path caching
25
+
26
  # NOTE: additional flags defined in server_flags.py
27
 
28
  FLAGS = flags.FLAGS
29
 
30
  FLAGS.set_default("development_demo", True)
31
 
32
+ flags.DEFINE_bool(
33
+ "quickstart", False,
34
+ "Quick-start mode, loads smaller models and a subset of the full data.")
35
+
36
+ flags.DEFINE_list(
37
+ "models", [
38
+ "sst2-tiny:sst2:https://storage.googleapis.com/what-if-tool-resources/lit-models/sst2_tiny.tar.gz",
39
+ "sst2-base:sst2:https://storage.googleapis.com/what-if-tool-resources/lit-models/sst2_base.tar.gz",
40
+ "stsb:stsb:https://storage.googleapis.com/what-if-tool-resources/lit-models/stsb_base.tar.gz",
41
+ "mnli:mnli:https://storage.googleapis.com/what-if-tool-resources/lit-models/mnli_base.tar.gz",
42
+ ], "List of models to load, as <name>:<task>:<path>. "
43
+ "See MODELS_BY_TASK for available tasks. Path should be the output of "
44
+ "saving a transformers model, e.g. model.save_pretrained(path) and "
45
+ "tokenizer.save_pretrained(path). Remote .tar.gz files will be downloaded "
46
+ "and cached locally.")
47
+
48
+ flags.DEFINE_integer(
49
+ "max_examples", None, "Maximum number of examples to load into LIT. "
50
+ "Note: MNLI eval set is 10k examples, so will take a while to run and may "
51
+ "be slow on older machines. Set --max_examples=200 for a quick start.")
52
+
53
+ MODELS_BY_TASK = {
54
+ "sst2": glue_models.SST2Model,
55
+ "stsb": glue_models.STSBModel,
56
+ "mnli": glue_models.MNLIModel,
57
+ }
58
+
59
+ # Pre-specified set of small models, which will load and run much faster.
60
+ QUICK_START_MODELS = (
61
+ "sst2-tiny:sst2:https://storage.googleapis.com/what-if-tool-resources/lit-models/sst2_tiny.tar.gz",
62
+ "sst2-small:sst2:https://storage.googleapis.com/what-if-tool-resources/lit-models/sst2_small.tar.gz",
63
+ "stsb-tiny:stsb:https://storage.googleapis.com/what-if-tool-resources/lit-models/stsb_tiny.tar.gz",
64
+ "mnli-small:mnli:https://storage.googleapis.com/what-if-tool-resources/lit-models/mnli_small.tar.gz",
65
+ )
66
 
67
 
68
  def get_wsgi_app():
69
+ """Return WSGI app for container-hosted demos."""
70
  FLAGS.set_default("server_type", "external")
71
  FLAGS.set_default("demo_mode", True)
72
  # Parse flags without calling app.run(main), to avoid conflict with
 
75
  return main(unused)
76
 
77
 
 
 
 
 
 
 
 
 
 
78
  def main(_):
79
+ # Quick-start mode.
80
+ if FLAGS.quickstart:
81
+ FLAGS.models = QUICK_START_MODELS # smaller, faster models
82
+ if FLAGS.max_examples is None or FLAGS.max_examples > 1000:
83
+ FLAGS.max_examples = 1000 # truncate larger eval sets
84
+ logging.info("Quick-start mode; overriding --models and --max_examples.")
85
+
86
+ models = {}
87
+ datasets = {}
88
+
89
+ tasks_to_load = set()
90
+ for model_string in FLAGS.models:
91
+ # Only split on the first two ':', because path may be a URL
92
+ # containing 'https://'
93
+ name, task, path = model_string.split(":", 2)
94
+ logging.info("Loading model '%s' for task '%s' from '%s'", name, task, path)
95
+ # Normally path is a directory; if it's an archive file, download and
96
+ # extract to the transformers cache.
97
+ if path.endswith(".tar.gz"):
98
+ path = transformers.file_utils.cached_path(
99
+ path, extract_compressed_file=True)
100
+ # Load the model from disk.
101
+ models[name] = MODELS_BY_TASK[task](path)
102
+ tasks_to_load.add(task)
103
+
104
+ ##
105
+ # Load datasets for each task that we have a model for
106
+ if "sst2" in tasks_to_load:
107
+ logging.info("Loading data for SST-2 task.")
108
+ datasets["sst_dev"] = glue.SST2Data("validation")
109
+
110
+ if "stsb" in tasks_to_load:
111
+ logging.info("Loading data for STS-B task.")
112
+ datasets["stsb_dev"] = glue.STSBData("validation")
113
+
114
+ if "mnli" in tasks_to_load:
115
+ logging.info("Loading data for MultiNLI task.")
116
+ datasets["mnli_dev"] = glue.MNLIData("validation_matched")
117
+ datasets["mnli_dev_mm"] = glue.MNLIData("validation_mismatched")
118
+
119
+ # Truncate datasets if --max_examples is set.
120
+ for name in datasets:
121
+ logging.info("Dataset: '%s' with %d examples", name, len(datasets[name]))
122
+ datasets[name] = datasets[name].slice[:FLAGS.max_examples]
123
+ logging.info(" truncated to %d examples", len(datasets[name]))
124
 
125
  # Start the LIT server. See server_flags.py for server options.
126
  lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())