probablybots commited on
Commit
7441644
·
verified ·
1 Parent(s): 0071a9b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +8 -8
README.md CHANGED
@@ -73,8 +73,8 @@ mgen test --model SequenceClassification --model.backbone aido_protein_16b --dat
73
  ```python
74
  from modelgenerator.tasks import Embed
75
  model = Embed.from_config({"model.backbone": "aido_protein_16b"}).eval()
76
- collated_batch = model.collate({"sequences": ["HELLQ", "WRLD"]})
77
- embedding = model(collated_batch)
78
  print(embedding.shape)
79
  print(embedding)
80
  ```
@@ -83,8 +83,8 @@ print(embedding)
83
  import torch
84
  from modelgenerator.tasks import SequenceClassification
85
  model = SequenceClassification.from_config({"model.backbone": "aido_protein_16b", "model.n_classes": 2}).eval()
86
- collated_batch = model.collate({"sequences": ["HELLQ", "WRLD"]})
87
- logits = model(collated_batch)
88
  print(logits)
89
  print(torch.argmax(logits, dim=-1))
90
  ```
@@ -93,8 +93,8 @@ print(torch.argmax(logits, dim=-1))
93
  import torch
94
  from modelgenerator.tasks import TokenClassification
95
  model = TokenClassification.from_config({"model.backbone": "aido_protein_16b", "model.n_classes": 3}).eval()
96
- collated_batch = model.collate({"sequences": ["HELLQ", "WRLD"]})
97
- logits = model(collated_batch)
98
  print(logits)
99
  print(torch.argmax(logits, dim=-1))
100
  ```
@@ -102,8 +102,8 @@ print(torch.argmax(logits, dim=-1))
102
  ```python
103
  from modelgenerator.tasks import SequenceRegression
104
  model = SequenceRegression.from_config({"model.backbone": "aido_protein_16b"}).eval()
105
- collated_batch = model.collate({"sequences": ["HELLQ", "WRLD"]})
106
- logits = model(collated_batch)
107
  print(logits)
108
  ```
109
 
 
73
  ```python
74
  from modelgenerator.tasks import Embed
75
  model = Embed.from_config({"model.backbone": "aido_protein_16b"}).eval()
76
+ transformed_batch = model.transform({"sequences": ["HELLQ", "WRLD"]})
77
+ embedding = model(transformed_batch)
78
  print(embedding.shape)
79
  print(embedding)
80
  ```
 
83
  import torch
84
  from modelgenerator.tasks import SequenceClassification
85
  model = SequenceClassification.from_config({"model.backbone": "aido_protein_16b", "model.n_classes": 2}).eval()
86
+ transformed_batch = model.transform({"sequences": ["HELLQ", "WRLD"]})
87
+ logits = model(transformed_batch)
88
  print(logits)
89
  print(torch.argmax(logits, dim=-1))
90
  ```
 
93
  import torch
94
  from modelgenerator.tasks import TokenClassification
95
  model = TokenClassification.from_config({"model.backbone": "aido_protein_16b", "model.n_classes": 3}).eval()
96
+ transformed_batch = model.transform({"sequences": ["HELLQ", "WRLD"]})
97
+ logits = model(transformed_batch)
98
  print(logits)
99
  print(torch.argmax(logits, dim=-1))
100
  ```
 
102
  ```python
103
  from modelgenerator.tasks import SequenceRegression
104
  model = SequenceRegression.from_config({"model.backbone": "aido_protein_16b"}).eval()
105
+ transformed_batch = model.transform({"sequences": ["HELLQ", "WRLD"]})
106
+ logits = model(transformed_batch)
107
  print(logits)
108
  ```
109