jcarbonnell commited on
Commit
1cda371
·
1 Parent(s): 7a37089

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -1
app.py CHANGED
@@ -7,6 +7,10 @@ model=GPT2LMHeadModel.from_pretrained("DemocracyStudio/generate_nft_content")
7
  tokenizer=GPT2Tokenizer.from_pretrained("DemocracyStudio/generate_nft_content")
8
  summarize=Summarizer()
9
 
 
 
 
 
10
  st.title("Text generation for the marketing content of NFTs")
11
  st.subheader("Course project 'NLP with transformers' at opencampus.sh, Spring 2022")
12
 
@@ -21,6 +25,8 @@ if choice == 'NFT':
21
  if st.button("Generate"):
22
  prompt = "<|startoftext|>"
23
  generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
 
 
24
  sample_outputs = model.generate(
25
  generated,
26
  do_sample=True,
@@ -31,11 +37,11 @@ if choice == 'NFT':
31
  )
32
  for i, sample_output in enumerate(sample_outputs):
33
  generated_text = tokenizer.decode(sample_output, skip_special_tokens=True)
 
34
 
35
  #st.text("Keywords: {}\n".format(keywords))
36
  #st.text("Length in number of words: {}\n".format(length))
37
  st.text("This is your tailored blog article {generated_text}")
38
- summary = summarize(generated_text, num_sentences=1)
39
  st.text("This is a tweet-sized summary of your article {summary}")
40
  else:
41
  st.write("Topic not available yet")
 
7
  tokenizer=GPT2Tokenizer.from_pretrained("DemocracyStudio/generate_nft_content")
8
  summarize=Summarizer()
9
 
10
+ device = torch.device("cuda")
11
+ model.cuda()
12
+ model.to(device)
13
+
14
  st.title("Text generation for the marketing content of NFTs")
15
  st.subheader("Course project 'NLP with transformers' at opencampus.sh, Spring 2022")
16
 
 
25
  if st.button("Generate"):
26
  prompt = "<|startoftext|>"
27
  generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
28
+ generated = generated.to(device)
29
+
30
  sample_outputs = model.generate(
31
  generated,
32
  do_sample=True,
 
37
  )
38
  for i, sample_output in enumerate(sample_outputs):
39
  generated_text = tokenizer.decode(sample_output, skip_special_tokens=True)
40
+ summary = summarize(generated_text, num_sentences=1)
41
 
42
  #st.text("Keywords: {}\n".format(keywords))
43
  #st.text("Length in number of words: {}\n".format(length))
44
  st.text("This is your tailored blog article {generated_text}")
 
45
  st.text("This is a tweet-sized summary of your article {summary}")
46
  else:
47
  st.write("Topic not available yet")