Spaces:
Sleeping
Sleeping
from typing import Optional, Dict | |
import streamlit as st | |
import os | |
from dotenv import load_dotenv | |
import torch | |
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer | |
from torch.nn.functional import softmax | |
from doctr.models import ocr_predictor | |
from doctr.io import DocumentFile | |
from functionbloom import save_uploaded_file, get_pdf_path, extract_text_pymupdf, get_bloom_taxonomy_scores,generate_ai_response,normalize_bloom_weights, generate_pdf,process_pdf_and_generate_questions,get_bloom_taxonomy_details | |
from functionbloom import predict_with_loaded_model, process_document, sendtogemini | |
load_dotenv() | |
model = DistilBertForSequenceClassification.from_pretrained('./fine_tuned_distilbert') | |
tokenizer = DistilBertTokenizer.from_pretrained('./fine_tuned_distilbert') | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
mapping = {"Remembering": 0, "Understanding": 1, "Applying": 2, "Analyzing": 3, "Evaluating": 4, "Creating": 5} | |
reverse_mapping = {v: k for k, v in mapping.items()} | |
modelocr = ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True) | |
def main(): | |
st.set_page_config(page_title="Academic Paper Tool", page_icon="π", layout="wide") | |
# Tabs for different functionalities | |
st.markdown(""" | |
<style> | |
.stTabs [data-baseweb="tab"] { | |
margin-bottom: 1rem; | |
flex: 1; | |
justify-content: center; | |
} | |
.stTabs [data-baseweb="tab-list"] button [data-testid="stMarkdownContainer"] p { | |
font-size: 2rem; | |
padding: 0 2rem; | |
font-weight: bold; | |
margin: 0; | |
} | |
/* Information Button Styling */ | |
.info-button { | |
background-color: #f0f2f6; | |
border: 1px solid #4a6cf7; | |
border-radius: 50%; | |
width: 24px; | |
height: 24px; | |
display: inline-flex; | |
align-items: center; | |
justify-content: center; | |
cursor: pointer; | |
margin-left: 8px; | |
font-weight: bold; | |
color: #4a6cf7; | |
} | |
/* Modal Styling */ | |
.modal { | |
display: none; | |
position: fixed; | |
z-index: 1000; | |
left: 0; | |
top: 0; | |
width: 100%; | |
height: 100%; | |
overflow: auto; | |
background-color: rgba(0,0,0,0.4); | |
} | |
.modal-content { | |
background-color: #fefefe; | |
margin: 15% auto; | |
padding: 20px; | |
border: 1px solid #888; | |
width: 80%; | |
max-width: 500px; | |
border-radius: 10px; | |
box-shadow: 0 4px 6px rgba(0,0,0,0.1); | |
} | |
.close-button { | |
color: #aaa; | |
float: right; | |
font-size: 28px; | |
font-weight: bold; | |
cursor: pointer; | |
} | |
.close-button:hover, | |
.close-button:focus { | |
color: black; | |
text-decoration: none; | |
cursor: pointer; | |
} | |
/* Question Container Styling */ | |
.question-container { | |
display: flex; | |
align-items: start; | |
gap: 10px; | |
margin-bottom: 10px; | |
} | |
/* Info Button Styling */ | |
.info-button { | |
background-color: #f0f2f6; | |
border: 1px solid #4a6cf7; | |
border-radius: 50%; | |
width: 24px; | |
height: 24px; | |
display: inline-flex; | |
align-items: center; | |
justify-content: center; | |
cursor: pointer; | |
font-weight: bold; | |
color: #4a6cf7; | |
flex-shrink: 0; | |
font-size: 14px; | |
} | |
.info-button:hover { | |
background-color: #4a6cf7; | |
color: white; | |
} | |
/* Modal Styling */ | |
.modal { | |
display: none; | |
position: fixed; | |
z-index: 9999; | |
left: 0; | |
top: 0; | |
width: 100%; | |
height: 100%; | |
background-color: rgba(0,0,0,0.4); | |
} | |
.modal-content { | |
background-color: #fefefe; | |
margin: 15% auto; | |
padding: 20px; | |
border: 1px solid #888; | |
width: 80%; | |
max-width: 500px; | |
border-radius: 10px; | |
box-shadow: 0 4px 6px rgba(0,0,0,0.1); | |
position: relative; | |
} | |
.close-button { | |
position: absolute; | |
right: 10px; | |
top: 5px; | |
color: #aaa; | |
font-size: 28px; | |
font-weight: bold; | |
cursor: pointer; | |
} | |
.close-button:hover, | |
.close-button:focus { | |
color: black; | |
text-decoration: none; | |
cursor: pointer; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
tab1, tab2 = st.tabs(["Question Generator", "Paper Scorer"]) | |
if 'totalscore' not in st.session_state: | |
st.session_state.totalscore = None | |
if 'show_details' not in st.session_state: | |
st.session_state.show_details = False | |
if 'question_scores' not in st.session_state: | |
st.session_state.question_scores = {} | |
# Question Generator Tab | |
with tab1: | |
st.markdown("<h1 style='font-size: 28px;'>π Academic Paper Question Generator</h1>", unsafe_allow_html=True) | |
st.markdown("Generate insightful questions from academic papers using Bloom's Taxonomy") | |
# Initialize session state variables with defaults | |
if 'pdf_source_type' not in st.session_state: | |
st.session_state.pdf_source_type = "URL" | |
if 'pdf_url' not in st.session_state: | |
st.session_state.pdf_url = "" | |
if 'uploaded_file' not in st.session_state: | |
st.session_state.uploaded_file = None | |
if 'questions' not in st.session_state: | |
st.session_state.questions = [] | |
if 'accepted_questions' not in st.session_state: | |
st.session_state.accepted_questions = [] | |
# API Configuration | |
api_key = os.getenv('GEMINI_API_KEY') | |
# Main form for PDF and question generation | |
with st.form(key='pdf_generation_form'): | |
st.subheader("PDF Source") | |
st.session_state.pdf_url = st.text_input( | |
"Enter the URL of the PDF", | |
value=st.session_state.pdf_url, | |
key="pdf_url_input" | |
) | |
st.markdown("<h4 style='text-align: center;'>OR</h4>", unsafe_allow_html=True) | |
st.session_state.uploaded_file = st.file_uploader( | |
"Upload a PDF file", | |
type=['pdf'], | |
key="pdf_file_upload" | |
) | |
st.session_state.user_input=st.text_area("Enter your query here", key="input", height=100) | |
# Question Length Selection | |
question_length = st.select_slider( | |
"Select Question Length", | |
options=["Short", "Medium", "Long"], | |
value="Medium", | |
help="Short: 10-15 words, Medium: 20-25 words, Long: 30-40 words" | |
) | |
st.session_state.include_numericals = st.checkbox("Include Numericals", key="include_numericals_checkbox") | |
# Bloom's Taxonomy Weights | |
st.subheader("Adjust Bloom's Taxonomy Weights") | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
knowledge = st.slider("Knowledge: Remembering", 0, 100, 20, key='knowledge_slider') | |
application = st.slider("Applying: Using abstractions in concrete situations", 0, 100, 20, key='application_slider') | |
with col2: | |
comprehension = st.slider("Understanding: Explaining the meaning of information", 0, 100, 20, key='comprehension_slider') | |
analysis = st.slider("Analyzing: Breaking down a whole into component parts", 0, 100, 20, key='analysis_slider') | |
with col3: | |
synthesis = st.slider("Creating: Putting parts together to form a new and integrated whole", 0, 100, 10, key='synthesis_slider') | |
evaluation = st.slider("Evaluation: Making and defending judgments based on internal evidence or external criteria", 0, 100, 10, key='evaluation_slider') | |
# Collect the Bloom's Taxonomy weights | |
bloom_taxonomy_weights = { | |
"Knowledge": knowledge, | |
"Comprehension": comprehension, | |
"Application": application, | |
"Analysis": analysis, | |
"Synthesis": synthesis, | |
"Evaluation": evaluation | |
} | |
# Number of questions | |
num_questions = st.slider("How many questions would you like to generate?", min_value=1, max_value=20, value=5, key='num_questions_slider') | |
# Submit button within the form | |
submit_button = st.form_submit_button(label='Generate Questions') | |
# Process form submission | |
if submit_button: | |
# Validate API key | |
if not api_key: | |
st.error("Please enter a valid Gemini API key.") | |
# Validate PDF source | |
elif not st.session_state.pdf_url and not st.session_state.uploaded_file: | |
st.error("Please enter a PDF URL or upload a PDF file.") | |
else: | |
# Normalize the Bloom's weights | |
normalized_bloom_weights = normalize_bloom_weights(bloom_taxonomy_weights) | |
st.info("Normalized Bloom's Taxonomy Weights:") | |
st.json(normalized_bloom_weights) | |
# Role and instructions for the AI | |
role_description = "You are a question-generating AI agent, given context and instruction, you need to generate questions from the context." | |
response_instructions = "Please generate questions that are clear and relevant to the content of the paper. Generate questions which are separated by new lines, without any numbering or additional context." | |
# Generate questions | |
with st.spinner('Generating questions...'): | |
st.session_state.questions = process_pdf_and_generate_questions( | |
pdf_source=st.session_state.pdf_url if st.session_state.pdf_url else None, | |
uploaded_file=st.session_state.uploaded_file if st.session_state.uploaded_file else None, | |
api_key=api_key, | |
role_description=role_description, | |
response_instructions=response_instructions, | |
bloom_taxonomy_weights=normalized_bloom_weights, | |
num_questions=num_questions, | |
question_length=question_length, | |
include_numericals=st.session_state.include_numericals, | |
user_input=st.session_state.user_input | |
) | |
if st.session_state.questions: | |
st.header("Generated Questions") | |
# Create a form for question management to prevent reload | |
with st.form(key='questions_form'): | |
for idx, question in enumerate(st.session_state.questions, 1): | |
cols = st.columns([4, 1]) # Create two columns | |
with cols[0]: | |
# Display the question | |
st.write(f"Q{idx}: {question}") | |
# Add info button using Streamlit's expander | |
with st.expander("Show Bloom's Taxonomy Details"): | |
taxonomy_details = get_bloom_taxonomy_details(st.session_state.question_scores.get(question)) | |
st.text(taxonomy_details) | |
# Use radio buttons for selection | |
with cols[1]: | |
selected_option = st.radio( | |
f"Select an option for Q{idx}", | |
["Accept", "Discard"], | |
key=f"radio_{idx}", | |
index=1 | |
) | |
# Handle radio button state changes | |
if selected_option == "Accept": | |
if question not in st.session_state.accepted_questions: | |
st.session_state.accepted_questions.append(question) | |
else: | |
if question in st.session_state.accepted_questions: | |
st.session_state.accepted_questions.remove(question) | |
# Submit button for question selection | |
submit_questions = st.form_submit_button("Update Accepted Questions") | |
# Show accepted questions | |
if st.session_state.accepted_questions: | |
st.header("Accepted Questions") | |
for q in st.session_state.accepted_questions: | |
st.write(q) | |
# Download button for accepted questions | |
if st.button("Download Accepted Questions as PDF"): | |
filename = generate_pdf(st.session_state.accepted_questions, filename="accepted_questions.pdf") | |
if filename: | |
with open(filename, "rb") as pdf_file: | |
st.download_button( | |
label="Click to Download PDF", | |
data=pdf_file, | |
file_name="accepted_questions.pdf", | |
mime="application/pdf" | |
) | |
st.success("PDF generated successfully!") | |
else: | |
st.info("No questions selected yet.") | |
# Add some footer information | |
st.markdown("---") | |
st.markdown(""" | |
### About this Tool | |
- Generate academic paper questions using Bloom's Taxonomy | |
- Customize question generation weights | |
- Select and refine generated questions | |
- Support for PDF via URL or local upload | |
""") | |
with tab2: | |
st.markdown("<h1 style='font-size: 28px;'>π Academic Paper Scorer</h1>", unsafe_allow_html=True) | |
st.markdown("Evaluate the Quality of Your Academic Paper") | |
# Create a styled container for the upload section | |
st.markdown(""" | |
<style> | |
.upload-container { | |
background-color: #f0f2f6; | |
border-radius: 10px; | |
padding: 20px; | |
border: 2px dashed #4a6cf7; | |
text-align: center; | |
} | |
.score-breakdown { | |
background-color: #f8f9fa; | |
border-radius: 8px; | |
padding: 15px; | |
margin-bottom: 15px; | |
} | |
.score-header { | |
font-weight: bold; | |
color: #4a6cf7; | |
margin-bottom: 10px; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
with st.form(key='paper_scorer_form'): | |
st.header("Upload Your Academic Paper") | |
uploaded_file = st.file_uploader( | |
"Choose a PDF file", | |
type=['pdf','jpg','png','jpeg'], | |
label_visibility="collapsed" | |
) | |
st.markdown("<div style='text-align: center; margin-top: 20px;'><strong>OR</strong></div>", unsafe_allow_html=True) | |
if 'question_typed' not in st.session_state: | |
st.session_state.question_typed = "" | |
st.text_area("Paste your question here", value=st.session_state.question_typed, key="question_typed") | |
question_typed = st.session_state.question_typed | |
submit_button = st.form_submit_button( | |
"Score Paper", | |
use_container_width=True, | |
type="primary" | |
) | |
if submit_button: | |
# Calculate total score | |
pdf_path = save_uploaded_file(uploaded_file) | |
dummydata = sendtogemini(inputpath=pdf_path, question=st.session_state.question_typed) | |
#print(dummydata) | |
total_score = {'Remembering': 0, 'Understanding': 0, 'Applying': 0, 'Analyzing': 0, 'Evaluating': 0, 'Creating': 0} | |
for item in dummydata: | |
for category in total_score: | |
total_score[category] += item['score'][category] | |
# average_score = total_score / (len(dummydata) * 6 * 10) * 100 | |
# Score display columns | |
categories = ['Remembering', 'Understanding', 'Applying', 'Analyzing', 'Evaluating', 'Creating'] | |
# Create 6 columns in a single row | |
cols = st.columns(6) | |
# Iterate through categories and populate columns | |
for i, category in enumerate(categories): | |
with cols[i]: | |
score = round(total_score[category] / (len(dummydata) ),ndigits=3) | |
color = 'green' if score > .7 else 'orange' if score > .4 else 'red' | |
st.markdown(f""" | |
<div class="score-breakdown"> | |
<div class="score-header" style="color: {color}">{category}</div> | |
<div style="font-size: 24px; color: {color};">{score}/{len(dummydata)}</div> | |
</div> | |
""", unsafe_allow_html=True) | |
with st.expander("Show Detailed Scores", expanded=True): | |
for idx, item in enumerate(dummydata, 1): | |
# Question header | |
st.markdown(f'<div class="score-header">Question {idx}: {item["question"]}</div>', unsafe_allow_html=True) | |
# Create columns for score display | |
score_cols = st.columns(6) | |
# Scoring categories | |
categories = ['Remembering', 'Understanding', 'Applying', 'Analyzing', 'Evaluating', 'Creating'] | |
for col, category in zip(score_cols, categories): | |
with col: | |
# Determine color based on score | |
score = round(item['score'][category],ndigits=3) | |
color = 'green' if score > .7 else 'orange' if score > .3 else 'red' | |
st.markdown(f""" | |
<div style="text-align: center; | |
background-color: #f1f1f1; | |
border-radius: 5px; | |
padding: 5px; | |
margin-bottom: 5px;"> | |
<div style="font-weight: bold; color: {color};">{category}</div> | |
<div style="font-size: 18px; color: {color};">{score}/1</div> | |
</div> | |
""", unsafe_allow_html=True) | |
st.markdown('</div>', unsafe_allow_html=True) | |
# Add a separator between questions | |
if idx < len(dummydata): | |
st.markdown('---') | |
# Run Streamlit app | |
if __name__ == "__main__": | |
main() | |