Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
""" | |
Created on Thu May 19 13:22:32 2022 | |
@author: UTKARSH | |
""" | |
import glob | |
import os | |
import pandas as pd | |
import torch | |
from torch.utils.data import Dataset, DataLoader | |
from sklearn.model_selection import train_test_split | |
from src.clean import clean_license_text | |
from tqdm.auto import tqdm | |
from transformers import ( | |
AdamW, | |
T5ForConditionalGeneration, | |
T5TokenizerFast as T5Tokenizer, | |
AutoModelForSeq2SeqLM, | |
AutoTokenizer | |
) | |
from src.read_data import read_license_summary_data | |
MODEL_PATH = "models/" | |
MODEL_FILENAME = "t5-base.model" | |
MODEL_NAME = "t5-base" | |
TOKENIZER = None | |
TEXT_MAX_TOKEN_LEN = 512 | |
SUMMARY_MAX_TOKEN_LEN = 128 | |
N_EPOCHS = 1 | |
BATCH_SIZE = 1 | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
class LicenseSummaryDataset(Dataset): | |
def __init__( | |
self, | |
data: pd.DataFrame, | |
tokenizer: T5Tokenizer, | |
text_max_token_len: int=512, | |
summary_max_token_len: int=128 | |
): | |
self.tokenizer = tokenizer | |
self.data = data | |
self.text_max_token_len = text_max_token_len | |
self.summary_max_token_len = summary_max_token_len | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, index: int): | |
data_row = self.data.iloc[index] | |
text = data_row["text"] | |
text_encoding = self.tokenizer( | |
text, | |
max_length=self.text_max_token_len, | |
padding="max_length", | |
truncation=True, | |
return_attention_mask=True, | |
add_special_tokens=True, | |
return_tensors="pt" | |
) | |
summary_encoding = self.tokenizer( | |
data_row["summary"], | |
max_length=self.summary_max_token_len, | |
padding="max_length", | |
truncation=True, | |
return_attention_mask=True, | |
add_special_tokens=True, | |
return_tensors="pt" | |
) | |
labels = summary_encoding["input_ids"] | |
labels[labels == 0] = -100 | |
return dict( | |
text=text, | |
summary=data_row["summary"], | |
text_input_ids=text_encoding["input_ids"].flatten(), | |
text_attention_mask=text_encoding["attention_mask"].flatten(), | |
labels=labels.flatten(), | |
labels_attention_mask=summary_encoding["attention_mask"].flatten() | |
) | |
def prepare_dataloaders(): | |
""" | |
Helper method to load data and create batched Dataloaders | |
Returns | |
------- | |
train_dataloader : DataLoader | |
Train DataLoader. | |
dev_dataloader : DataLoader | |
Validation DataLoader. | |
""" | |
license_summary_data = pd.DataFrame(read_license_summary_data()) | |
train_df, dev_df = train_test_split(license_summary_data, test_size=0.1) | |
TOKENIZER = T5Tokenizer.from_pretrained(MODEL_NAME) | |
train_dataset = LicenseSummaryDataset( | |
train_df, | |
TOKENIZER, | |
TEXT_MAX_TOKEN_LEN, | |
SUMMARY_MAX_TOKEN_LEN | |
) | |
dev_dataset = LicenseSummaryDataset( | |
dev_df, | |
TOKENIZER, | |
TEXT_MAX_TOKEN_LEN, | |
SUMMARY_MAX_TOKEN_LEN | |
) | |
train_dataloader = DataLoader( | |
train_dataset, | |
batch_size=BATCH_SIZE, | |
shuffle=True, | |
num_workers=0 | |
) | |
dev_dataloader = DataLoader( | |
dev_dataset, | |
batch_size=BATCH_SIZE, | |
shuffle=True, | |
num_workers=0 | |
) | |
return train_dataloader, dev_dataloader | |
def train(epoch, model, dataloader, optimizer, batch_size): | |
""" | |
Trains the given model on the given data for the given number of epochs. | |
Parameters | |
---------- | |
epoch : int | |
The epoch number for which the model is being trained. | |
model : Summarizer | |
A summarizer model which we train. | |
dataloader : torch.utils.data.DataLoader | |
The dataloader on which the model is to be trained. | |
optimizer : transformers.AdamW | |
The optimizer to be used to optimize weights during training. | |
batch_size : int | |
The size of each batch as set in the dataloader. | |
""" | |
model.train() | |
total_train_loss = 0 | |
for _, batch in tqdm(enumerate(dataloader)): | |
model.zero_grad() | |
print(_) | |
input_ids = batch["text_input_ids"].to(device, dtype=torch.long) | |
attention_mask = batch["text_attention_mask"].to(device, dtype=torch.long) | |
labels = batch["labels"].to(device, dtype=torch.long) | |
labels_attention_mask = batch["labels_attention_mask"].to(device, dtype=torch.long) | |
model_output = model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
decoder_attention_mask=labels_attention_mask, | |
labels=labels | |
) | |
# loss, _ = model_output.loss, model_output.logits | |
loss = model_output.loss | |
total_train_loss += loss.item() | |
loss.backward() | |
# torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
optimizer.step() | |
avg_train_loss = total_train_loss / len(dataloader) | |
print(f"Epoch {epoch}: Training loss: {avg_train_loss}") | |
def train_and_save_model(train_dataloader, PATH): | |
""" | |
Trains a summarizer model from the given Dataloader and saves it at the | |
given path | |
Parameters | |
---------- | |
train_dataloader : Dataloader | |
Batched Training Dataloader. | |
PATH : str | |
Path where the trained model is to be saved. | |
Returns | |
------- | |
model : Summarizer / torch.nn.Module | |
Trained model. | |
""" | |
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict=True).to(device) | |
optimizer = AdamW(model.parameters(), lr=3e-5) | |
for epoch in range(1, N_EPOCHS + 1): | |
train(epoch, model, train_dataloader, optimizer, BATCH_SIZE) | |
torch.save(model.state_dict(), PATH) | |
print("Model Saved!") | |
return model | |
def summarize_text_with_model(text, model, tokenizer): | |
""" | |
Summarizes License text using the given trained T5 model. | |
Parameters | |
---------- | |
text : str | |
The License text to be summarized. | |
model: Summarizer / torch.nn.Module | |
The trained model which is to be used to summarize text | |
Returns | |
------- | |
str | |
Summary of the License text from the given model. | |
""" | |
text, definitions = clean_license_text(text) | |
text_encoding = tokenizer( | |
text, | |
max_length=TEXT_MAX_TOKEN_LEN, | |
padding="max_length", | |
truncation=True, | |
return_attention_mask=True, | |
add_special_tokens=True, | |
return_tensors="pt" | |
) | |
generated_ids = model.generate( | |
input_ids=text_encoding["input_ids"].to(device, dtype=torch.long), | |
attention_mask=text_encoding["attention_mask"].to(device, dtype=torch.long), | |
max_length=SUMMARY_MAX_TOKEN_LEN, | |
num_beams=2, | |
repetition_penalty=2.5, | |
length_penalty=1.0, | |
early_stopping=True | |
) | |
preds = [ | |
tokenizer.decode( | |
gen_id, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=True | |
) for gen_id in generated_ids | |
] | |
return "".join(preds), definitions | |
def summarize(text, load_from_huggingface=True): | |
""" | |
Summarizes the given License text | |
Parameters | |
---------- | |
text : str | |
Preprocessed License text. | |
Returns | |
------- | |
summary : str | |
Summary of the License text. | |
""" | |
if load_from_huggingface: | |
print("Loading Model from HuggingFace...") | |
CUSTOM_MODEL_NAME = "utkarshsaboo45/ClearlyDefinedLicenseSummarizer" | |
model = AutoModelForSeq2SeqLM.from_pretrained(CUSTOM_MODEL_NAME).to(device) | |
tokenizer = AutoTokenizer.from_pretrained(CUSTOM_MODEL_NAME) | |
else: | |
if os.path.exists(MODEL_PATH + MODEL_FILENAME): | |
print("Loading Model...") | |
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict=True).to(device) | |
TOKENIZER = T5Tokenizer.from_pretrained(MODEL_NAME) | |
model.load_state_dict(torch.load(MODEL_PATH + MODEL_FILENAME)) | |
model.eval() | |
else: | |
print("Training model...") | |
if not os.path.exists(MODEL_PATH): | |
os.makedirs(MODEL_PATH) | |
train_dataloader, _ = prepare_dataloaders() | |
model = train_and_save_model(train_dataloader, MODEL_PATH + MODEL_FILENAME) | |
tokenizer = TOKENIZER | |
return summarize_text_with_model(text, model, tokenizer) | |
def summarize_license_files(path): | |
""" | |
Summarize License files from paths and save them as summary text files. | |
Parameters | |
---------- | |
path : list(str) | |
A list of paths of the License files. | |
""" | |
paths = glob.glob(path + "*.txt") | |
for license_path in paths: | |
with open(license_path, "r", encoding="utf-8") as f: | |
summary, _ = summarize(f.read()) | |
with open(license_path.replace(".txt", "") + "__summary.txt", "w", encoding="utf-8") as f: | |
f.write(summary) | |