halyn commited on
Commit
2c73e16
ยท
1 Parent(s): 6a062bb

modify model load

Browse files
Files changed (1) hide show
  1. app.py +13 -1
app.py CHANGED
@@ -41,12 +41,24 @@ def load_model():
41
  try:
42
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=access_token, clean_up_tokenization_spaces=False)
43
  model = AutoModelForCausalLM.from_pretrained(model_name, token=access_token)
44
- device = 0 if torch.cuda.is_available() else -1
 
 
 
 
 
 
 
 
 
 
 
45
  return pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, temperature=0.1, device=device)
46
  except Exception as e:
47
  print(f"Error loading model: {e}")
48
  return None
49
 
 
50
  # ํŽ˜์ด์ง€ UI
51
  def main():
52
  st.title("Welcome to GemmaPaperQA")
 
41
  try:
42
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=access_token, clean_up_tokenization_spaces=False)
43
  model = AutoModelForCausalLM.from_pretrained(model_name, token=access_token)
44
+
45
+ # ๋””๋ฒ„๊น…: GPU/CPU ํ™•์ธ ๋ฐ ์ถœ๋ ฅ
46
+ if torch.cuda.is_available():
47
+ print("Using GPU")
48
+ device = 0
49
+ else:
50
+ print("Using CPU")
51
+ device = -1
52
+
53
+ # ๋””๋ฒ„๊น…: device ์ถœ๋ ฅ
54
+ print(f"Device: {device}")
55
+
56
  return pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, temperature=0.1, device=device)
57
  except Exception as e:
58
  print(f"Error loading model: {e}")
59
  return None
60
 
61
+
62
  # ํŽ˜์ด์ง€ UI
63
  def main():
64
  st.title("Welcome to GemmaPaperQA")