justinj92 commited on
Commit
213c70e
·
verified ·
1 Parent(s): 54a2331

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -9
app.py CHANGED
@@ -86,10 +86,16 @@ if not os.path.exists(CFG.Embeddings_path + '/index.faiss'):
86
  embeddings = HuggingFaceInstructEmbeddings(model_name = CFG.embeddings_model_repo, model_kwargs={"device":"cuda"})
87
  vectordb = FAISS.load_local(CFG.Output_folder + '/faiss_index_ml_papers', embeddings, allow_dangerous_deserialization=True)
88
 
89
- @spaces.GPU
90
  def build_model(model_repo = CFG.model_name):
91
  tokenizer = AutoTokenizer.from_pretrained(model_repo)
92
  model = AutoModelForCausalLM.from_pretrained(model_repo, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16)
 
 
 
 
 
 
93
  device = torch.device("cuda")
94
  model = model.to(device)
95
 
@@ -107,12 +113,7 @@ terminators = [
107
  ]
108
 
109
 
110
- # if torch.cuda.is_available():
111
- # device = torch.device("cuda")
112
- # print(f"Using GPU: {torch.cuda.get_device_name(device)}")
113
- # else:
114
- # device = torch.device("cpu")
115
- # print("Using CPU")
116
 
117
  pipe = pipeline(task="text-generation", model=model, tokenizer=tok, eos_token_id=terminators, do_sample=True, max_new_tokens=CFG.max_new_tokens, temperature=CFG.temperature, top_p=CFG.top_p, repetition_penalty=CFG.repetition_penalty)
118
 
@@ -167,7 +168,7 @@ qa_chain = RetrievalQA.from_chain_type(
167
  verbose = False
168
  )
169
 
170
- @spaces.GPU
171
  def wrap_text_preserve_newlines(text, width=1500):
172
  # Split the input text into lines based on newline characters
173
  lines = text.split('\n')
@@ -180,7 +181,7 @@ def wrap_text_preserve_newlines(text, width=1500):
180
 
181
  return wrapped_text
182
 
183
- @spaces.GPU
184
  def process_llm_response(llm_response):
185
  ans = wrap_text_preserve_newlines(llm_response['result'])
186
 
 
86
  embeddings = HuggingFaceInstructEmbeddings(model_name = CFG.embeddings_model_repo, model_kwargs={"device":"cuda"})
87
  vectordb = FAISS.load_local(CFG.Output_folder + '/faiss_index_ml_papers', embeddings, allow_dangerous_deserialization=True)
88
 
89
+
90
  def build_model(model_repo = CFG.model_name):
91
  tokenizer = AutoTokenizer.from_pretrained(model_repo)
92
  model = AutoModelForCausalLM.from_pretrained(model_repo, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16)
93
+ if torch.cuda.is_available():
94
+ device = torch.device("cuda")
95
+ print(f"Using GPU: {torch.cuda.get_device_name(device)}")
96
+ else:
97
+ device = torch.device("cpu")
98
+ print("Using CPU")
99
  device = torch.device("cuda")
100
  model = model.to(device)
101
 
 
113
  ]
114
 
115
 
116
+
 
 
 
 
 
117
 
118
  pipe = pipeline(task="text-generation", model=model, tokenizer=tok, eos_token_id=terminators, do_sample=True, max_new_tokens=CFG.max_new_tokens, temperature=CFG.temperature, top_p=CFG.top_p, repetition_penalty=CFG.repetition_penalty)
119
 
 
168
  verbose = False
169
  )
170
 
171
+
172
  def wrap_text_preserve_newlines(text, width=1500):
173
  # Split the input text into lines based on newline characters
174
  lines = text.split('\n')
 
181
 
182
  return wrapped_text
183
 
184
+
185
  def process_llm_response(llm_response):
186
  ans = wrap_text_preserve_newlines(llm_response['result'])
187