Antoniskaraolis commited on
Commit
ac888bf
·
verified ·
1 Parent(s): 0c78c53

Upload assessment3_antonis_karaolis.py

Browse files
Files changed (1) hide show
  1. assessment3_antonis_karaolis.py +98 -0
assessment3_antonis_karaolis.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Assessment3_Antonis_Karaolis.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1Qd3aOoBB6q1uy2pHPeLudMlsYd9J30-C
8
+ """
9
+
10
+ !pip install -U sentence-transformers
11
+ !pip install transformers
12
+ !pip install gradio
13
+ !pip install chromadb
14
+ !pip install datasets
15
+ pip install accelerate -U
16
+ pip install transformers[torch]
17
+
18
+ import pandas as pd
19
+ from sentence_transformers import SentenceTransformer
20
+ import chromadb
21
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer, TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments
22
+ import gradio as gr
23
+ import torch
24
+ from accelerate import Accelerator
25
+ from transformers import Trainer, TrainingArguments
26
+ from datasets import Dataset
27
+ from torch.cuda.amp import autocast
28
+
29
+ emails_df = pd.read_csv('/content/emails.csv', nrows=500, on_bad_lines='skip')
30
+ emails_df['message'] = emails_df['message'].apply(lambda x: x.strip() if type(x) == str else '')
31
+
32
+ model = SentenceTransformer('all-MiniLM-L6-v2')
33
+
34
+ emails_embeddings = model.encode(emails_df['message'].tolist(), show_progress_bar=True)
35
+
36
+ chroma_client = chromadb.Client()
37
+ collection = chroma_client.create_collection(name="enron_emails_subset")
38
+
39
+ collection.add(
40
+ embeddings=emails_embeddings.tolist(),
41
+ documents=emails_df['message'].tolist(),
42
+ metadatas=[{"email_id": idx} for idx in emails_df.index],
43
+ ids=[str(idx) for idx in emails_df.index]
44
+ )
45
+
46
+ tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
47
+ tokenizer.pad_token = tokenizer.eos_token
48
+
49
+ def tokenize_function(examples):
50
+ with autocast():
51
+ result = tokenizer(examples['message'], truncation=True, padding="max_length", max_length=128)
52
+ result["labels"] = result["input_ids"].copy()
53
+ return result
54
+
55
+ emails_df = pd.read_csv('/content/emails.csv', nrows=500, on_bad_lines='skip')
56
+ dataset = Dataset.from_pandas(emails_df[['message']])
57
+ dataset = dataset.map(tokenize_function, batched=True, num_proc=4)
58
+
59
+ train_dataset = dataset.train_test_split(test_size=0.1)['train']
60
+
61
+ model = GPT2LMHeadModel.from_pretrained('distilgpt2')
62
+ model.resize_token_embeddings(len(tokenizer))
63
+
64
+ training_args = TrainingArguments(
65
+ output_dir='/content/model_output',
66
+ num_train_epochs=1,
67
+ per_device_train_batch_size=8,
68
+ gradient_accumulation_steps=2,
69
+ save_steps=250,
70
+ logging_dir='/content/logs',
71
+ logging_strategy="steps",
72
+ logging_steps=50
73
+ )
74
+
75
+ trainer = Trainer(
76
+ model=model,
77
+ args=training_args,
78
+ train_dataset=train_dataset,
79
+ tokenizer=tokenizer
80
+ )
81
+
82
+ trainer.train()
83
+
84
+ model.save_pretrained('/content/model_output')
85
+ tokenizer.save_pretrained('/content/model_output')
86
+
87
+ model = GPT2LMHeadModel.from_pretrained('/content/model_output')
88
+ tokenizer = GPT2Tokenizer.from_pretrained('/content/model_output')
89
+
90
+ def answer_question(question):
91
+ model.eval()
92
+ inputs = tokenizer.encode(question, return_tensors='pt')
93
+ outputs = model.generate(inputs, max_length=100, num_return_sequences=1)
94
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
95
+
96
+ # Gradio interface
97
+ iface = gr.Interface(fn=answer_question, inputs="text", outputs="text")
98
+ iface.launch()