Stefan Dumitrescu commited on
Commit
f4a3863
·
1 Parent(s): c44f938
Files changed (1) hide show
  1. app.py +15 -18
app.py CHANGED
@@ -1,14 +1,7 @@
1
  import transformers
2
  import streamlit as st
3
 
4
- from transformers import AutoTokenizer, AutoModelWithLMHead
5
-
6
- ###################
7
- # global variables
8
-
9
-
10
- ###################
11
- # page configs and functions
12
 
13
  st.set_page_config(
14
  page_title="Romanian Text Generator",
@@ -16,29 +9,33 @@ st.set_page_config(
16
  layout="wide"
17
  )
18
 
19
- model_list = ["dumitrescustefan/gpt-neo-romanian-780m"]
20
- st.sidebar.header("Select Model")
21
- model_checkpoint = st.sidebar.radio("", model_list)
22
- text_element = st.text_input('Text:', 'Acesta este un exemplu,')
23
-
24
-
25
 
 
 
 
 
 
 
26
 
 
 
27
 
28
- st.sidebar.header("Select type of PERSON detection")
29
  max_length = st.sidebar.slider("Max Length", value=20, min_value=10, max_value=200)
30
  temperature = st.sidebar.slider("Temperature", value=1.0, min_value=0.0, max_value=1.0, step=0.05)
31
  top_k = st.sidebar.slider("Top-k", min_value=0, max_value=15, step=1, value=0)
32
  top_p = st.sidebar.slider("Top-p", min_value=0.0, max_value=1.0, step=0.05, value=0.9)
33
 
 
34
 
35
  @st.cache(allow_output_mutation=True)
36
  def setModel(model_checkpoint):
37
- model = AutoModelWithLMHead.from_pretrained(model_checkpoint)
38
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
39
  return model, tokenizer
40
 
41
- def infer(model, tokenizer, text, input_ids, max_length, temperature, top_k, top_p):
42
  encoded_prompt = tokenizer.encode(text, add_special_tokens=False, return_tensors="pt")
43
  output_sequences = model.generate(
44
  input_ids=encoded_prompt.input_ids,
@@ -53,7 +50,7 @@ def infer(model, tokenizer, text, input_ids, max_length, temperature, top_k, top
53
  return output_sequences
54
 
55
  model, tokenizer = setModel(model_checkpoint)
56
- output_sequences = infer(model, tokenizer, text_element, input_ids, max_length, temperature, top_k, top_p)
57
 
58
  for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
59
  print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
 
1
  import transformers
2
  import streamlit as st
3
 
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
 
 
 
5
 
6
  st.set_page_config(
7
  page_title="Romanian Text Generator",
 
9
  layout="wide"
10
  )
11
 
12
+ st.write("Type your text here and press Ctrl+Enter to generate the next sequence:")
 
 
 
 
 
13
 
14
+ model_list = [
15
+ "dumitrescustefan/gpt-neo-romanian-780m"
16
+ "readerbench/RoGPT2-base",
17
+ "readerbench/RoGPT2-medium",
18
+ "readerbench/RoGPT2-large"
19
+ ]
20
 
21
+ st.sidebar.header("Select model")
22
+ model_checkpoint = st.sidebar.radio("", model_list)
23
 
24
+ st.sidebar.header("Select generation parameters")
25
  max_length = st.sidebar.slider("Max Length", value=20, min_value=10, max_value=200)
26
  temperature = st.sidebar.slider("Temperature", value=1.0, min_value=0.0, max_value=1.0, step=0.05)
27
  top_k = st.sidebar.slider("Top-k", min_value=0, max_value=15, step=1, value=0)
28
  top_p = st.sidebar.slider("Top-p", min_value=0.0, max_value=1.0, step=0.05, value=0.9)
29
 
30
+ text_element = st.text_input('Text:', 'Acesta este un exemplu,')
31
 
32
  @st.cache(allow_output_mutation=True)
33
  def setModel(model_checkpoint):
34
+ model = AutoModelForCausalLM.from_pretrained(model_checkpoint)
35
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
36
  return model, tokenizer
37
 
38
+ def infer(model, tokenizer, text, max_length, temperature, top_k, top_p):
39
  encoded_prompt = tokenizer.encode(text, add_special_tokens=False, return_tensors="pt")
40
  output_sequences = model.generate(
41
  input_ids=encoded_prompt.input_ids,
 
50
  return output_sequences
51
 
52
  model, tokenizer = setModel(model_checkpoint)
53
+ output_sequences = infer(model, tokenizer, text_element, max_length, temperature, top_k, top_p)
54
 
55
  for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
56
  print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")