namannn commited on
Commit
769c112
·
verified ·
1 Parent(s): 50cd6ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -37
app.py CHANGED
@@ -1,66 +1,92 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
 
4
 
5
  @st.cache_resource
6
  def load_model_and_tokenizer():
7
  """
8
  Load model and tokenizer with Streamlit's caching to prevent reloading.
9
- @st.cache_resource ensures the model is loaded only once per session.
10
  """
11
- tokenizer = AutoTokenizer.from_pretrained("namannn/llama2-13b-hyperbolic-cluster-pruned")
12
- model = AutoModelForCausalLM.from_pretrained(
13
- "namannn/llama2-13b-hyperbolic-cluster-pruned",
14
- # Optional: specify device and precision to optimize loading
15
- device_map="auto", # Automatically distribute model across available GPUs/CPU
16
- torch_dtype=torch.float16, # Use half precision to reduce memory usage
17
- low_cpu_mem_usage=True # Optimize memory usage during model loading
18
- )
19
- return tokenizer, model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def generate_text(prompt, tokenizer, model, max_length):
22
  """
23
- Generate text using the loaded model and tokenizer.
24
  """
25
- # Encode the prompt text
26
- inputs = tokenizer(prompt, return_tensors="pt")
27
-
28
- # Generate text with the model
29
- outputs = model.generate(
30
- inputs["input_ids"],
31
- max_length=max_length,
32
- num_return_sequences=1,
33
- no_repeat_ngram_size=2,
34
- do_sample=True,
35
- top_k=50,
36
- top_p=0.95,
37
- temperature=0.7
38
- )
39
-
40
- # Decode and return generated text
41
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
42
- return generated_text
 
 
 
 
 
 
 
 
43
 
44
  def main():
45
- # Set page title and icon
46
  st.set_page_config(page_title="LLaMa2 Text Generation", page_icon="✍️")
47
 
48
  # Page title and description
49
  st.title("Text Generation with LLaMa2-13b Hyperbolic Model")
50
  st.write("Enter a prompt below and the model will generate text.")
51
 
52
- # Load model and tokenizer (only once)
53
  try:
54
  tokenizer, model = load_model_and_tokenizer()
55
  except Exception as e:
56
- st.error(f"Error loading model: {e}")
57
  return
58
 
 
 
 
 
 
59
  # User input for prompt
60
  prompt = st.text_area("Input Prompt", "Once upon a time, in a land far away")
61
 
62
  # Slider for controlling the length of the output
63
- max_length = st.slider("Max Length of Generated Text", min_value=50, max_value=200, value=100)
64
 
65
  # Button to trigger text generation
66
  if st.button("Generate Text"):
@@ -70,10 +96,13 @@ def main():
70
  generated_text = generate_text(prompt, tokenizer, model, max_length)
71
 
72
  # Display generated text
73
- st.subheader("Generated Text:")
74
- st.write(generated_text)
 
 
 
75
  except Exception as e:
76
- st.error(f"Error generating text: {e}")
77
  else:
78
  st.warning("Please enter a prompt to generate text.")
79
 
 
1
  import streamlit as st
 
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
  @st.cache_resource
6
  def load_model_and_tokenizer():
7
  """
8
  Load model and tokenizer with Streamlit's caching to prevent reloading.
 
9
  """
10
+ try:
11
+ tokenizer = AutoTokenizer.from_pretrained(
12
+ "namannn/llama2-13b-hyperbolic-cluster-pruned",
13
+ use_fast=True, # Use fast tokenizer if available
14
+ trust_remote_code=True # Trust remote code for custom tokenizers
15
+ )
16
+
17
+ # Ensure pad_token is set
18
+ if tokenizer.pad_token is None:
19
+ tokenizer.pad_token = tokenizer.eos_token
20
+
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ "namannn/llama2-13b-hyperbolic-cluster-pruned",
23
+ device_map="auto",
24
+ torch_dtype=torch.float16,
25
+ low_cpu_mem_usage=True,
26
+ trust_remote_code=True # Trust remote code for custom models
27
+ )
28
+
29
+ return tokenizer, model
30
+ except Exception as e:
31
+ st.error(f"Error loading model: {e}")
32
+ raise
33
 
34
  def generate_text(prompt, tokenizer, model, max_length):
35
  """
36
+ Generate text using the loaded model and tokenizer with detailed error handling.
37
  """
38
+ try:
39
+ # Ensure input is on the correct device
40
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
41
+
42
+ # Generate text with more explicit parameters
43
+ with torch.no_grad(): # Disable gradient calculation
44
+ outputs = model.generate(
45
+ input_ids=inputs["input_ids"],
46
+ attention_mask=inputs.get("attention_mask"),
47
+ max_length=max_length + len(inputs["input_ids"][0]),
48
+ num_return_sequences=1,
49
+ no_repeat_ngram_size=2,
50
+ do_sample=True,
51
+ top_k=50,
52
+ top_p=0.95,
53
+ temperature=0.7,
54
+ pad_token_id=tokenizer.eos_token_id
55
+ )
56
+
57
+ # Decode the generated text
58
+ generated_text = tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
59
+
60
+ return generated_text.strip()
61
+ except Exception as e:
62
+ st.error(f"Error generating text: {e}")
63
+ return None
64
 
65
  def main():
66
+ # Set page configuration
67
  st.set_page_config(page_title="LLaMa2 Text Generation", page_icon="✍️")
68
 
69
  # Page title and description
70
  st.title("Text Generation with LLaMa2-13b Hyperbolic Model")
71
  st.write("Enter a prompt below and the model will generate text.")
72
 
73
+ # Load model and tokenizer
74
  try:
75
  tokenizer, model = load_model_and_tokenizer()
76
  except Exception as e:
77
+ st.error(f"Failed to load model: {e}")
78
  return
79
 
80
+ # System information
81
+ st.sidebar.header("System Information")
82
+ st.sidebar.write(f"Device: {model.device}")
83
+ st.sidebar.write(f"Model Dtype: {model.dtype}")
84
+
85
  # User input for prompt
86
  prompt = st.text_area("Input Prompt", "Once upon a time, in a land far away")
87
 
88
  # Slider for controlling the length of the output
89
+ max_length = st.slider("Max Length of Generated Text", min_value=50, max_value=500, value=150)
90
 
91
  # Button to trigger text generation
92
  if st.button("Generate Text"):
 
96
  generated_text = generate_text(prompt, tokenizer, model, max_length)
97
 
98
  # Display generated text
99
+ if generated_text:
100
+ st.subheader("Generated Text:")
101
+ st.write(generated_text)
102
+ else:
103
+ st.warning("No text was generated. Please check the input and try again.")
104
  except Exception as e:
105
+ st.error(f"Unexpected error during text generation: {e}")
106
  else:
107
  st.warning("Please enter a prompt to generate text.")
108