Yoxas commited on
Commit
be520f8
·
verified ·
1 Parent(s): 49dab71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -5,8 +5,8 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
5
  import gradio as gr
6
  import json
7
  import faiss
8
- import spaces
9
  import numpy as np
 
10
  # Ensure you have GPU support
11
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
 
@@ -32,7 +32,6 @@ llama_model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(d
32
  summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device=0 if device == 'cuda' else -1)
33
 
34
  # Define the function to find the most relevant document using FAISS
35
- @spaces.GPU(duration=120)
36
  def retrieve_relevant_doc(query):
37
  query_embedding = sentence_model.encode(query, convert_to_tensor=False)
38
  _, indices = index.search(np.array([query_embedding]), k=1)
@@ -40,14 +39,23 @@ def retrieve_relevant_doc(query):
40
  return df.iloc[best_match_idx]['Abstract']
41
 
42
  # Define the function to generate a response
43
- @spaces.GPU(duration=120)
44
  def generate_response(query):
45
  relevant_doc = retrieve_relevant_doc(query)
46
  if len(relevant_doc) > 512: # Truncate long documents
47
  relevant_doc = summarizer(relevant_doc, max_length=150, min_length=50, do_sample=False)[0]['summary_text']
 
48
  input_text = f"Document: {relevant_doc}\n\nQuestion: {query}\n\nAnswer:"
49
  inputs = llama_tokenizer(input_text, return_tensors="pt").to(device)
50
- outputs = llama_model.generate(inputs["input_ids"], max_length=150)
 
 
 
 
 
 
 
 
 
51
  response = llama_tokenizer.decode(outputs[0], skip_special_tokens=True)
52
  return response
53
 
 
5
  import gradio as gr
6
  import json
7
  import faiss
 
8
  import numpy as np
9
+ import spaces
10
  # Ensure you have GPU support
11
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
 
 
32
  summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device=0 if device == 'cuda' else -1)
33
 
34
  # Define the function to find the most relevant document using FAISS
 
35
  def retrieve_relevant_doc(query):
36
  query_embedding = sentence_model.encode(query, convert_to_tensor=False)
37
  _, indices = index.search(np.array([query_embedding]), k=1)
 
39
  return df.iloc[best_match_idx]['Abstract']
40
 
41
  # Define the function to generate a response
 
42
  def generate_response(query):
43
  relevant_doc = retrieve_relevant_doc(query)
44
  if len(relevant_doc) > 512: # Truncate long documents
45
  relevant_doc = summarizer(relevant_doc, max_length=150, min_length=50, do_sample=False)[0]['summary_text']
46
+
47
  input_text = f"Document: {relevant_doc}\n\nQuestion: {query}\n\nAnswer:"
48
  inputs = llama_tokenizer(input_text, return_tensors="pt").to(device)
49
+
50
+ # Set pad_token_id to eos_token_id to avoid the warning
51
+ pad_token_id = llama_tokenizer.eos_token_id
52
+ outputs = llama_model.generate(
53
+ inputs["input_ids"],
54
+ attention_mask=inputs["attention_mask"],
55
+ max_length=150,
56
+ pad_token_id=pad_token_id
57
+ )
58
+
59
  response = llama_tokenizer.decode(outputs[0], skip_special_tokens=True)
60
  return response
61