Izza-shahzad-13 commited on
Commit
1122e93
·
verified ·
1 Parent(s): 4161b83

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import faiss
4
+ import numpy as np
5
+ from sentence_transformers import SentenceTransformer
6
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
7
+
8
+ # Load the Sentence Transformer and T5 model
9
+ @st.cache(allow_output_mutation=True)
10
+ def load_models():
11
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
12
+ qa_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small")
13
+ tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
14
+ return embedding_model, qa_model, tokenizer
15
+
16
+ embedding_model, qa_model, tokenizer = load_models()
17
+
18
+ # Upload and load the CSV file
19
+ st.title("Economics & Population Advisor")
20
+ uploaded_file = st.file_uploader("Upload your CSV file with economic documents", type=["csv"])
21
+
22
+ if uploaded_file is not None:
23
+ # Load CSV
24
+ df = pd.read_csv(uploaded_file, error_bad_lines=False, engine='python')
25
+ st.write("Dataset Preview:", df.head())
26
+
27
+ # Assume 'text' column contains the document text; replace with actual column name
28
+ documents = df['text'].tolist() if 'text' in df.columns else st.text_input("Specify the text column name:")
29
+
30
+ # Create embeddings for FAISS indexing
31
+ st.write("Indexing documents...")
32
+ embeddings = embedding_model.encode(documents)
33
+ dimension = embeddings.shape[1]
34
+ index = faiss.IndexFlatL2(dimension)
35
+ index.add(np.array(embeddings))
36
+ st.write("Indexing complete.")
37
+
38
+ # Function to generate response
39
+ def generate_summary(context):
40
+ inputs = tokenizer("summarize: " + context, return_tensors="pt", max_length=512, truncation=True)
41
+ outputs = qa_model.generate(inputs["input_ids"], max_length=150, min_length=50, length_penalty=2.0)
42
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
43
+
44
+ # RAG functionality: Ask a question, retrieve documents, and generate an answer
45
+ st.subheader("Ask a Question about Economic Data")
46
+ question = st.text_input("Enter your question:")
47
+
48
+ if st.button("Get Answer") and question:
49
+ question_embedding = embedding_model.encode([question])
50
+ D, I = index.search(np.array(question_embedding), k=3)
51
+ retrieved_docs = [documents[i] for i in I[0]]
52
+ context = " ".join(retrieved_docs)
53
+ answer = generate_summary(context)
54
+
55
+ st.write("Answer:", answer)