import re, glob, string
import math
from tqdm import tqdm
from transformers import AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained('gpt2', bos_token='<|startoftext|>', eos_token='<|endoftext|>', pad_token='<|pad|>')
from nltk.tokenize import sent_tokenize

# ----------------------------- Cleaning process 1/2 -----------------------------

def sanitize(line):
	# print('before', line)
	line2 = re.sub(r'\[.+\]','',line)
	# print('after',line2)
	for a in ["January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"]:
		line2 = line2.replace(a,'')
	line2 = re.sub(r'\b[A-Z]+\b','',line2.strip())
	line2 = re.sub(r'\d','',line2)
	line2 = line2.translate(str.maketrans('','',"‟“’❝❞‚‘‛❛❜❟’")) #just removed the quotes
	line2 = line2.translate(str.maketrans('','',string.punctuation))
	line2 = re.sub(r'\s+',' ',line2).strip()
	return line2

def remove_footnotes_and_clean(sents):
	sents = [x.replace("'",'').replace('*','').replace('’®','').replace('’','') for x in sents]
	s = ''
	for line in sents:
		try:
			if line.strip()[-1] != '-':
				s = s + line.strip() + ' '
			else:
				s = s + line.strip()
		except:
			print(sents)
			input()
	s = re.sub(r'\s+',' ',s)
	return s

path = 'text_files/'
ml = sorted(glob.glob(path+'*.txt'))
show = False

path = 'clean_text_files/'
for k,m in enumerate(tqdm(ml, total=len(ml), ncols=100)):
	# m = ml[-1]
	# if k < 67:
	# 	continue
	file = open(m,'r')
	content = file.readlines()
	file.close()

	if show:
		print(m)

	paras = []
	sents = []

	mean_spaces = []
	footnote_found = False

	for line in content:
		line2 = sanitize(line)
		if re.search(r'^\W\s\w',line.strip()):
			footnote_found = True
		if re.search(r'^VOL.*\d\d\d\d.*\d$',line.strip()) or 'THE COLLECTED WORKS OF MAHATMA GANDHI' in line.strip():
			# new page
			footnote_found = False
		
		if len(line2) > 5 and len(line2.split()) > 4 and footnote_found==False:
			if show:
				print(line.rstrip(),end='')
			li_spaces = len(line) - len(line.strip())
			if show:
				print(li_spaces)
			mean_spaces.append(li_spaces)
			# input()
	
	mean_spaces = math.floor(sum(mean_spaces)/len(mean_spaces))
	if show:
		print('ms',mean_spaces)
		print(' '*mean_spaces+'^')
	footnote_found = False
	last_spaces = -1
	i = 0
	while i < len(content)-1:
		# line2 = re.sub(r'[A-Z]','',line.strip())
		# line2 = re.sub(r'\[\w+\]','',line2)
		line = content[i]
		li_spaces = len(line) - len(line.strip())
		if re.search(r'^\W\s\w',line.strip()):
			footnote_found = True
		if re.search(r'^VOL.*\d\d\d\d.*\d$',line.strip()) or 'THE COLLECTED WORKS OF MAHATMA GANDHI' in line.strip():
			# new page
			footnote_found = False
			i+=1
			# print('--',line.rstrip())
			continue
		if footnote_found == False:
			if not (li_spaces > mean_spaces):
				# when the spaces in current line is equal or one tab shy from the mean spaces
				line2 = sanitize(line)
				if len(line2) > 5 and len(line2.split()) > 4:
					if show:
						print('++',line.rstrip())
					sents.append(line)
					last_spaces = li_spaces
				elif last_spaces == li_spaces:
					if show:
						print('++',line.rstrip())
					sents.append(line)
				else:
					last_spaces = -1
					if show:
						print('--',line.rstrip())
			else:
				# the current line has more or less spaces as compared to the mean
				next_line = content[i+1]
				lj_spaces = len(next_line) - len(next_line.strip())
				if not (lj_spaces > mean_spaces):
					# print('b4', line)
					line1 = sanitize(content[i])
					line2 = sanitize(next_line)
					# print('now',line2)
					if len(line1) > 5 and len(line1.split()) > 4 and len(line2) > 5 and len(line2.split()) > 4:
						sent_text = remove_footnotes_and_clean(sents)
						paras.append(sent_text)
						if show:
							print('++',line.rstrip(),'<------NEW PARA')
						sents = [line]
						# print('$$',paras[-1])
					else:
						last_spaces = -1
						if show:
							print('--',line.rstrip())
				else:
					last_spaces = -1
					if show:
						print('--',line.rstrip())
		else:
			last_spaces = -1
			if show:
				print('--',line.rstrip())
		if show:
			input('wait')
		i+=1
		
	file = open(path+m.split('/')[-1],'w')
	file.write('\n'.join(paras[1:]))
	file.close()
	# input('here wait')

	

# ----------------------------- Cleaning process 2/2 -----------------------------
path = 'clean_text_files/'
ml = sorted(glob.glob(path+'*.txt'))

text = []

for m in tqdm(range(1,99)):
	file = open(path+str(m)+'.txt','r')
	text += file.readlines()
	file.close()

file = open('all_paras.txt','w')
file.write(''.join(text))
file.close()

sents = []
tcsents = [] # transformer compatible sents
para_stack = []
for para in tqdm(text):
	para = para.strip()
	sents += sent_tokenize(para)
	para_stack = [para] + para_stack
	while len(para_stack)!=0:
		top_para = para_stack.pop(0)
		if len(tokenizer('<|startoftext|>'+ top_para + '<|endoftext|>')['input_ids']) > 200: # <-------------
			ts = sent_tokenize(top_para)
			if len(ts) > 1:
				para_stack = [' '.join(ts[int(len(ts)/2):])] + para_stack # second half
				para_stack = [' '.join(ts[:int(len(ts)/2)])] + para_stack # first half
			else:
				tcsents.append(top_para)
		else:
			tcsents.append(top_para)


file = open('all_sents.txt','w')
file.write('\n'.join(sents))
file.close()

file = open('all_tc_sents_200.txt','w')
file.write('\n'.join(tcsents))
file.close()