LukeOLuck commited on
Commit
6c038be
·
1 Parent(s): 8b4f499

Remove torchinfo

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. model.py +1 -6
app.py CHANGED
@@ -17,7 +17,7 @@ with open("example_texts.txt", "r") as file:
17
 
18
  ### Model and transforms preparation ###
19
  # Create model and tokenizer
20
- model, tokenizer = create_roberta_model(output_shape=len(class_names), print_summary=False)
21
 
22
  # Load saved weights
23
  model.load_state_dict(
 
17
 
18
  ### Model and transforms preparation ###
19
  # Create model and tokenizer
20
+ model, tokenizer = create_roberta_model(output_shape=len(class_names))
21
 
22
  # Load saved weights
23
  model.load_state_dict(
model.py CHANGED
@@ -1,5 +1,4 @@
1
  import torch
2
- from torchinfo import summary
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
 
5
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -9,7 +8,7 @@ def model_input_wrapper(batch_size, sequence_length, tokenizer):
9
  dummy_attention_mask = torch.ones(batch_size, sequence_length, dtype=torch.long)
10
  return {'input_ids': dummy_input_ids, 'attention_mask': dummy_attention_mask}
11
 
12
- def create_roberta_model(output_shape:int=10, device=device, print_summary=True):
13
  """Creates a HuggingFace roberta-base model.
14
 
15
  Args:
@@ -31,8 +30,4 @@ def create_roberta_model(output_shape:int=10, device=device, print_summary=True)
31
 
32
  model.classifier.out_proj = torch.nn.Linear(in_features=768, out_features=output_shape)
33
 
34
- if print_summary:
35
- sample_inputs = model_input_wrapper(1, 128, tokenizer)
36
- print(summary(model, input_data=sample_inputs, verbose=0, col_names=["input_size", "output_size", "num_params", "trainable"], col_width=20, row_settings=["var_names"]))
37
-
38
  return model.to(device), tokenizer
 
1
  import torch
 
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
 
4
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
8
  dummy_attention_mask = torch.ones(batch_size, sequence_length, dtype=torch.long)
9
  return {'input_ids': dummy_input_ids, 'attention_mask': dummy_attention_mask}
10
 
11
+ def create_roberta_model(output_shape:int=10, device=device):
12
  """Creates a HuggingFace roberta-base model.
13
 
14
  Args:
 
30
 
31
  model.classifier.out_proj = torch.nn.Linear(in_features=768, out_features=output_shape)
32
 
 
 
 
 
33
  return model.to(device), tokenizer