Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -16,8 +16,8 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
16 |
retriever = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1", device=device)
|
17 |
|
18 |
# Load generation model
|
19 |
-
tokenizer = AutoTokenizer.from_pretrained("
|
20 |
-
model = AutoModelForSeq2SeqLM.from_pretrained("
|
21 |
|
22 |
|
23 |
def scrape(urls: List[str]) -> Dataset:
|
@@ -64,7 +64,7 @@ def search_web(query: str) -> List[str]:
|
|
64 |
|
65 |
# Extract the title and URL of the top search results
|
66 |
urls = set()
|
67 |
-
for result in search_results[:
|
68 |
url = result.find("a")["href"]
|
69 |
if url.startswith("http"):
|
70 |
urls.add(url)
|
@@ -83,12 +83,10 @@ def generate_answer(question_doc: str) -> str:
|
|
83 |
model_output = model.generate(
|
84 |
input_ids=q_ids,
|
85 |
attention_mask=q_mask,
|
86 |
-
min_new_tokens=32,
|
87 |
max_new_tokens=256,
|
88 |
-
no_repeat_ngram_size=3,
|
89 |
-
num_beams=2,
|
90 |
-
do_sample=True,
|
91 |
length_penalty=1.5,
|
|
|
|
|
92 |
)
|
93 |
answer = tokenizer.batch_decode(model_output, skip_special_tokens=True)[0]
|
94 |
return answer.strip()
|
|
|
16 |
retriever = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1", device=device)
|
17 |
|
18 |
# Load generation model
|
19 |
+
tokenizer = AutoTokenizer.from_pretrained("MahmoudH/t5-v1_1-base-abs_qa")
|
20 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("MahmoudH/t5-v1_1-base-abs_qa", from_tf=True).to(device)
|
21 |
|
22 |
|
23 |
def scrape(urls: List[str]) -> Dataset:
|
|
|
64 |
|
65 |
# Extract the title and URL of the top search results
|
66 |
urls = set()
|
67 |
+
for result in search_results[:5]:
|
68 |
url = result.find("a")["href"]
|
69 |
if url.startswith("http"):
|
70 |
urls.add(url)
|
|
|
83 |
model_output = model.generate(
|
84 |
input_ids=q_ids,
|
85 |
attention_mask=q_mask,
|
|
|
86 |
max_new_tokens=256,
|
|
|
|
|
|
|
87 |
length_penalty=1.5,
|
88 |
+
do_sample=True,
|
89 |
+
num_beams=4
|
90 |
)
|
91 |
answer = tokenizer.batch_decode(model_output, skip_special_tokens=True)[0]
|
92 |
return answer.strip()
|