File size: 2,200 Bytes
5e2e02c
9f29f51
5e2e02c
9f29f51
 
5e2e02c
9f29f51
 
 
 
5e2e02c
9f29f51
 
5e2e02c
9f29f51
 
5e2e02c
9f29f51
 
267005e
9f29f51
 
 
 
5e2e02c
 
 
 
267005e
9f29f51
 
5e2e02c
267005e
5e2e02c
 
267005e
5e2e02c
 
 
9f29f51
 
 
5e2e02c
267005e
5e2e02c
 
 
9f29f51
 
5e2e02c
9f29f51
 
5e2e02c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import streamlit as st
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, pipeline

# Function to generate response using RAG (Retrieval-Augmented Generation)
def generate_response_with_rag(txt):
    try:
        # Initialize the RAG model and tokenizer
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
        retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
        model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")

        # Tokenize the input text
        inputs = tokenizer(txt, return_tensors="pt")

        # Retrieve relevant documents using the retriever
        retrieved_docs = retriever.retrieve(inputs["input_ids"])

        # Generate the output using RAG
        generated = model.generate(input_ids=inputs["input_ids"], context_input_ids=retrieved_docs['context_input_ids'])

        # Decode the generated text
        summary = tokenizer.decode(generated[0], skip_special_tokens=True)

        return summary
    except Exception as e:
        st.error(f"An error occurred during summarization: {str(e)}")
        return None

# Page title and layout
st.set_page_config(page_title='πŸ¦œπŸ”— RAG Text Summarization App')
st.title('πŸ¦œπŸ”— RAG Text Summarization App')

# Text input area for user to input text
txt_input = st.text_area('Enter your text', '', height=200)

# Form to accept the user's text input for summarization
response = None
with st.form('summarize_form', clear_on_submit=True):
    submitted = st.form_submit_button('Submit')
    if submitted and txt_input:
        with st.spinner('Summarizing with RAG...'):
            response = generate_response_with_rag(txt_input)

# Display the response if available
if response:
    st.info(response)

# Instructions for getting started with Hugging Face models
st.subheader("Hugging Face RAG Summarization")
st.write("""
This app uses Hugging Face's RAG model (Retrieval-Augmented Generation) to generate summaries with relevant external context.
RAG retrieves information from a set of documents and combines that with a generative model to produce more accurate summaries.
""")