Study_Assistant / app.py
Khd-B's picture
Create app.py
2a5fb30 verified
raw
history blame
2.48 kB
import streamlit as st
import PyPDF2
from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np
import faiss
from gtts import gTTS
import os
# Initialize the model and tokenizer
model_name = "sentence-transformers/all-MiniLM-L6-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
# Function to get embeddings
def get_embedding(text):
inputs = tokenizer(text, return_tensors='pt')
with torch.no_grad():
embeddings = model(**inputs).last_hidden_state.mean(dim=1).numpy()
return embeddings
# Initialize FAISS index
embeddings_dimension = 384 # for MiniLM
index = faiss.IndexFlatL2(embeddings_dimension)
# Title of the app
st.title("Study Assistant for Grade 9")
# File uploader widget
uploaded_file = st.file_uploader("Upload your PDF file", type=["pdf"])
if uploaded_file is not None:
# Read the uploaded PDF file
pdf_reader = PyPDF2.PdfReader(uploaded_file)
text = ""
# Extract text from each page
for page in pdf_reader.pages:
text += page.extract_text() if page.extract_text() else ""
st.subheader("Extracted Text:")
st.write(text)
# Generate embedding for the extracted text
embeddings = get_embedding(text)
index.add(embeddings) # Add embedding to the FAISS index
st.success("Text extracted and embeddings generated!")
# Subject selection and query input
subject = st.selectbox("Select Subject", ["Math", "Science", "English"])
query = st.text_input("Type your query")
if st.button("Submit"):
if query:
# Get embedding for the query
query_embedding = get_embedding(query)
# Search for the nearest neighbors in the FAISS index
D, I = index.search(query_embedding, k=5) # Retrieve top 5 matches
st.subheader("Top Matches:")
for idx in I[0]:
if idx < len(embeddings): # Ensure index is valid
st.write(f"Match Index: {idx}, Distance: {D[0][idx]}") # You can display the match details
# Convert response to speech
response_text = f"You asked about '{query}' in {subject}. Here are your top matches."
tts = gTTS(text=response_text, lang='en')
tts.save("response.mp3")
os.system("start response.mp3") # Adjust for different OS
st.success("Response generated and spoken!")
# Note: To handle errors or improve this further, add appropriate try-except blocks and validations.