halyn commited on
Commit
bbfcd34
·
1 Parent(s): 7f83f4e

use gpu and new qa_chain

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -1,9 +1,11 @@
1
  import io
2
  import os
 
3
  import streamlit as st
4
  from PyPDF2 import PdfReader
5
  from langchain.text_splitter import CharacterTextSplitter
6
  from langchain.chains.question_answering import load_qa_chain
 
7
  from langchain_huggingface import HuggingFaceEmbeddings
8
  from langchain_community.vectorstores import FAISS
9
  from langchain_community.llms import HuggingFacePipeline
@@ -41,6 +43,7 @@ def load_model():
41
  model_name = "google/gemma-2-2b" # Hugging Face 모델 ID
42
  access_token = os.getenv("HF_TOKEN")
43
  try:
 
44
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=access_token, clean_up_tokenization_spaces=False)
45
  model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=access_token)
46
  return pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, temperature=0.1)
@@ -57,7 +60,7 @@ def setup_qa_chain():
57
  print(f"Error loading model: {e}")
58
  return
59
  llm = HuggingFacePipeline(pipeline=pipe)
60
- qa_chain = load_qa_chain(llm, chain_type="stuff")
61
 
62
  # 메인 페이지 UI
63
  def main_page():
 
1
  import io
2
  import os
3
+ import torch
4
  import streamlit as st
5
  from PyPDF2 import PdfReader
6
  from langchain.text_splitter import CharacterTextSplitter
7
  from langchain.chains.question_answering import load_qa_chain
8
+ from langchain.chains import load_qa_with_sources_chain
9
  from langchain_huggingface import HuggingFaceEmbeddings
10
  from langchain_community.vectorstores import FAISS
11
  from langchain_community.llms import HuggingFacePipeline
 
43
  model_name = "google/gemma-2-2b" # Hugging Face 모델 ID
44
  access_token = os.getenv("HF_TOKEN")
45
  try:
46
+ device = 0 if torch.cuda.is_available() else -1
47
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=access_token, clean_up_tokenization_spaces=False)
48
  model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=access_token)
49
  return pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, temperature=0.1)
 
60
  print(f"Error loading model: {e}")
61
  return
62
  llm = HuggingFacePipeline(pipeline=pipe)
63
+ qa_chain = load_qa_with_sources_chain(llm)
64
 
65
  # 메인 페이지 UI
66
  def main_page():