diabolic6045 commited on
Commit
3611d45
·
verified ·
1 Parent(s): 92db476

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -28
app.py CHANGED
@@ -9,31 +9,33 @@ model = AutoModelForCausalLM.from_pretrained("diabolic6045/ELN-Llama-1B-base")
9
  def generate_response(message, temperature, max_length):
10
  # Tokenize input
11
  inputs = tokenizer(message, return_tensors="pt", truncation=True, max_length=512)
12
-
13
- # Initialize the generated text with the input message
14
- generated_text = message
15
 
16
  # Generate response token by token
17
- with torch.no_grad():
18
- generated_ids = model.generate(
19
- inputs["input_ids"],
20
- max_length=max_length,
21
- temperature=temperature,
22
- do_sample=True,
23
- pad_token_id=tokenizer.eos_token_id,
24
- num_return_sequences=1,
25
- return_dict_in_generate=True,
26
- output_scores=True,
27
- )
28
-
29
- # Get the generated token ids (excluding the input prompt)
30
- new_tokens = generated_ids.sequences[0][inputs["input_ids"].shape[1]:]
31
-
32
- # Decode and yield tokens one by one
33
- for i in range(len(new_tokens)):
34
- next_token = tokenizer.decode(new_tokens[:i+1], skip_special_tokens=True)
35
- generated_text += next_token
36
- yield generated_text
 
 
 
37
 
38
  # Create the Gradio interface
39
  demo = gr.Interface(
@@ -47,17 +49,17 @@ demo = gr.Interface(
47
  title="LLaMA Text Completion",
48
  description="Generate text completions using the ELN-Llama-1B model. Enter the start of a text, and the model will continue it.",
49
  examples=[
50
- ["Once upon a time in a magical forest", 0.7, 200],
51
- ["The recipe for making the perfect chocolate cake requires", 0.7, 200],
52
- ["In the year 2150, humanity had finally achieved", 0.7, 200],
53
- ["The most important principles of effective programming are", 0.8, 300],
54
  ],
55
  article="""
56
  ## Tips for better completions:
57
  - Start with a clear and detailed prompt
58
  - Adjust temperature: Higher for creative writing, lower for factual completion
59
  - Adjust max length based on how much text you want to generate
60
- """,
61
  )
62
 
63
  if __name__ == "__main__":
 
9
  def generate_response(message, temperature, max_length):
10
  # Tokenize input
11
  inputs = tokenizer(message, return_tensors="pt", truncation=True, max_length=512)
12
+ input_ids = inputs["input_ids"]
13
+ current_text = message
 
14
 
15
  # Generate response token by token
16
+ for _ in range(max_length - input_ids.shape[1]):
17
+ with torch.no_grad():
18
+ outputs = model(input_ids)
19
+ next_token_logits = outputs.logits[:, -1, :]
20
+
21
+ # Apply temperature
22
+ next_token_logits = next_token_logits / temperature
23
+
24
+ # Sample from the distribution
25
+ probs = torch.softmax(next_token_logits, dim=-1)
26
+ next_token = torch.multinomial(probs, num_samples=1)
27
+
28
+ # Stop if we generate an EOS token
29
+ if next_token.item() == tokenizer.eos_token_id:
30
+ break
31
+
32
+ # Append the new token to input_ids
33
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
34
+
35
+ # Decode only the new token and add it to current text
36
+ new_token_text = tokenizer.decode(next_token[0], skip_special_tokens=True)
37
+ current_text += new_token_text
38
+ yield current_text
39
 
40
  # Create the Gradio interface
41
  demo = gr.Interface(
 
49
  title="LLaMA Text Completion",
50
  description="Generate text completions using the ELN-Llama-1B model. Enter the start of a text, and the model will continue it.",
51
  examples=[
52
+ ["Once upon a time in a magical forest", 0.7, 50],
53
+ ["The recipe for making the perfect chocolate cake requires", 0.7, 50],
54
+ ["In the year 2150, humanity had finally achieved", 0.7, 50],
55
+ ["The most important principles of effective programming are", 0.8, 50],
56
  ],
57
  article="""
58
  ## Tips for better completions:
59
  - Start with a clear and detailed prompt
60
  - Adjust temperature: Higher for creative writing, lower for factual completion
61
  - Adjust max length based on how much text you want to generate
62
+ """
63
  )
64
 
65
  if __name__ == "__main__":