Sakalti commited on
Commit
75b8738
·
verified ·
1 Parent(s): b4e040e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -1
app.py CHANGED
@@ -3,6 +3,7 @@ import requests
3
  import torch
4
  from transformers import AutoModel
5
  from huggingface_hub import HfApi
 
6
 
7
  def convert_and_deploy(url, repo_id, hf_token):
8
  # セーフテンソルファイルをダウンロード
@@ -15,14 +16,28 @@ def convert_and_deploy(url, repo_id, hf_token):
15
  with open(file_path, "wb") as f:
16
  f.write(response.content)
17
 
 
 
 
 
 
 
 
 
 
 
18
  # モデルを読み込み
19
  try:
20
  # セーフテンソルファイルからモデルの状態を読み込み
21
- state_dict = torch.load(file_path)
 
22
 
23
  # モデルを初期化
24
  model = AutoModel.from_pretrained("path_to_model", torch_dtype=torch.float16, token=hf_token)
25
 
 
 
 
26
  # モデルの状態を設定
27
  model.load_state_dict(state_dict)
28
  except Exception as e:
 
3
  import torch
4
  from transformers import AutoModel
5
  from huggingface_hub import HfApi
6
+ import safetensors
7
 
8
  def convert_and_deploy(url, repo_id, hf_token):
9
  # セーフテンソルファイルをダウンロード
 
16
  with open(file_path, "wb") as f:
17
  f.write(response.content)
18
 
19
+ # ファイルの存在を確認
20
+ if not os.path.exists(file_path):
21
+ return "ファイルが正しく保存されませんでした。"
22
+
23
+ # ファイルの内容を確認
24
+ with open(file_path, "rb") as f:
25
+ content = f.read(100) # 先頭100バイトを読み込む
26
+ if not content:
27
+ return "ファイルが空です。"
28
+
29
  # モデルを読み込み
30
  try:
31
  # セーフテンソルファイルからモデルの状態を読み込み
32
+ with safetensors.safe_open(file_path, framework="pt") as f:
33
+ state_dict = {k: f.get_tensor(k) for k in f.keys()}
34
 
35
  # モデルを初期化
36
  model = AutoModel.from_pretrained("path_to_model", torch_dtype=torch.float16, token=hf_token)
37
 
38
+ # BF16からFP16に変換
39
+ state_dict = {k: v.to(torch.float16) for k, v in state_dict.items()}
40
+
41
  # モデルの状態を設定
42
  model.load_state_dict(state_dict)
43
  except Exception as e: