tamatwi commited on
Commit
c0c9825
1 Parent(s): 7e35d0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -11
app.py CHANGED
@@ -1,12 +1,3 @@
1
- import gradio as gr
2
- from transformers import pipeline, AutoTokenizer
3
-
4
- gr.load("models/rinna/japanese-gpt2-medium").launch()
5
-
6
-
7
- # 日本語モデルを指定
8
- model_name = "rinna/japanese-gpt2-medium"
9
-
10
  from spaces import GPU
11
 
12
  @GPU(duration=120)
@@ -14,10 +5,13 @@ def generate_text(prompt, max_length):
14
  result = generator(prompt, max_length=max_length, num_return_sequences=1)
15
  return result[0]['generated_text']
16
 
 
 
17
 
18
  # トークナイザーとパイプラインの設定
19
  tokenizer = AutoTokenizer.from_pretrained(model_name)
20
- generator = pipeline('text-generation', model=model_name, tokenizer=tokenizer, device=0) # device=0はGPUを使用する設定
 
21
 
22
  def generate_text(prompt, max_length):
23
  result = generator(prompt, max_length=max_length, num_return_sequences=1)
@@ -32,4 +26,4 @@ iface = gr.Interface(
32
  outputs=gr.Textbox(label="生成されたテキスト")
33
  )
34
 
35
- iface.launch()
 
 
 
 
 
 
 
 
 
 
1
  from spaces import GPU
2
 
3
  @GPU(duration=120)
 
5
  result = generator(prompt, max_length=max_length, num_return_sequences=1)
6
  return result[0]['generated_text']
7
 
8
+ # 日本語モデルを指定
9
+ model_name = "rinna/japanese-gpt2-medium"
10
 
11
  # トークナイザーとパイプラインの設定
12
  tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+ generator = pipeline('text-generation', model=model_name, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
15
 
16
  def generate_text(prompt, max_length):
17
  result = generator(prompt, max_length=max_length, num_return_sequences=1)
 
26
  outputs=gr.Textbox(label="生成されたテキスト")
27
  )
28
 
29
+ iface.launch(share=True)