import pinecone import streamlit as st from sentence_transformers import SentenceTransformer from transformers import BartTokenizer, BartForConditionalGeneration class BartGenerator: def __init__(self, model_name): self.tokenizer = BartTokenizer.from_pretrained(model_name) self.generator = BartForConditionalGeneration.from_pretrained(model_name) def tokenize(self, query, max_length=1024): inputs = self.tokenizer([query], max_length=max_length, return_tensors="pt") return inputs def generate(self, query, min_length=20, max_length=40): inputs = self.tokenize(query) ids = self.generator.generate(inputs["input_ids"], num_beams=1, min_length=int(min_length), max_length=int(max_length), temperature=int(temperature)) answer = self.tokenizer.batch_decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] return answer @st.experimental_singleton def init_models(): retriever = SentenceTransformer("flax-sentence-embeddings/all_datasets_v3_mpnet-base") #("multi-qa-mpnet-base-cos-v1") ("flax-sentence-embeddings/all_datasets_v3_mpnet-base") generator = BartGenerator("vblagoje/bart_lfqa") return retriever, generator PINECONE_KEY = st.secrets["PINECONE_KEY"] @st.experimental_singleton def init_pinecone(): pinecone.init(api_key=PINECONE_KEY, environment="us-west1-gcp") return pinecone.Index("history-qa") retriever, generator = init_models() index = init_pinecone() def display_answer(answer): return st.markdown(f"""
{m['metadata']['passage_text']}" for m in context] context = " ".join(context) query = f"question: {query} context: {context}" return query # set parameters top_k = 5 min_length = 1 max_length = 150 temperature = 3.5 st.sidebar.write(""" ## Here are some questions you can try out: ### Copy and paste to test who was the first person on the moon?\n Which was the first radio station at Auburn University\n where is Damastown located\n What is the Lohanipur Torso \n when was The Coliseum Theatre opened\n Who invented the tatoo machine\n whats th erecipe for Corn chowder\n when was the Tamil Methodist Church built\n when was the first electric power system built?\n How was the first wireless message sent?\n what was the war of currents?\n what was NASAs most expensive project?\n What brands of smokoing paper are manufactured by Miguel y Costas\n what influenced the naming Holy Forty Martyrs Church\n When was the world first power system built\n which is the largest island within the Halifax Harbour\n Who was Joseph Monier\n who were the Karadjordjevic dynasty\n how many royal tombs were excavated at Tillia Tepe\n What did the HEICO company manufacture\n tell me about The Battle of Antietam\n Which was the smallest microbrewery in the United States\n when did queen marie recieve the bran castle\n Whe was York Township founded\n When did the United Nations Security Council reform the security sector\n When was Magandang Umaga Po first aired\n when was Mae Lan District formed\n what is Voice over Internet Protocol\n When was InfluxDB developed\n When was the Semanário Económico newspaper started\n who owned Kasteln Castle\n when was The Steinbach Haus built\n when was the Guerrero ship in Africa\n tell me about the Guerrero ship\n When was the Companhia Paulista de Trens Metropolitanos rilway built\n When was the lincoln mall demolished\n where is Damastown located\n when was solo diving first practiced\n when was Consumers Credit Union History Consumers Credit Union was founded\n Who built the castle of Daroynk\n What is the prime meridian\n Which was the first radio station at Auburn University\n What are the origins of feminist music\n What were the earliest insecticides to be used\n who were the Drevlians\n Who were the founders of A.F.C. Euro Kickers\n when was the camera-on-a-chip developed\n """) st.write("If you encounter an error, search again.") query = st.text_input("Search!", "") if query != "": with st.spinner(text="Wait a sec 🚀🚀🚀"): xq = retriever.encode([query]).tolist() xc = index.query(xq, top_k=int(top_k), include_metadata=True) query = format_query(query, xc["matches"]) with st.spinner(text="Just a minute ✍️✍️✍️"): answer = generator.generate(query, min_length=min_length, max_length=max_length) st.write("#### System generated response:") display_answer(answer) st.write("#### Here are some resources you might find relevant:") for m in xc["matches"]: title = m["metadata"]["article_title"] url = "https://en.wikipedia.org/wiki/" + title.replace(" ", "_") context = m["metadata"]["passage_text"] display_context(title, context, url)