hamishivi commited on
Commit
50d00b1
·
verified ·
1 Parent(s): 287142c
Files changed (3) hide show
  1. app.py +1 -1
  2. args.json +15 -0
  3. sdlm/arguments.py +3 -3
app.py CHANGED
@@ -17,7 +17,7 @@ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
17
 
18
 
19
  def main():
20
- model_args, data_args, training_args, diffusion_args = get_args()
21
  tokenizer, model = load_model(model_args, data_args, training_args, diffusion_args, logger)
22
 
23
  model.eval()
 
17
 
18
 
19
  def main():
20
+ model_args, data_args, training_args, diffusion_args = get_args("args.json")
21
  tokenizer, model = load_model(model_args, data_args, training_args, diffusion_args, logger)
22
 
23
  model.eval()
args.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name_or_path": "hamishivi/tess2-v0.3",
3
+ "simplex_value": 5,
4
+ "num_diffusion_steps": 5000,
5
+ "num_inference_diffusion_steps": 100,
6
+ "beta_schedule": "squaredcos_improved_ddpm",
7
+ "top_p": 0.99,
8
+ "self_condition": "logits_mean",
9
+ "self_condition_mix_before_weights": true,
10
+ "is_causal": false,
11
+ "mask_padding_in_loss": false,
12
+ "dataset_name": "c4",
13
+ "streaming": true,
14
+ "dataset_config_name": "en"
15
+ }
sdlm/arguments.py CHANGED
@@ -11,7 +11,7 @@ MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
11
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
12
 
13
 
14
- def get_args():
15
  parser = HfArgumentParser(
16
  (
17
  ModelArguments,
@@ -20,11 +20,11 @@ def get_args():
20
  DiffusionArguments,
21
  )
22
  )
23
- if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
24
  # If we pass only one argument to the script and it's the path to a json file,
25
  # let's parse it to get our arguments.
26
  model_args, data_args, training_args, diffusion_args = parser.parse_json_file(
27
- json_file=os.path.abspath(sys.argv[1])
28
  )
29
  else:
30
  (
 
11
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
12
 
13
 
14
+ def get_args(filename: str = None):
15
  parser = HfArgumentParser(
16
  (
17
  ModelArguments,
 
20
  DiffusionArguments,
21
  )
22
  )
23
+ if filename is not None:
24
  # If we pass only one argument to the script and it's the path to a json file,
25
  # let's parse it to get our arguments.
26
  model_args, data_args, training_args, diffusion_args = parser.parse_json_file(
27
+ json_file=filename
28
  )
29
  else:
30
  (