Josh Nguyen commited on
Commit
74a3063
1 Parent(s): 6797eb5

Change max_new_tokens to max_length

Browse files
Files changed (1) hide show
  1. app.py +3 -5
app.py CHANGED
@@ -30,7 +30,7 @@ model = AutoModelForCausalLM.from_pretrained(
30
 
31
 
32
  def generate_text(prompt: str,
33
- max_new_tokens: int = 512,
34
  temperature: float = 0.5,
35
  top_p: float = 0.95,
36
  top_k: int = 50) -> str:
@@ -43,8 +43,6 @@ def generate_text(prompt: str,
43
  inputs = inputs.to(DEVICE)
44
 
45
  # Prepare arguments for generation
46
- input_length = inputs["input_ids"].shape[-1]
47
- max_new_tokens = min(max_new_tokens, WINDOW_SIZE - input_length)
48
  if temperature >= 1.0:
49
  temperature = 0.99
50
  elif temperature <= 0.0:
@@ -60,7 +58,7 @@ def generate_text(prompt: str,
60
  generation_kwargs = dict(
61
  inputs=inputs,
62
  streamer=inputs,
63
- max_new_tokens=max_new_tokens,
64
  do_sample=True,
65
  top_p=top_p,
66
  top_k=top_k,
@@ -90,7 +88,7 @@ demo = gr.Interface(
90
  scale=10,
91
  ),
92
  gr.Slider(
93
- label="Maximum new tokens",
94
  minimum=1,
95
  maximum=4096,
96
  step=1,
 
30
 
31
 
32
  def generate_text(prompt: str,
33
+ max_length: int = 1024,
34
  temperature: float = 0.5,
35
  top_p: float = 0.95,
36
  top_k: int = 50) -> str:
 
43
  inputs = inputs.to(DEVICE)
44
 
45
  # Prepare arguments for generation
 
 
46
  if temperature >= 1.0:
47
  temperature = 0.99
48
  elif temperature <= 0.0:
 
58
  generation_kwargs = dict(
59
  inputs=inputs,
60
  streamer=inputs,
61
+ max_length=max_length,
62
  do_sample=True,
63
  top_p=top_p,
64
  top_k=top_k,
 
88
  scale=10,
89
  ),
90
  gr.Slider(
91
+ label="Maximum length of the output",
92
  minimum=1,
93
  maximum=4096,
94
  step=1,