|
import pandas as pd |
|
import re |
|
import torch |
|
import threading |
|
from transformers import BertTokenizerFast, DistilBertTokenizer, DistilBertForSequenceClassification |
|
from tqdm import tqdm |
|
import tkinter as tk |
|
from tkinter import filedialog, messagebox, scrolledtext, ttk |
|
from tkinter.font import Font |
|
|
|
|
|
device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu') |
|
|
|
def replace_titles_and_abbreviations(text): |
|
replacements = { |
|
r"Mr\.": "<MR>", r"Ms\.": "<MS>", r"Mrs\.": "<MRS>", r"Dr\.": "<DR>", |
|
r"Prof\.": "<PROF>", r"Rev\.": "<REV>", r"Gen\.": "<GEN>", r"Sen\.": "<SEN>", |
|
r"Rep\.": "<REP>", r"Gov\.": "<GOV>", r"Lt\.": "<LT>", r"Sgt\.": "<SGT>", |
|
r"Capt\.": "<CAPT>", r"Cmdr\.": "<CMDR>", r"Adm\.": "<ADM>", r"Maj\.": "<MAJ>", |
|
r"Col\.": "<COL>", r"St\.": "<ST>", r"Co\.": "<CO>", r"Inc\.": "<INC>", |
|
r"Corp\.": "<CORP>", r"Ltd\.": "<LTD>", r"Jr\.": "<JR>", r"Sr\.": "<SR>", |
|
r"Ph\.D\.": "<PHD>", r"M\.D\.": "<MD>", r"B\.A\.": "<BA>", r"B\.S\.": "<BS>", |
|
r"M\.A\.": "<MA>", r"M\.S\.": "<MS>", r"LL\.B\.": "<LLB>", r"LL\.M\.": "<LLM>", |
|
r"J\.D\.": "<JD>", r"Esq\.": "<ESQ>", |
|
} |
|
for pattern, replacement in replacements.items(): |
|
text = re.sub(pattern, replacement, text) |
|
return text |
|
|
|
def revert_titles_and_abbreviations(text): |
|
replacements = { |
|
"<MR>": "Mr.", "<MS>": "Ms.", "<MRS>": "Mrs.", "<DR>": "Dr.", |
|
"<PROF>": "Prof.", "<REV>": "Rev.", "<GEN>": "Gen.", "<SEN>": "Sen.", |
|
"<REP>": "Rep.", "<GOV>": "Gov.", "<LT>": "Lt.", "<SGT>": "Sgt.", |
|
"<CAPT>": "Capt.", "<CMDR>": "Cmdr.", "<ADM>": "Adm.", "<MAJ>": "Maj.", |
|
"<COL>": "Col.", "<ST>": "St.", "<CO>": "Co.", "<INC>": "Inc.", |
|
"<CORP>": "Corp.", "<LTD>": "Ltd.", "<JR>": "Jr.", "<SR>": "Sr.", |
|
"<PHD>": "Ph.D.", "<MD>": "M.D.", "<BA>": "B.A.", "<BS>": "B.S.", |
|
"<MA>": "M.A.", "<MS>": "M.S.", "<LLB>": "LL.B.", "<LLM>": "LL.M.", |
|
"<JD>": "J.D.", "<ESQ>": "Esq.", |
|
} |
|
for placeholder, original in replacements.items(): |
|
text = re.sub(placeholder, original, text) |
|
return text |
|
|
|
def split_text_by_pauses(text): |
|
text = replace_titles_and_abbreviations(text) |
|
pattern = r'[.!,;?:]' |
|
parts = [part.strip() for part in re.split(pattern, text) if part.strip()] |
|
parts_with_punctuation = [ |
|
part + text[text.find(part) + len(part)] |
|
if text.find(part) + len(part) < len(text) and text[text.find(part) + len(part)] in '.!,;?' |
|
else part for part in parts |
|
] |
|
parts_with_punctuation = [revert_titles_and_abbreviations(part) for part in parts_with_punctuation] |
|
return parts_with_punctuation |
|
|
|
def Process_txt_into_BERT_quotes_input_dataframe(filepath): |
|
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') |
|
|
|
with open(filepath, 'r', encoding='utf-8') as file: |
|
text = file.read() |
|
|
|
sentences = split_text_by_pauses(text) |
|
|
|
data = { |
|
'Text': [], |
|
'Context': [], |
|
'Text start char': [], |
|
'Text end char': [], |
|
'Context start char': [], |
|
'Context end char': [], |
|
'Is Quote': [], |
|
'Speaker': [] |
|
} |
|
|
|
tokenized_text = tokenizer.tokenize(text) |
|
encoded_text = tokenizer.encode_plus(text, add_special_tokens=False, return_offsets_mapping=True) |
|
offsets = encoded_text['offset_mapping'] |
|
|
|
for sentence in sentences: |
|
start_idx, end_idx = text.find(sentence), text.find(sentence) + len(sentence) |
|
start_token_idx = next((i for i, offset in enumerate(offsets) if offset[0] == start_idx), None) |
|
end_token_idx = next((i for i, offset in enumerate(offsets) if offset[1] == end_idx), None) |
|
|
|
if start_token_idx is not None and end_token_idx is not None: |
|
context_start_token_idx = max(0, start_token_idx - 200) |
|
context_end_token_idx = min(len(tokenized_text), end_token_idx + 200) |
|
|
|
context_start_char = offsets[context_start_token_idx][0] |
|
context_end_char = offsets[min(context_end_token_idx, len(offsets) - 1)][1] |
|
|
|
context = text[context_start_char:context_end_char] |
|
|
|
data['Text'].append(sentence) |
|
data['Context'].append(context) |
|
data['Text start char'].append(start_idx) |
|
data['Text end char'].append(end_idx) |
|
data['Context start char'].append(context_start_char) |
|
data['Context end char'].append(context_end_char) |
|
data['Is Quote'].append('') |
|
data['Speaker'].append('') |
|
|
|
df = pd.DataFrame(data) |
|
|
|
return df |
|
|
|
def predict_quote(context, text, model_checkpoint_path="./quotation_identifer_model/checkpoint-1000"): |
|
formatted_input = f"{context} : Is Sentence Quote : {text}" |
|
|
|
model = DistilBertForSequenceClassification.from_pretrained(model_checkpoint_path).to(device) |
|
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') |
|
|
|
tokenized_input = tokenizer(formatted_input, padding="max_length", truncation=True, max_length=512, return_tensors="pt").to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**tokenized_input) |
|
|
|
predicted_label = torch.argmax(outputs.logits).item() |
|
|
|
label_encoder = {0: "Not a Quote", 1: "Quote"} |
|
|
|
return label_encoder[predicted_label] == "Quote" |
|
|
|
def fill_is_quote_column(df, model_checkpoint_path="./quotation_identifer_model/checkpoint-1000"): |
|
if 'Is Quote' not in df.columns: |
|
df['Is Quote'] = None |
|
|
|
tqdm.pandas(desc="Processing rows", unit="row") |
|
|
|
for index, row in tqdm(df.iterrows(), total=len(df)): |
|
context = row['Context'] |
|
text = row['Text'] |
|
df.at[index, 'Is Quote'] = predict_quote(context, text, model_checkpoint_path) |
|
|
|
return df |
|
|
|
def transfer_quotes(complete_df, incomplete_df): |
|
for index, row in complete_df.iterrows(): |
|
is_quote = row['Is Quote'] |
|
if pd.notna(is_quote): |
|
incomplete_df.at[index, 'Is Quote'] = is_quote |
|
|
|
return incomplete_df |
|
|
|
def visualize_quotes(df, is_dark_mode=False): |
|
root = tk.Toplevel() |
|
root.title("Text Visualization") |
|
root.geometry("800x600") |
|
|
|
style = ttk.Style(root) |
|
style.theme_use('clam') |
|
|
|
main_frame = ttk.Frame(root, padding="20") |
|
main_frame.pack(fill=tk.BOTH, expand=True) |
|
|
|
title_font = Font(family="Helvetica", size=24, weight="bold") |
|
title_label = ttk.Label(main_frame, text="Quote Visualization (Identified quotes are highlighted in blue)", font=title_font) |
|
title_label.pack(pady=(0, 20)) |
|
|
|
text_box = scrolledtext.ScrolledText(main_frame, width=80, height=30, wrap=tk.WORD, font=("Helvetica", 12)) |
|
text_box.pack(fill=tk.BOTH, expand=True) |
|
|
|
def set_color_scheme(is_dark): |
|
if is_dark: |
|
style.configure("TFrame", background="#2c2c2c") |
|
style.configure("TLabel", background="#2c2c2c", foreground="white") |
|
text_box.config(bg="#2c2c2c", fg="white", insertbackground="white") |
|
text_box.tag_configure('quote', background='#4a86e8', foreground='white') |
|
root.configure(bg="#2c2c2c") |
|
else: |
|
style.configure("TFrame", background="#f0f0f0") |
|
style.configure("TLabel", background="#f0f0f0", foreground="black") |
|
text_box.config(bg="white", fg="black", insertbackground="black") |
|
text_box.tag_configure('quote', background='#4a86e8', foreground='black') |
|
root.configure(bg="#f0f0f0") |
|
|
|
def highlight_text(): |
|
text_box.delete('1.0', tk.END) |
|
for _, row in df.iterrows(): |
|
text = row['Text'] |
|
is_quote = row['Is Quote'] |
|
if is_quote: |
|
text_box.insert(tk.END, text + "\n", 'quote') |
|
else: |
|
text_box.insert(tk.END, text + "\n") |
|
|
|
set_color_scheme(is_dark_mode) |
|
highlight_text() |
|
|
|
root.mainloop() |
|
|
|
class QuoteIdentifierApp: |
|
def __init__(self, master): |
|
self.master = master |
|
self.master.title("Quote Identifier") |
|
self.master.geometry("600x450") |
|
self.master.resizable(False, False) |
|
|
|
self.style = ttk.Style() |
|
self.style.theme_use('clam') |
|
|
|
self.is_dark_mode = False |
|
self.create_widgets() |
|
self.set_light_mode() |
|
|
|
def create_widgets(self): |
|
self.main_frame = ttk.Frame(self.master, padding="20") |
|
self.main_frame.pack(fill=tk.BOTH, expand=True) |
|
|
|
title_font = Font(family="Helvetica", size=24, weight="bold") |
|
title_label = ttk.Label(self.main_frame, text="Quote Identifier", font=title_font) |
|
title_label.pack(pady=(0, 20)) |
|
|
|
btn_frame = ttk.Frame(self.main_frame) |
|
btn_frame.pack(fill=tk.X, pady=10) |
|
|
|
self.open_file_btn = ttk.Button(btn_frame, text="Open Text File", command=self.open_file, style="AccentButton.TButton") |
|
self.open_file_btn.pack(side=tk.LEFT, padx=(0, 10)) |
|
|
|
self.identify_quotes_btn = ttk.Button(btn_frame, text="Run Identify Quotes", command=self.identify_quotes, style="AccentButton.TButton") |
|
self.identify_quotes_btn.pack(side=tk.LEFT) |
|
|
|
self.dark_mode_btn = ttk.Button(self.main_frame, text="Toggle Dark Mode", command=self.toggle_dark_mode, style="TButton") |
|
self.dark_mode_btn.pack(pady=10) |
|
|
|
self.status_label = ttk.Label(self.main_frame, text="Ready", font=("Helvetica", 12)) |
|
self.status_label.pack(pady=10) |
|
|
|
self.progress_bar = ttk.Progressbar(self.main_frame, orient=tk.HORIZONTAL, length=300, mode='determinate') |
|
self.progress_bar.pack(pady=10) |
|
|
|
def set_light_mode(self): |
|
self.style.configure("TFrame", background="#f0f0f0") |
|
self.style.configure("TButton", background="#e0e0e0", foreground="black") |
|
self.style.configure("AccentButton.TButton", background="#4a86e8", foreground="white") |
|
self.style.configure("TLabel", background="#f0f0f0", foreground="black") |
|
self.master.configure(bg="#f0f0f0") |
|
self.is_dark_mode = False |
|
|
|
def set_dark_mode(self): |
|
self.style.configure("TFrame", background="#2c2c2c") |
|
self.style.configure("TButton", background="#3c3c3c", foreground="white") |
|
self.style.configure("AccentButton.TButton", background="#4a86e8", foreground="white") |
|
self.style.configure("TLabel", background="#2c2c2c", foreground="white") |
|
self.master.configure(bg="#2c2c2c") |
|
self.is_dark_mode = True |
|
|
|
def toggle_dark_mode(self): |
|
if self.is_dark_mode: |
|
self.set_light_mode() |
|
else: |
|
self.set_dark_mode() |
|
|
|
def open_file(self): |
|
filepath = filedialog.askopenfilename(filetypes=[("Text files", "*.txt")]) |
|
if filepath: |
|
self.status_label.config(text=f"File selected: {filepath}") |
|
self.filepath = filepath |
|
else: |
|
self.status_label.config(text="No file selected") |
|
|
|
def identify_quotes(self): |
|
if hasattr(self, 'filepath'): |
|
self.status_label.config(text="Processing... Please wait.") |
|
self.progress_bar['value'] = 0 |
|
self.master.update() |
|
|
|
def process_quotes(): |
|
df = Process_txt_into_BERT_quotes_input_dataframe(self.filepath) |
|
df = self.fill_is_quote_column_with_progress(df) |
|
self.master.after(0, lambda: self.finish_processing(df)) |
|
|
|
threading.Thread(target=process_quotes, daemon=True).start() |
|
else: |
|
messagebox.showwarning("No File Selected", "Please select a text file first.") |
|
|
|
def fill_is_quote_column_with_progress(self, df): |
|
if 'Is Quote' not in df.columns: |
|
df['Is Quote'] = None |
|
|
|
total_rows = len(df) |
|
for index, row in enumerate(tqdm(df.iterrows(), total=total_rows, desc="Processing rows", unit="row")): |
|
context = row[1]['Context'] |
|
text = row[1]['Text'] |
|
df.at[index, 'Is Quote'] = predict_quote(context, text) |
|
|
|
progress = (index + 1) / total_rows * 100 |
|
self.master.after(0, lambda p=progress: self.update_progress(p)) |
|
|
|
return df |
|
|
|
def update_progress(self, value): |
|
self.progress_bar['value'] = value |
|
self.master.update_idletasks() |
|
|
|
def finish_processing(self, df): |
|
self.progress_bar['value'] = 100 |
|
self.status_label.config(text="Quote identification complete!") |
|
visualize_quotes(df, self.is_dark_mode) |
|
|
|
def create_gui(): |
|
root = tk.Tk() |
|
app = QuoteIdentifierApp(root) |
|
root.mainloop() |
|
|
|
if __name__ == "__main__": |
|
create_gui() |