Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import json | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel, BertTokenizer, BertModel,T5Tokenizer, T5ForConditionalGeneration,AutoTokenizer, AutoModelForSeq2SeqLM | |
import torch | |
from sklearn.metrics.pairwise import cosine_similarity | |
import numpy as np | |
import nltk | |
from nltk.tokenize import sent_tokenize | |
from nltk.corpus import stopwords | |
def is_new_file_upload(uploaded_file): | |
if 'last_uploaded_file' in st.session_state: | |
# Check if the newly uploaded file is different from the last one | |
if (uploaded_file.name != st.session_state.last_uploaded_file['name'] or | |
uploaded_file.size != st.session_state.last_uploaded_file['size']): | |
st.session_state.last_uploaded_file = {'name': uploaded_file.name, 'size': uploaded_file.size} | |
# st.write("A new src image file has been uploaded.") | |
return True | |
else: | |
# st.write("The same src image file has been re-uploaded.") | |
return False | |
else: | |
# st.write("This is the first file upload detected.") | |
st.session_state.last_uploaded_file = {'name': uploaded_file.name, 'size': uploaded_file.size} | |
return True | |
def combined_similarity(similarity, sentence, query): | |
# Tokenize both the sentence and the query | |
# sentence_words = set(sentence.split()) | |
# query_words = set(query.split()) | |
sentence_words = set(word for word in sentence.split() if word.lower() not in st.session_state.stop_words) | |
query_words = set(word for word in query.split() if word.lower() not in st.session_state.stop_words) | |
# Calculate the number of common words | |
common_words = len(sentence_words.intersection(query_words)) | |
# Adjust the similarity score with the common words count | |
combined_score = similarity + (common_words / max(len(query_words), 1)) # Normalize by the length of the query to keep the score between -1 and 1 | |
return combined_score | |
big_text = """ | |
<div style='text-align: center;'> | |
<h1 style='font-size: 30x;'>Knowledge Extraction A</h1> | |
</div> | |
""" | |
# Display the styled text | |
st.markdown(big_text, unsafe_allow_html=True) | |
uploaded_json_file = st.file_uploader("Upload a pre-processed file", | |
type=['json']) | |
st.markdown( | |
f'<a href="https://ikmtechnology.github.io/ikmtechnology/untethered_extracted_paragraphs.json" target="_blank">Sample 1 download and then upload to above</a>', | |
unsafe_allow_html=True) | |
st.markdown("sample queries for above file: <br/> What is death? What is a lucid dream? What is the seat of consciousness?",unsafe_allow_html=True) | |
if uploaded_json_file is not None: | |
if is_new_file_upload(uploaded_json_file): | |
print("is new file uploaded") | |
save_path = './uploaded_files' | |
if not os.path.exists(save_path): | |
os.makedirs(save_path) | |
with open(os.path.join(save_path, uploaded_json_file.name), "wb") as f: | |
f.write(uploaded_json_file.getbuffer()) # Write the file to the specified location | |
st.success(f'Saved file temp_{uploaded_json_file.name} in {save_path}') | |
st.session_state.uploaded_path=os.path.join(save_path, uploaded_json_file.name) | |
# st.session_state.page_count = utils.get_pdf_page_count(st.session_state.uploaded_pdf_path) | |
# print("page_count=",st.session_state.page_count) | |
content = uploaded_json_file.read() | |
try: | |
st.session_state.restored_paragraphs = json.loads(content) | |
#print(data) | |
# Check if the parsed data is a dictionary | |
if isinstance(st.session_state.restored_paragraphs, list): | |
# Count the restored_paragraphs of top-level elements | |
st.session_state.list_count = len(st.session_state.restored_paragraphs) | |
st.write(f'The number of elements at the top level of the hierarchy: {st.session_state.list_count }') | |
else: | |
st.write('The JSON content is not a dictionary.') | |
except json.JSONDecodeError: | |
st.write('Invalid JSON file.') | |
st.rerun() | |
if 'is_initialized' not in st.session_state: | |
st.session_state['is_initialized'] = True | |
nltk.download('punkt') | |
nltk.download('stopwords') | |
st.session_state.stop_words = set(stopwords.words('english')) | |
st.session_state.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", ) | |
st.session_state.bert_model = BertModel.from_pretrained("bert-base-uncased", ).to('cuda') | |
if 'list_count' in st.session_state: | |
st.write(f'The number of elements at the top level of the hierarchy: {st.session_state.list_count }') | |
if 'paragraph_sentence_encodings' not in st.session_state: | |
print("start embedding paragarphs") | |
read_progress_bar = st.progress(0) | |
st.session_state.paragraph_sentence_encodings = [] | |
for index,paragraph in enumerate(st.session_state.restored_paragraphs): | |
#print(paragraph) | |
progress_percentage = (index) / (st.session_state.list_count - 1) | |
# print(progress_percentage) | |
read_progress_bar.progress(progress_percentage) | |
sentence_encodings = [] | |
sentences = sent_tokenize(paragraph['text']) | |
for sentence in sentences: | |
if sentence.strip().endswith('?'): | |
sentence_encodings.append(None) | |
continue | |
if len(sentence.strip()) < 4: | |
sentence_encodings.append(None) | |
continue | |
sentence_tokens = st.session_state.bert_tokenizer(sentence, return_tensors="pt", padding=True, truncation=True).to('cuda') | |
with torch.no_grad(): | |
sentence_encoding = st.session_state.bert_model(**sentence_tokens).last_hidden_state[:, 0, :].cpu().numpy() | |
sentence_encodings.append([sentence, sentence_encoding]) | |
# sentence_encodings.append([sentence,bert_model(**sentence_tokens).last_hidden_state[:, 0, :].detach().numpy()]) | |
st.session_state.paragraph_sentence_encodings.append([paragraph, sentence_encodings]) | |
st.rerun() | |
if 'paragraph_sentence_encodings' in st.session_state: | |
query = st.text_input("Enter your query") | |
if query: | |
query_tokens = st.session_state.bert_tokenizer(query, return_tensors="pt", padding=True, truncation=True).to('cuda') | |
with torch.no_grad(): # Disable gradient calculation for inference | |
# Perform the forward pass on the GPU | |
query_encoding = st.session_state.bert_model(**query_tokens).last_hidden_state[:, 0, | |
:].cpu().numpy() # Move the result to CPU and convert to NumPy | |
paragraph_scores = [] | |
sentence_scores = [] | |
sentence_encoding = [] | |
total_count=len(st.session_state.paragraph_sentence_encodings) | |
processing_progress_bar = st.progress(0) | |
for index,paragraph_sentence_encoding in enumerate(st.session_state.paragraph_sentence_encodings): | |
progress_percentage = index / (total_count- 1) | |
processing_progress_bar.progress(progress_percentage) | |
best_similarity = -1 | |
sentence_similarities = [] | |
for sentence_encoding in paragraph_sentence_encoding[1]: | |
if sentence_encoding: | |
similarity = cosine_similarity(query_encoding, sentence_encoding[1])[0][0] | |
# adjusted_similarity = similarity*len(sentence_encoding[0].split())**0.5 | |
combined_score = combined_similarity(similarity, sentence_encoding[0], query) | |
# print("sentence="+sentence_encoding[0] + " len="+str()) | |
sentence_similarities.append(combined_score) | |
sentence_scores.append((combined_score, sentence_encoding[0])) | |
# best_similarity = max(best_similarity, similarity) | |
sentence_similarities.sort(reverse=True) | |
# Calculate the average of the top three sentence similarities | |
if len(sentence_similarities) >= 3: | |
top_three_avg_similarity = np.mean(sentence_similarities[:3]) | |
elif sentence_similarities: | |
top_three_avg_similarity = np.mean(sentence_similarities) | |
else: | |
top_three_avg_similarity = 0 | |
paragraph_scores.append((top_three_avg_similarity, paragraph_sentence_encoding[0])) | |
sentence_scores = sorted(sentence_scores, key=lambda x: x[0], reverse=True) | |
# Display the scores and sentences | |
# print("Top scored sentences and their scores:") | |
# for score, sentence in sentence_scores: # Print top 10 for demonstration | |
# print(f"Score: {score:.4f}, Sentence: {sentence}") | |
# Sort the paragraphs by their best similarity score | |
paragraph_scores = sorted(paragraph_scores, key=lambda x: x[0], reverse=True) | |
# Debug prints to understand the scores and paragraphs | |
st.write("Top scored paragraphs and their scores:") | |
for score, paragraph in paragraph_scores[:5]: # Print top 5 for debugging | |
st.write(f"Score: {score}, Paragraph: {paragraph['text']}") |