|
import streamlit as st |
|
import pandas as pd |
|
import faiss |
|
import numpy as np |
|
from sentence_transformers import SentenceTransformer |
|
from transformers import T5ForConditionalGeneration, T5Tokenizer |
|
|
|
|
|
@st.cache(allow_output_mutation=True) |
|
def load_models(): |
|
embedding_model = SentenceTransformer('all-MiniLM-L6-v2') |
|
qa_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small") |
|
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small") |
|
return embedding_model, qa_model, tokenizer |
|
|
|
embedding_model, qa_model, tokenizer = load_models() |
|
|
|
|
|
st.title("Economics & Population Advisor") |
|
uploaded_file = st.file_uploader("Upload your CSV file with economic documents", type=["csv"]) |
|
|
|
if uploaded_file is not None: |
|
|
|
df = pd.read_csv(uploaded_file, on_bad_lines='skip', engine='python') |
|
st.write("Dataset Preview:", df.head()) |
|
|
|
|
|
text_column = st.text_input("Specify the column containing the document text:", value="Country Name") |
|
|
|
if text_column not in df.columns: |
|
st.error(f"The column '{text_column}' was not found in the dataset.") |
|
else: |
|
|
|
documents = df[text_column].tolist() |
|
|
|
|
|
st.write("Indexing documents...") |
|
embeddings = embedding_model.encode(documents, convert_to_numpy=True) |
|
dimension = embeddings.shape[1] |
|
|
|
|
|
index = faiss.IndexFlatL2(dimension) |
|
index.add(np.array(embeddings, dtype=np.float32)) |
|
st.write("Indexing complete.") |
|
|
|
|
|
def generate_summary(context): |
|
inputs = tokenizer("summarize: " + context, return_tensors="pt", max_length=512, truncation=True) |
|
outputs = qa_model.generate(inputs["input_ids"], max_length=150, min_length=50, length_penalty=2.0) |
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
st.subheader("Ask a Question about Economic Data") |
|
question = st.text_input("Enter your question:") |
|
|
|
if st.button("Get Answer") and question: |
|
|
|
question_embedding = embedding_model.encode([question], convert_to_numpy=True) |
|
|
|
|
|
D, I = index.search(np.array(question_embedding, dtype=np.float32), k=3) |
|
retrieved_docs = [documents[i] for i in I[0]] |
|
|
|
|
|
context = " ".join(retrieved_docs[:5]) |
|
if len(context) > 1000: |
|
context = context[:1000] |
|
|
|
|
|
answer = generate_summary(context) |
|
st.write("Answer:", answer) |
|
|