Guchyos commited on
Commit
2ac2435
·
verified ·
1 Parent(s): dfd22a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -23
app.py CHANGED
@@ -1,47 +1,49 @@
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
 
 
 
 
 
 
 
 
4
 
5
  def load_model():
6
  model_name = "Guchyos/gemma-2b-elyza-task"
7
 
8
  print("Loading tokenizer...")
9
- tokenizer = AutoTokenizer.from_pretrained(model_name)
10
 
11
  print("Loading model...")
12
  model = AutoModelForCausalLM.from_pretrained(
13
  model_name,
14
  device_map="auto",
15
- torch_dtype=torch.float16
 
16
  )
17
  return model, tokenizer
18
 
19
- # モデルをグローバルに1回だけロード
20
- try:
21
- model, tokenizer = load_model()
22
- print("Model loaded successfully!")
23
- except Exception as e:
24
- print(f"Error loading model: {str(e)}")
25
-
26
  def predict(message, history):
 
 
 
27
  try:
28
- # 入力の準備
29
- prompt = f"質問: {message}\n\n回答:"
30
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
31
 
32
- # 生成
33
  outputs = model.generate(
34
  **inputs,
35
  max_new_tokens=512,
36
  temperature=0.7,
37
  top_p=0.9,
38
- do_sample=True,
39
- repetition_penalty=1.1
40
  )
41
 
42
- # 応答の生成
43
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
44
- return response.replace(prompt, "").strip()
45
 
46
  except Exception as e:
47
  return f"エラーが発生しました: {str(e)}"
@@ -56,7 +58,6 @@ demo = gr.ChatInterface(
56
  ## 使い方
57
  - 質問を入力してEnterキーを押してください
58
  - 生成には数秒かかります
59
- - 結果が気に入らない場合は「再生成」ボタンを押してください
60
 
61
  ## 特徴
62
  - 4bit量子化により最適化
@@ -67,11 +68,7 @@ demo = gr.ChatInterface(
67
  "日本の四季について、それぞれの特徴を説明してください。",
68
  "人工知能の発展における倫理的な課題について説明してください。",
69
  "東京の主要な観光スポットを3つ挙げて、それぞれ説明してください。"
70
- ],
71
- retry_btn="🔄 再生成",
72
- undo_btn="↩️ 取り消し",
73
- clear_btn="🗑️ クリア",
74
- theme=gr.themes.Soft()
75
  )
76
 
77
  # アプリの起動
 
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
+ import os
5
+ from huggingface_hub import login
6
+
7
+ # Hugging Face トークンを環境変数から取得して認証
8
+ try:
9
+ login(token=os.environ.get("HUGGINGFACE_TOKEN"))
10
+ except:
11
+ print("Warning: HUGGINGFACE_TOKEN not found")
12
 
13
  def load_model():
14
  model_name = "Guchyos/gemma-2b-elyza-task"
15
 
16
  print("Loading tokenizer...")
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)
18
 
19
  print("Loading model...")
20
  model = AutoModelForCausalLM.from_pretrained(
21
  model_name,
22
  device_map="auto",
23
+ torch_dtype=torch.float16,
24
+ use_auth_token=True
25
  )
26
  return model, tokenizer
27
 
 
 
 
 
 
 
 
28
  def predict(message, history):
29
+ # 履歴がある場合は考慮
30
+ full_prompt = f"質問: {message}\n\n回答:"
31
+
32
  try:
33
+ # モデルとトークナイザーをロード(毎回ロード)
34
+ model, tokenizer = load_model()
 
35
 
36
+ inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
37
  outputs = model.generate(
38
  **inputs,
39
  max_new_tokens=512,
40
  temperature=0.7,
41
  top_p=0.9,
42
+ do_sample=True
 
43
  )
44
 
 
45
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
46
+ return response.replace(full_prompt, "").strip()
47
 
48
  except Exception as e:
49
  return f"エラーが発生しました: {str(e)}"
 
58
  ## 使い方
59
  - 質問を入力してEnterキーを押してください
60
  - 生成には数秒かかります
 
61
 
62
  ## 特徴
63
  - 4bit量子化により最適化
 
68
  "日本の四季について、それぞれの特徴を説明してください。",
69
  "人工知能の発展における倫理的な課題について説明してください。",
70
  "東京の主要な観光スポットを3つ挙げて、それぞれ説明してください。"
71
+ ]
 
 
 
 
72
  )
73
 
74
  # アプリの起動