|
import streamlit as st |
|
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, pipeline |
|
|
|
|
|
def generate_response_with_rag(txt): |
|
try: |
|
|
|
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") |
|
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True) |
|
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq") |
|
|
|
|
|
inputs = tokenizer(txt, return_tensors="pt") |
|
|
|
|
|
retrieved_docs = retriever.retrieve(inputs["input_ids"]) |
|
|
|
|
|
generated = model.generate(input_ids=inputs["input_ids"], context_input_ids=retrieved_docs['context_input_ids']) |
|
|
|
|
|
summary = tokenizer.decode(generated[0], skip_special_tokens=True) |
|
|
|
return summary |
|
except Exception as e: |
|
st.error(f"An error occurred during summarization: {str(e)}") |
|
return None |
|
|
|
|
|
st.set_page_config(page_title='π¦π RAG Text Summarization App') |
|
st.title('π¦π RAG Text Summarization App') |
|
|
|
|
|
txt_input = st.text_area('Enter your text', '', height=200) |
|
|
|
|
|
response = None |
|
with st.form('summarize_form', clear_on_submit=True): |
|
submitted = st.form_submit_button('Submit') |
|
if submitted and txt_input: |
|
with st.spinner('Summarizing with RAG...'): |
|
response = generate_response_with_rag(txt_input) |
|
|
|
|
|
if response: |
|
st.info(response) |
|
|
|
|
|
st.subheader("Hugging Face RAG Summarization") |
|
st.write(""" |
|
This app uses Hugging Face's RAG model (Retrieval-Augmented Generation) to generate summaries with relevant external context. |
|
RAG retrieves information from a set of documents and combines that with a generative model to produce more accurate summaries. |
|
""") |
|
|