File size: 1,836 Bytes
ca5ca01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import streamlit as st
import torch
from tqdm import tqdm
from peft import PeftModel, PeftConfig
from transformers import AutoModelForSeq2SeqLM
from transformers import AutoTokenizer

config = PeftConfig.from_pretrained("NursNurs/T5ForReverseDictionary")
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")
model = PeftModel.from_pretrained(model, "NursNurs/T5ForReverseDictionary")

tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def return_top_k(sentence, k=10):

  inputs = [f"Descripton : {sentence}. Word : "]

  inputs = tokenizer(
      inputs,
      padding=True, truncation=True,
      return_tensors="pt",
  )


  model.to(device)

  with torch.no_grad():
    inputs = {k: v.to(device) for k, v in inputs.items()}
    output_sequences = model.generate(input_ids=inputs["input_ids"], max_new_tokens=10, num_beams=k, num_return_sequences=k, #max_length=3,
                                              top_p = 50, output_scores=True, return_dict_in_generate=True) #repetition_penalty=10000.0
    #print("output_sequences", output_sequences)
    logits = output_sequences['sequences_scores'].clone().detach()
    decoded_probabilities = torch.softmax(logits, dim=0)


    #all word predictions
    predictions = [tokenizer.decode(tokens, skip_special_tokens=True) for tokens in output_sequences['sequences']]
    probabilities = [round(float(prob), 2) for prob in decoded_probabilities]

  return predictions


st.title("You name it!")

# adding the text that will show in the text box as default
default_value = "Type the description of the word you have in mind!"

sent = st.text_area("Text", default_value, height = 275)

result = return_top_k(sent)
st.write("Here are my guesses about your word:")
st.write(result)