Study_Assistant / app.py
Khd-B's picture
Update app.py
9f1a422 verified
raw
history blame
2.4 kB
import streamlit as st
import PyPDF2
from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np
import faiss
from gtts import gTTS
import os
# Initialize the model and tokenizer
model_name = "sentence-transformers/all-MiniLM-L6-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
# Function to get embeddings
def get_embedding(text):
inputs = tokenizer(text, return_tensors='pt')
with torch.no_grad():
embeddings = model(**inputs).last_hidden_state.mean(dim=1).numpy()
return embeddings
# Initialize FAISS index
embeddings_dimension = 384 # for MiniLM
index = faiss.IndexFlatL2(embeddings_dimension)
# Title of the app
st.title("Study Assistant for Grade 9")
# File uploader widget
uploaded_file = st.file_uploader("Upload your PDF file", type=["pdf"])
if uploaded_file is not None:
# Read the uploaded PDF file
pdf_reader = PyPDF2.PdfReader(uploaded_file)
text = ""
# Extract text from each page
for page in pdf_reader.pages:
text += page.extract_text() if page.extract_text() else ""
st.subheader("Extracted Text:")
st.write(text)
# Generate embedding for the extracted text
embeddings = get_embedding(text)
index.add(embeddings) # Add embedding to the FAISS index
st.success("Text extracted and embeddings generated!")
# Subject selection and query input
subject = st.selectbox("Select Subject", ["Math", "Science", "English"])
query = st.text_input("Type your query")
if st.button("Submit"):
if query:
# Get embedding for the query
query_embedding = get_embedding(query)
# Search for the nearest neighbors in the FAISS index
D, I = index.search(query_embedding, k=5) # Retrieve top 5 matches
st.subheader("Top Matches:")
for idx in I[0]:
if idx < len(embeddings): # Ensure index is valid
st.write(f"Match Index: {idx}, Distance: {D[0][idx]}") # Display match details
# Convert response to speech
response_text = f"You asked about '{query}' in {subject}. Here are your top matches."
tts = gTTS(text=response_text, lang='en')
tts.save("response.mp3")
# Display audio controls (Streamlit doesn't support direct playback)
st.audio("response.mp3")
st.success("Response generated!")