RishuD7 commited on
Commit
4612b54
·
1 Parent(s): b8ca476

fixed device detect issue

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -11,13 +11,13 @@ if __name__ == '__main__':
11
  import gradio as gr
12
  import os
13
 
14
- device = f'cuda:{torch.cuda.current_device()}' if torch.cuda.is_available() else 'cpu'
15
  config_path = os.path.join('models_file', 'config.pth')
16
  model_path = os.path.join('models_file', 'microsoft-deberta-base_0.9449373420387531_8_best.pth')
17
  tokenizer = AutoTokenizer.from_pretrained('models_file/tokenizer')
18
  model = CustomModel(CFG, config_path=config_path, pretrained=False)
19
  state = torch.load(model_path,
20
- map_location=torch.device('cuda'))
21
  model.load_state_dict(state['model'])
22
 
23
  def get_answer(context, feature):
 
11
  import gradio as gr
12
  import os
13
 
14
+ device = f'cuda' if torch.cuda.is_available() else 'cpu'
15
  config_path = os.path.join('models_file', 'config.pth')
16
  model_path = os.path.join('models_file', 'microsoft-deberta-base_0.9449373420387531_8_best.pth')
17
  tokenizer = AutoTokenizer.from_pretrained('models_file/tokenizer')
18
  model = CustomModel(CFG, config_path=config_path, pretrained=False)
19
  state = torch.load(model_path,
20
+ map_location=device)
21
  model.load_state_dict(state['model'])
22
 
23
  def get_answer(context, feature):