Josh Nguyen commited on
Commit
422252e
1 Parent(s): 74a3063

Update app

Browse files
Files changed (1) hide show
  1. app.py +11 -10
app.py CHANGED
@@ -30,19 +30,19 @@ model = AutoModelForCausalLM.from_pretrained(
30
 
31
 
32
  def generate_text(prompt: str,
33
- max_length: int = 1024,
34
  temperature: float = 0.5,
35
  top_p: float = 0.95,
36
  top_k: int = 50) -> str:
37
 
38
  # Encode the prompt
39
  inputs = tokenizer([prompt],
40
- return_tensors="pt",
41
- add_special_tokens=False,
42
- return_token_type_ids=False)
43
- inputs = inputs.to(DEVICE)
44
 
45
  # Prepare arguments for generation
 
 
46
  if temperature >= 1.0:
47
  temperature = 0.99
48
  elif temperature <= 0.0:
@@ -58,7 +58,7 @@ def generate_text(prompt: str,
58
  generation_kwargs = dict(
59
  inputs=inputs,
60
  streamer=inputs,
61
- max_length=max_length,
62
  do_sample=True,
63
  top_p=top_p,
64
  top_k=top_k,
@@ -70,9 +70,10 @@ def generate_text(prompt: str,
70
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
71
  thread.start()
72
 
73
- # outputs = []
74
- for text in streamer:
75
- return text
 
76
 
77
 
78
  demo = gr.Interface(
@@ -88,7 +89,7 @@ demo = gr.Interface(
88
  scale=10,
89
  ),
90
  gr.Slider(
91
- label="Maximum length of the output",
92
  minimum=1,
93
  maximum=4096,
94
  step=1,
 
30
 
31
 
32
  def generate_text(prompt: str,
33
+ max_new_tokens: int = 512,
34
  temperature: float = 0.5,
35
  top_p: float = 0.95,
36
  top_k: int = 50) -> str:
37
 
38
  # Encode the prompt
39
  inputs = tokenizer([prompt],
40
+ return_tensors='pt',
41
+ add_special_tokens=False).to(DEVICE)
 
 
42
 
43
  # Prepare arguments for generation
44
+ input_length = inputs["input_ids"].shape[-1]
45
+ max_new_tokens = min(max_new_tokens, WINDOW_SIZE - input_length)
46
  if temperature >= 1.0:
47
  temperature = 0.99
48
  elif temperature <= 0.0:
 
58
  generation_kwargs = dict(
59
  inputs=inputs,
60
  streamer=inputs,
61
+ max_new_tokens=max_new_tokens,
62
  do_sample=True,
63
  top_p=top_p,
64
  top_k=top_k,
 
70
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
71
  thread.start()
72
 
73
+ generated_text = ""
74
+ for new_text in streamer:
75
+ generated_text += new_text
76
+ return generated_text
77
 
78
 
79
  demo = gr.Interface(
 
89
  scale=10,
90
  ),
91
  gr.Slider(
92
+ label="Maximum new tokens",
93
  minimum=1,
94
  maximum=4096,
95
  step=1,