ai-forever commited on
Commit
d19e0a0
1 Parent(s): 75c52ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -3,8 +3,8 @@ import gradio as gr
3
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
4
  tokenizer = GPT2Tokenizer.from_pretrained("sberbank-ai/mGPT")
5
  model = GPT2LMHeadModel.from_pretrained("sberbank-ai/mGPT")
6
- model.cuda()
7
- model.eval()
8
 
9
  description = "Multilingual generation with mGPT"
10
  title = "Generate your own example"
@@ -17,8 +17,11 @@ article = (
17
  "</p>"
18
  )
19
 
 
 
 
20
  def generate(prompt: str):
21
- input_ids = tokenizer.encode(prompt, return_tensors="pt").cuda()
22
  out = model.generate(input_ids,
23
  min_length=100,
24
  max_length=200,
 
3
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
4
  tokenizer = GPT2Tokenizer.from_pretrained("sberbank-ai/mGPT")
5
  model = GPT2LMHeadModel.from_pretrained("sberbank-ai/mGPT")
6
+ #model.cuda()
7
+ #model.eval()
8
 
9
  description = "Multilingual generation with mGPT"
10
  title = "Generate your own example"
 
17
  "</p>"
18
  )
19
 
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ fp16 = device != 'cpu'
22
+
23
  def generate(prompt: str):
24
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
25
  out = model.generate(input_ids,
26
  min_length=100,
27
  max_length=200,