Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import fitz | |
import os | |
model = AutoModelForSequenceClassification.from_pretrained("Reem333/Citaion-Classifier") | |
tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096") | |
def extract_text_from_pdf(file_path): | |
text = '' | |
with fitz.open(file_path) as pdf_document: | |
for page_number in range(pdf_document.page_count): | |
page = pdf_document.load_page(page_number) | |
text += page.get_text() | |
return text | |
def predict_class(text): | |
try: | |
max_length = 4096 | |
truncated_text = text[:max_length] | |
inputs = tokenizer(truncated_text, return_tensors="pt", padding=True, truncation=True, max_length=max_length) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predicted_class = torch.argmax(logits, dim=1).item() | |
return predicted_class | |
except Exception as e: | |
st.error(f"Error during prediction: {e}") | |
return None | |
uploaded_files_dir = "uploaded_files" | |
os.makedirs(uploaded_files_dir, exist_ok=True) | |
class_colors = { | |
0: "#d62728", # Level 1 | |
1: "#ff7f0e", # Level 2 | |
2: "#2ca02c", # Level 3 | |
3: "#1f77b4" # Level 4 | |
} | |
st.set_page_config(page_title="Paper Citation Classifier", page_icon="logo.png") | |
with st.sidebar: | |
st.image("logo.png", width=70) | |
st.markdown('<div style="position: absolute; left: 5px;"></div>', unsafe_allow_html=True) | |
st.markdown("# Paper Citation Classifier") | |
st.markdown("---") | |
st.markdown("## About") | |
st.markdown(''' | |
This is a tool to classify paper citations into different levels based on their number of citations. | |
Powered by Fine-Tuned [Longformer model](https://huggingface.co/REEM-ALRASHIDI/LongFormer-Paper-Citaion-Classifier) with custom data. | |
''') | |
st.markdown("### Class Levels:") | |
st.markdown("- Level 1: Highly cited papers") | |
st.markdown("- Level 2: Average cited papers") | |
st.markdown("- Level 3: More cited papers") | |
st.markdown("- Level 4: Low cited papers") | |
st.markdown("---") | |
st.markdown('Tabuk University') | |
st.title("Check Your Paper Now!") | |
option = st.radio("Select input type:", ("Text", "PDF")) | |
if option == "Text": | |
title_input = st.text_area("Enter Title:") | |
abstract_input = st.text_area("Enter Abstract:") | |
full_text_input = st.text_area("Enter Full Text:") | |
affiliations_input = st.text_area("Enter Affiliations:") | |
keywords_input = st.text_area("Enter Keywords:") | |
options=["Nursing", "Physics", "Maths", "Chemical", "Nuclear", "Engineering" ,"Other"] | |
selected_category = st.selectbox("Select WoS categories:", options, index= None) | |
if selected_category == "Other": | |
custom_category = st.text_input("Enter custom category:") | |
selected_category = custom_category if custom_category else "Other" | |
combined_text = f"{title_input} [SEP] {keywords_input} [SEP] {abstract_input} [SEP] {selected_category} [SEP] {affiliations_input} [SEP] {' [SEP] '.join(full_text_input)}" | |
if st.button("Predict"): | |
if not any([title_input, abstract_input,keywords_input, full_text_input, affiliations_input]): | |
st.warning("Please enter paper text.") | |
else: | |
with st.spinner("Predicting..."): | |
predicted_class = predict_class(combined_text) | |
if predicted_class is not None: | |
class_labels = ["Level 1", "Level 2", "Level 3", "Level 4"] | |
st.text("Predicted Class:") | |
for i, label in enumerate(class_labels): | |
if i == predicted_class: | |
st.markdown( | |
f'<div style="background-color: {class_colors[predicted_class]}; padding: 10px; border-radius: 5px; color: white; font-weight: bold;">{label}</div>', | |
unsafe_allow_html=True | |
) | |
else: | |
st.text(label) | |
elif option == "PDF": | |
uploaded_file = st.file_uploader("Upload a PDF file", type=["pdf"]) | |
if uploaded_file is not None: | |
with st.spinner("Processing PDF..."): | |
file_path = os.path.join(uploaded_files_dir, uploaded_file.name) | |
with open(file_path, "wb") as f: | |
f.write(uploaded_file.getbuffer()) | |
st.success("File uploaded successfully.") | |
st.text(f"File Path: {file_path}") | |
file_text = extract_text_from_pdf(file_path) | |
st.text("Extracted Text:") | |
st.text(file_text) | |
if st.button("Predict from PDF Text"): | |
if not file_text.strip(): | |
st.warning("Please upload a PDF with text content.") | |
else: | |
with st.spinner("Predicting..."): | |
predicted_class = predict_class(file_text) | |
if predicted_class is not None: | |
class_labels = ["Level 1", "Level 2", "Level 3", "Level 4"] | |
st.text("**Predicted Class:**") | |
for i, label in enumerate(class_labels): | |
if i == predicted_class: | |
st.markdown( | |
f'<div style="background-color: {class_colors[predicted_class]}; padding: 10px; border-radius: 5px; color: white; font-weight: bold;">{label}</div>', | |
unsafe_allow_html=True | |
) | |
else: | |
st.text(label) | |