MahmoudH commited on
Commit
332312f
·
1 Parent(s): 633e625

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -7
app.py CHANGED
@@ -16,8 +16,8 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
16
  retriever = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1", device=device)
17
 
18
  # Load generation model
19
- tokenizer = AutoTokenizer.from_pretrained("yjernite/bart_eli5")
20
- model = AutoModelForSeq2SeqLM.from_pretrained("yjernite/bart_eli5").to(device)
21
 
22
 
23
  def scrape(urls: List[str]) -> Dataset:
@@ -64,7 +64,7 @@ def search_web(query: str) -> List[str]:
64
 
65
  # Extract the title and URL of the top search results
66
  urls = set()
67
- for result in search_results[:10]:
68
  url = result.find("a")["href"]
69
  if url.startswith("http"):
70
  urls.add(url)
@@ -83,12 +83,10 @@ def generate_answer(question_doc: str) -> str:
83
  model_output = model.generate(
84
  input_ids=q_ids,
85
  attention_mask=q_mask,
86
- min_new_tokens=32,
87
  max_new_tokens=256,
88
- no_repeat_ngram_size=3,
89
- num_beams=2,
90
- do_sample=True,
91
  length_penalty=1.5,
 
 
92
  )
93
  answer = tokenizer.batch_decode(model_output, skip_special_tokens=True)[0]
94
  return answer.strip()
 
16
  retriever = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1", device=device)
17
 
18
  # Load generation model
19
+ tokenizer = AutoTokenizer.from_pretrained("MahmoudH/t5-v1_1-base-abs_qa")
20
+ model = AutoModelForSeq2SeqLM.from_pretrained("MahmoudH/t5-v1_1-base-abs_qa", from_tf=True).to(device)
21
 
22
 
23
  def scrape(urls: List[str]) -> Dataset:
 
64
 
65
  # Extract the title and URL of the top search results
66
  urls = set()
67
+ for result in search_results[:5]:
68
  url = result.find("a")["href"]
69
  if url.startswith("http"):
70
  urls.add(url)
 
83
  model_output = model.generate(
84
  input_ids=q_ids,
85
  attention_mask=q_mask,
 
86
  max_new_tokens=256,
 
 
 
87
  length_penalty=1.5,
88
+ do_sample=True,
89
+ num_beams=4
90
  )
91
  answer = tokenizer.batch_decode(model_output, skip_special_tokens=True)[0]
92
  return answer.strip()