text_gen / app.py
saima730's picture
Update app.py
9f29f51 verified
raw
history blame
2.2 kB
import streamlit as st
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, pipeline
# Function to generate response using RAG (Retrieval-Augmented Generation)
def generate_response_with_rag(txt):
try:
# Initialize the RAG model and tokenizer
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")
# Tokenize the input text
inputs = tokenizer(txt, return_tensors="pt")
# Retrieve relevant documents using the retriever
retrieved_docs = retriever.retrieve(inputs["input_ids"])
# Generate the output using RAG
generated = model.generate(input_ids=inputs["input_ids"], context_input_ids=retrieved_docs['context_input_ids'])
# Decode the generated text
summary = tokenizer.decode(generated[0], skip_special_tokens=True)
return summary
except Exception as e:
st.error(f"An error occurred during summarization: {str(e)}")
return None
# Page title and layout
st.set_page_config(page_title='πŸ¦œπŸ”— RAG Text Summarization App')
st.title('πŸ¦œπŸ”— RAG Text Summarization App')
# Text input area for user to input text
txt_input = st.text_area('Enter your text', '', height=200)
# Form to accept the user's text input for summarization
response = None
with st.form('summarize_form', clear_on_submit=True):
submitted = st.form_submit_button('Submit')
if submitted and txt_input:
with st.spinner('Summarizing with RAG...'):
response = generate_response_with_rag(txt_input)
# Display the response if available
if response:
st.info(response)
# Instructions for getting started with Hugging Face models
st.subheader("Hugging Face RAG Summarization")
st.write("""
This app uses Hugging Face's RAG model (Retrieval-Augmented Generation) to generate summaries with relevant external context.
RAG retrieves information from a set of documents and combines that with a generative model to produce more accurate summaries.
""")