Mishmosh commited on
Commit
d81e698
·
verified ·
1 Parent(s): 1e80c8e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +272 -0
app.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://www.kaggle.com/datasets/wcukierski/enron-email-dataset
2
+ from google.colab import drive
3
+ drive.mount('/content/drive')
4
+ # libraries
5
+ #!pip install transformers --upgrade
6
+ #!pip install gradio
7
+ #!pip install datasets
8
+ #!pip install huggingface-hub
9
+ #!pip install chromadb
10
+ #!pip install accelerate==0.21.0
11
+ #!pip install transformers[torch]
12
+ #!pip install git+https://github.com/huggingface/accelerate.git
13
+ import pandas as pd
14
+ import numpy as np
15
+ from transformers import AutoModel
16
+ from sklearn.model_selection import train_test_split
17
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering, pipeline
18
+ import gradio as gr
19
+ import chromadb
20
+ from datasets import Dataset
21
+ from transformers import Trainer, TrainingArguments
22
+ from transformers import AutoModelForMaskedLM, DataCollatorForLanguageModeling
23
+ from transformers import TextDataset, DataCollatorForLanguageModeling
24
+ #from transformers import TrainingArguments, Trainer
25
+ #from transformers import pipeline
26
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
27
+
28
+ file_path = '/content/drive/MyDrive/emails.csv'
29
+ df = pd.read_csv(file_path)
30
+ df_columns = df.columns
31
+ print(df.head(10))
32
+
33
+ messages_df = df['message'] #extract message column
34
+ print(messages_df.head())
35
+ print(type(messages_df))
36
+
37
+ # Extract 1% of the content as test set so that instead of 500,000 emails 5,000 are being used as a sample. (Kept changing test size to stop colab crashing.)
38
+ emails_train, emails_test = train_test_split(messages_df, test_size=0.000008, random_state=42)
39
+ print(emails_test)
40
+ print(type(emails_test))
41
+
42
+ pd.set_option('display.max_colwidth', None) #check content
43
+ print(emails_test.head()) #first 5 rows
44
+ print(type(emails_test))
45
+
46
+ # Embeddings
47
+ import os
48
+ # Define maximum sequence length
49
+ max_seq_length = 512
50
+ # Truncate or pad sequences to the maximum length
51
+ truncated_emails_test = [email[:max_seq_length] for email in emails_test]
52
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
53
+ model = AutoModel.from_pretrained("bert-base-uncased")
54
+ embeddings_pipeline = pipeline('feature-extraction', model=model, tokenizer=tokenizer)
55
+ embeddings = embeddings_pipeline(truncated_emails_test)
56
+ print(type(embeddings))
57
+ #print(embeddings[:5]) #cannot see embeddings like this
58
+
59
+ # to see the embeddings
60
+ # Save each embedding to a separate file
61
+ for i, emb in enumerate(embeddings):
62
+ np.save(f"embedding_{i}.npy", emb)
63
+ # Load each embedding from its corresponding file
64
+ loaded_embeddings = []
65
+ for i in range(len(embeddings)):
66
+ emb = np.load(f"embedding_{i}.npy")
67
+ loaded_embeddings.append(emb)
68
+ for i, emb in enumerate(loaded_embeddings):
69
+ print(f"Embedding {i}:")
70
+ print(emb)
71
+
72
+
73
+
74
+ import chromadb
75
+ chroma_client = chromadb.Client()
76
+ collection = chroma_client.create_collection(name="michelletest")
77
+
78
+
79
+ # Extract the embeddings from the nested list
80
+ extracted_embeddings = [embedding[0][0] for embedding in embeddings]
81
+
82
+ # Add embeddings to the ChromaDB collection
83
+ collection.add(
84
+ embeddings=extracted_embeddings[:5], # Add the first 5 embeddings
85
+ documents=emails_test.tolist()[:5], # Add the first 5 documents
86
+ metadatas=[{"source": "emails_test"} for _ in range(5)], # Metadata for the first 5 documents
87
+ ids=[f"id{i}" for i in range(5)] # ID for the first 5 documents
88
+ )
89
+
90
+
91
+ collection.count() #check how many in the database
92
+
93
+ # Retrieve the first 2 entries from the ChromaDB database to check that it worked properly
94
+ collection.get()
95
+
96
+ # Convert the Series to a DataFrame
97
+ emails_test_df = emails_test.to_frame()
98
+ # Print the column names of the DataFrame
99
+ print(emails_test_df.columns)
100
+
101
+ print(emails_test_df['message']) #checking content of messsages for fine tuning the model
102
+
103
+ print(emails_test_df['message'].head())
104
+
105
+ # Print the column names of the DataFrame
106
+ print(emails_test_df.columns)
107
+
108
+ num_entries = emails_test_df.shape[0]
109
+ print("Number of entries in emails_test_df:", num_entries)
110
+
111
+
112
+ # Extract 1% of the content as test set so that instead of 500,000 emails 5,000 are being used as a sample; 60 used in the end
113
+ emails_train, emails_test2 = train_test_split(messages_df, test_size=0.00001, random_state=42)
114
+ print(emails_test2)
115
+ print(type(emails_test2))
116
+ num_entries2=emails_test2.shape[0]
117
+ print("number of",num_entries2)
118
+
119
+ # Convert pandas Series to a list of strings
120
+ text_list = emails_test_df['message'].tolist()
121
+
122
+ # Verify the type and content
123
+ print(type(text_list))
124
+ print(text_list[:5]) # Print the first 5 entries as an example
125
+
126
+
127
+ print(text_list[:5])
128
+
129
+ print(text_list)
130
+
131
+
132
+ print(text_list[2]) #to see the content of an average mail to know what to clean up
133
+
134
+ def remove_sections(email): #clean email of content that is not useful
135
+ """Remove sections including original message, from, sent, to, subject line, and additional headers."""
136
+ sections_to_remove = [
137
+ "----- Original Message -----",
138
+ "From:",
139
+ "Sent:",
140
+ "To:",
141
+ "CC:",
142
+ "Subject:",
143
+ "Message-ID:",
144
+ "Date:",
145
+ "Mime-Version:",
146
+ "Content-Type:",
147
+ "Content-Transfer-Encoding:",
148
+ "X-cc:",
149
+ "X-bcc:",
150
+ "X-Folder:",
151
+ "X-Origin:",
152
+ "X-FileName:",
153
+ "-----Original Message-----"
154
+ ]
155
+
156
+ for section in sections_to_remove:
157
+ email = [line for line in email if section not in line]
158
+
159
+ return email
160
+ # Remove sections from each email in the list
161
+ cleaned_text_list = [remove_sections(email.split("\n")) for email in text_list]
162
+
163
+ # Print out the cleaned emails to see if content looks ok
164
+ for cleaned_email in cleaned_text_list:
165
+ print("\n".join(cleaned_email))
166
+ print("=" * 50) # Separate each cleaned email for better readability
167
+
168
+
169
+ #fine tune language model
170
+
171
+ # Define the pre-trained model name (bart-base)
172
+ model_name = "facebook/bart-base"
173
+
174
+ # Load the tokenizer for bart-base
175
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
176
+
177
+
178
+
179
+ # Function to preprocess text_list for training
180
+ def prepare_data(text_list):
181
+ # Tokenize the text with padding and truncation (BART handles these well)
182
+ inputs = tokenizer(text_list, padding="max_length", truncation=True)
183
+
184
+ # Copy the input IDs for labels (desired output during training)
185
+ labels = inputs.input_ids.copy()
186
+
187
+ # Create a Dataset object from the preprocessed data
188
+ return Dataset.from_dict({"input_ids": inputs["input_ids"], "labels": labels})
189
+ """Preprocesses text data for training the BART model.
190
+
191
+ Args:
192
+ text_list: A list of strings containing the text data.
193
+
194
+ Returns:
195
+ A Dataset object containing the preprocessed data.
196
+ """
197
+
198
+
199
+ # Prepare your training data from the text list
200
+ train_data = prepare_data(text_list)
201
+
202
+ # Define the fine-tuning model (BART for sequence-to-sequence tasks)
203
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
204
+
205
+ # Training hyperparameters (adjust as needed)
206
+ batch_size = 8
207
+ learning_rate = 2e-5
208
+ num_epochs = 3
209
+
210
+
211
+ from transformers import Trainer
212
+
213
+ # Define the Trainer object for training management
214
+ trainer = Trainer(
215
+ model=model,
216
+ args=TrainingArguments(
217
+ output_dir="./results", # Output directory for checkpoints etc.
218
+ overwrite_output_dir=True,
219
+ per_device_train_batch_size=batch_size,
220
+ learning_rate=learning_rate,
221
+ num_train_epochs=num_epochs,
222
+ ),
223
+ train_dataset=train_data,
224
+ )
225
+
226
+
227
+
228
+ # Start the fine-tuning process
229
+ trainer.train()
230
+
231
+ # Save the fine-tuned model and tokenizer
232
+ model.save_pretrained("./fine-tuned_bart")
233
+ tokenizer.save_pretrained("./fine-tuned_bart")
234
+
235
+ print("Fine-tuning completed! Model saved in ./fine-tuned_bart")
236
+
237
+
238
+ # Fine-tuning completed! Model saved in ./fine-tuned_bart
239
+ # i used a very small amount of input so that colab stopped crashing
240
+
241
+
242
+
243
+ import gradio as gr
244
+ from transformers import BartForQuestionAnswering, BartTokenizer
245
+
246
+ # Load the fine-tuned BART model
247
+ model = BartForQuestionAnswering.from_pretrained("./fine-tuned_bart")
248
+ tokenizer = BartTokenizer.from_pretrained("./fine-tuned_bart")
249
+
250
+ # Function to answer questions
251
+ def answer_question(question):
252
+ inputs = tokenizer.encode_plus(question, return_tensors="pt", max_length=512, truncation=True)
253
+ input_ids = inputs["input_ids"].tolist()[0]
254
+
255
+ answer_start_scores, answer_end_scores = model(**inputs)
256
+ answer_start = torch.argmax(answer_start_scores)
257
+ answer_end = torch.argmax(answer_end_scores) + 1
258
+
259
+ answer = tokenizer.decode(input_ids[answer_start:answer_end])
260
+ return answer
261
+
262
+ # Create Gradio interface
263
+ iface = gr.Interface(
264
+ fn=answer_question,
265
+ inputs="text",
266
+ outputs="text",
267
+ title="Question Answering Model",
268
+ description="Enter a question to get the answer."
269
+ )
270
+
271
+ # Launch the interface
272
+ iface.launch()