BillBojangeles2000 commited on
Commit
423c9eb
·
1 Parent(s): 079509f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -10
app.py CHANGED
@@ -1,6 +1,9 @@
1
  import pinecone
2
  from pprint import pprint
3
  import streamlit as st
 
 
 
4
  # connect to pinecone environment
5
  pinecone.init(
6
  api_key="e5d4972e-0045-43d5-a55e-efdeafe442dd",
@@ -23,9 +26,9 @@ index = pinecone.Index(index_name)
23
 
24
  from transformers import BartTokenizer, BartForConditionalGeneration
25
 
26
- # load bart tokenizer and model from huggingface
27
- tokenizer = BartTokenizer.from_pretrained('vblagoje/bart_lfqa')
28
- generator = BartForConditionalGeneration.from_pretrained('vblagoje/bart_lfqa').to('cpu')
29
 
30
  import torch
31
  from sentence_transformers import SentenceTransformer
@@ -51,13 +54,25 @@ def format_query(query, context):
51
  query = f"question: {query} context: {context}"
52
  return query
53
  def generate_answer(query):
54
- # tokenize the query to get input_ids
55
- inputs = tokenizer([query], trunication=True, max_length=1024, return_tensors="pt")
56
- # use generator to predict output ids
57
- ids = generator.generate(inputs["input_ids"], num_beams=2, min_length=20, max_length=64)
58
- # use tokenizer to decode the output ids
59
- answer = tokenizer.batch_decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
60
- st.write(str(answer))
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  query = st.text_area('Enter Question:')
63
  b = st.button('Submit!')
 
1
  import pinecone
2
  from pprint import pprint
3
  import streamlit as st
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM
6
+ model_name = "vblagoje/bart_lfqa"
7
  # connect to pinecone environment
8
  pinecone.init(
9
  api_key="e5d4972e-0045-43d5-a55e-efdeafe442dd",
 
26
 
27
  from transformers import BartTokenizer, BartForConditionalGeneration
28
 
29
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
30
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
31
+ model = model.to('cuda')
32
 
33
  import torch
34
  from sentence_transformers import SentenceTransformer
 
54
  query = f"question: {query} context: {context}"
55
  return query
56
  def generate_answer(query):
57
+ query_and_docs = query
58
+
59
+ model_input = tokenizer(query_and_docs, truncation=True, padding=True, return_tensors="pt")
60
+
61
+ generated_answers_encoded = model.generate(input_ids=model_input["input_ids"].to(device),
62
+ attention_mask=model_input["attention_mask"].to(device),
63
+ min_length=64,
64
+ max_length=256,
65
+ do_sample=False,
66
+ early_stopping=True,
67
+ num_beams=8,
68
+ temperature=1.0,
69
+ top_k=None,
70
+ top_p=None,
71
+ eos_token_id=tokenizer.eos_token_id,
72
+ no_repeat_ngram_size=3,
73
+ num_return_sequences=1)
74
+ res = tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True,clean_up_tokenization_spaces=True)
75
+ st.write(str(res))
76
 
77
  query = st.text_area('Enter Question:')
78
  b = st.button('Submit!')