halyn commited on
Commit
daadf81
·
1 Parent(s): e2eb364

move success check code to setup_qa_chain

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -41,11 +41,10 @@ def load_model():
41
  model_name = "google/gemma-2-2b" # Hugging Face 모델 ID
42
  access_token = os.getenv("HF_TOKEN")
43
  try:
44
- device = 0 if torch.cuda.is_available() else -1
45
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=access_token, clean_up_tokenization_spaces=False)
46
  model = AutoModelForCausalLM.from_pretrained(model_name, token=access_token)
47
  device = 0 if torch.cuda.is_available() else -1
48
- return pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, temperature=0.1)
49
  except Exception as e:
50
  print(f"Error loading model: {e}")
51
  return None
@@ -60,6 +59,12 @@ def setup_qa_chain():
60
  return
61
  llm = HuggingFacePipeline(pipeline=pipe)
62
  qa_chain = load_qa_chain(llm, chain_type="map_rerank")
 
 
 
 
 
 
63
 
64
  # 메인 페이지 UI
65
  def main_page():
@@ -101,10 +106,6 @@ def main_page():
101
 
102
  setup_qa_chain()
103
 
104
- st.session_state.paper_name = paper.name[:-4]
105
- st.session_state.page = "chat"
106
- st.success("PDF successfully processed! You can now ask questions.")
107
-
108
  except Exception as e:
109
  st.error(f"Failed to process the PDF: {str(e)}")
110
 
 
41
  model_name = "google/gemma-2-2b" # Hugging Face 모델 ID
42
  access_token = os.getenv("HF_TOKEN")
43
  try:
 
44
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=access_token, clean_up_tokenization_spaces=False)
45
  model = AutoModelForCausalLM.from_pretrained(model_name, token=access_token)
46
  device = 0 if torch.cuda.is_available() else -1
47
+ return pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, temperature=0.1, device=device)
48
  except Exception as e:
49
  print(f"Error loading model: {e}")
50
  return None
 
59
  return
60
  llm = HuggingFacePipeline(pipeline=pipe)
61
  qa_chain = load_qa_chain(llm, chain_type="map_rerank")
62
+
63
+
64
+ st.session_state.paper_name = paper.name[:-4]
65
+ st.session_state.page = "chat"
66
+ st.success("PDF successfully processed! You can now ask questions.")
67
+
68
 
69
  # 메인 페이지 UI
70
  def main_page():
 
106
 
107
  setup_qa_chain()
108
 
 
 
 
 
109
  except Exception as e:
110
  st.error(f"Failed to process the PDF: {str(e)}")
111