m3 commited on
Commit
a3b447c
·
1 Parent(s): aad22ec

chore: add new safetensors model

Browse files
Files changed (2) hide show
  1. model.safetensors +3 -0
  2. src/init_onnx.py +7 -1
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9bbcbf73561f6bc5d0a17ea6a2081feed2d1304e87602d8c502d9a5c4bd85576
3
+ size 16
src/init_onnx.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  import torch.nn as nn
3
- import torch.onnx
4
  class BaseModel(nn.Module):
5
  def __init__(self):
6
  super(BaseModel, self).__init__()
@@ -13,6 +13,12 @@ model = BaseModel()
13
 
14
  dummy_input = torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float32)
15
 
 
 
 
 
 
 
16
  onnx_file_path = "model.onnx"
17
  torch.onnx.export(model, dummy_input, onnx_file_path,
18
  input_names=['input'], output_names=['output'],
 
1
  import torch
2
  import torch.nn as nn
3
+
4
  class BaseModel(nn.Module):
5
  def __init__(self):
6
  super(BaseModel, self).__init__()
 
13
 
14
  dummy_input = torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float32)
15
 
16
+ safetensors_file_path = "model.safetensors"
17
+
18
+ from safetensors.torch import save_file
19
+ save_file(model.state_dict(), 'model.safetensors')
20
+
21
+ import torch.onnx
22
  onnx_file_path = "model.onnx"
23
  torch.onnx.export(model, dummy_input, onnx_file_path,
24
  input_names=['input'], output_names=['output'],