Izza-shahzad-13's picture
Update app.py
622dc3f verified
raw
history blame
2.34 kB
import streamlit as st
import pandas as pd
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from transformers import T5ForConditionalGeneration, T5Tokenizer
# Load the Sentence Transformer and T5 model
@st.cache(allow_output_mutation=True)
def load_models():
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
qa_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small")
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
return embedding_model, qa_model, tokenizer
embedding_model, qa_model, tokenizer = load_models()
# Upload and load the CSV file
st.title("Economics & Population Advisor")
uploaded_file = st.file_uploader("Upload your CSV file with economic documents", type=["csv"])
if uploaded_file is not None:
# Load CSV
df = pd.read_csv(uploaded_file, error_bad_lines=False, engine='python')
st.write("Dataset Preview:", df.head())
# Assume 'text' column contains the document text; replace with actual column name
documents = df['Country Name'].tolist() if 'text' in df.columns else st.text_input("Specify the text column name:")
# Create embeddings for FAISS indexing
st.write("Indexing documents...")
embeddings = embedding_model.encode(documents)
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(np.array(embeddings))
st.write("Indexing complete.")
# Function to generate response
def generate_summary(context):
inputs = tokenizer("summarize: " + context, return_tensors="pt", max_length=512, truncation=True)
outputs = qa_model.generate(inputs["input_ids"], max_length=150, min_length=50, length_penalty=2.0)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# RAG functionality: Ask a question, retrieve documents, and generate an answer
st.subheader("Ask a Question about Economic Data")
question = st.text_input("Enter your question:")
if st.button("Get Answer") and question:
question_embedding = embedding_model.encode([question])
D, I = index.search(np.array(question_embedding), k=3)
retrieved_docs = [documents[i] for i in I[0]]
context = " ".join(retrieved_docs)
answer = generate_summary(context)
st.write("Answer:", answer)