g3casey commited on
Commit
36c9b26
·
1 Parent(s): 8497bb0

Changing to paste in text for input since the wikipedia api doesn't work.

Browse files
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/aws.xml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="accountSettings">
4
+ <option name="activeRegion" value="us-east-1" />
5
+ <option name="recentlyUsedRegions">
6
+ <list>
7
+ <option value="us-east-1" />
8
+ </list>
9
+ </option>
10
+ </component>
11
+ </project>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
5
+ <option name="ignoredErrors">
6
+ <list>
7
+ <option value="N806" />
8
+ <option value="N803" />
9
+ <option value="N802" />
10
+ </list>
11
+ </option>
12
+ </inspection_tool>
13
+ </profile>
14
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8" project-jdk-type="Python SDK" />
4
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/summaraize.iml" filepath="$PROJECT_DIR$/.idea/summaraize.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/other.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="PySciProjectComponent">
4
+ <option name="PY_SCI_VIEW" value="true" />
5
+ <option name="PY_SCI_VIEW_SUGGESTED" value="true" />
6
+ </component>
7
+ </project>
.idea/summaraize.iml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ <component name="PyDocumentationSettings">
9
+ <option name="renderExternalDocumentation" value="true" />
10
+ </component>
11
+ </module>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="$PROJECT_DIR$" vcs="Git" />
5
+ </component>
6
+ </project>
app.py CHANGED
@@ -36,7 +36,7 @@ def get_wiki(search_term):
36
  orig_text_len = len(text)
37
  text = summarize(text)
38
  sum_length = len(text)
39
- return [text,orig_text_len,sum_length]
40
 
41
 
42
  # def inference(file):
@@ -48,10 +48,10 @@ out_orig_test_len = gr.Number(label='Original Text Length')
48
  out_sum_text_len = gr.Number(label='Summarized Text Length')
49
 
50
  iface = gr.Interface(fn=get_wiki,
51
- inputs=gr.Textbox(lines=2, placeholder="Wikipedia search term here...", label='Search Term'),
52
  outputs=[out_sum_text,out_orig_test_len,out_sum_text_len],
53
- title='Wikipedia Article Summary',
54
- description='Enter a search term to get a wikipedia article associated with it. Then we will summarize the article found. ',
55
  sample_inputs='guardians of the galaxy'
56
  )
57
  iface.launch() # To create a public link, set `share=True` in `launch()`.
 
36
  orig_text_len = len(text)
37
  text = summarize(text)
38
  sum_length = len(text)
39
+ return [text, orig_text_len, sum_length]
40
 
41
 
42
  # def inference(file):
 
48
  out_sum_text_len = gr.Number(label='Summarized Text Length')
49
 
50
  iface = gr.Interface(fn=get_wiki,
51
+ inputs=gr.Textbox(lines=50, placeholder="Wikipedia search term here...", label='Search Term'),
52
  outputs=[out_sum_text,out_orig_test_len,out_sum_text_len],
53
+ title='Article Summary',
54
+ description='Paste in an article and it will be summarized',
55
  sample_inputs='guardians of the galaxy'
56
  )
57
  iface.launch() # To create a public link, set `share=True` in `launch()`.
inference.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from transformers import AutoModelForSeq2SeqLM
2
+
3
+ model = AutoModelForSeq2SeqLM.from_pretrained("sgugger/my-awesome-model")
summarize_train.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ from datasets import load_dataset, load_metric
3
+ import datasets
4
+ import random
5
+ import pandas as pd
6
+ from IPython.display import display, HTML
7
+ from transformers import AutoTokenizer
8
+ from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
9
+
10
+
11
+ model_checkpoint = "t5-small"
12
+
13
+ raw_datasets = load_dataset("xsum")
14
+ metric = load_metric("rouge")
15
+
16
+
17
+
18
+ def show_random_elements(dataset, num_examples=5):
19
+ assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
20
+ picks = []
21
+ for _ in range(num_examples):
22
+ pick = random.randint(0, len(dataset) - 1)
23
+ while pick in picks:
24
+ pick = random.randint(0, len(dataset) - 1)
25
+ picks.append(pick)
26
+
27
+ df = pd.DataFrame(dataset[picks])
28
+ for column, typ in dataset.features.items():
29
+ if isinstance(typ, datasets.ClassLabel):
30
+ df[column] = df[column].transform(lambda i: typ.names[i])
31
+ display(HTML(df.to_html()))
32
+
33
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
34
+ print(transformers.__version__)
35
+
36
+ if model_checkpoint in ["t5-small", "t5-base", "t5-larg", "t5-3b", "t5-11b"]:
37
+ prefix = "summarize: "
38
+ else:
39
+ prefix = ""
40
+
41
+ max_input_length = 1024
42
+ max_target_length = 128
43
+
44
+ def preprocess_function(examples):
45
+ inputs = [prefix + doc for doc in examples["document"]]
46
+ model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
47
+
48
+ # Setup the tokenizer for targets
49
+ with tokenizer.as_target_tokenizer():
50
+ labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)
51
+
52
+ model_inputs["labels"] = labels["input_ids"]
53
+ return model_inputs
54
+
55
+
56
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
57
+
58
+ batch_size = 16
59
+ model_name = model_checkpoint.split("/")[-1]
60
+ args = Seq2SeqTrainingArguments(
61
+ f"{model_name}-finetuned-xsum",
62
+ evaluation_strategy = "epoch",
63
+ learning_rate=2e-5,
64
+ per_device_train_batch_size=batch_size,
65
+ per_device_eval_batch_size=batch_size,
66
+ weight_decay=0.01,
67
+ save_total_limit=3,
68
+ num_train_epochs=1,
69
+ predict_with_generate=True,
70
+ fp16=True,
71
+ push_to_hub=True,
72
+ )
73
+
74
+ import nltk
75
+ import numpy as np
76
+
77
+
78
+ def compute_metrics(eval_pred):
79
+ predictions, labels = eval_pred
80
+ decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
81
+ # Replace -100 in the labels as we can't decode them.
82
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
83
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
84
+
85
+ # Rouge expects a newline after each sentence
86
+ decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
87
+ decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
88
+
89
+ result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
90
+ # Extract a few results
91
+ result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
92
+
93
+ # Add mean generated length
94
+ prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
95
+ result["gen_len"] = np.mean(prediction_lens)
96
+
97
+ return {k: round(v, 4) for k, v in result.items()}
98
+
99
+ trainer = Seq2SeqTrainer(
100
+ model,
101
+ args,
102
+ train_dataset=tokenized_datasets["train"],
103
+ eval_dataset=tokenized_datasets["validation"],
104
+ data_collator=data_collator,
105
+ tokenizer=tokenizer,
106
+ compute_metrics=compute_metrics
107
+ )
108
+
109
+
tester.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wikipedia
2
+
3
+ def search_wiki(text):
4
+ article_list = wikipedia.search(text)
5
+ wikipedia.page(article_list[0])
6
+
7
+
8
+ def get_wiki(search_term):
9
+ return wikipedia.page(search_term)
10
+
11
+
12
+
13
+ # src = search_wiki('spacex')
14
+ get = get_wiki('spacex')
15
+ # print(src)
16
+ print(get)
17
+ print(wikipedia.summary("Python Programming Language"))
18
+ x = search_wiki('spacex')
19
+
20
+ print('done')
21
+