namannn commited on
Commit
a992249
·
verified ·
1 Parent(s): 82a2d2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -38
app.py CHANGED
@@ -1,42 +1,81 @@
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
3
 
4
- # Load model and tokenizer
5
- tokenizer = AutoTokenizer.from_pretrained("namannn/llama2-13b-hyperbolic-cluster-pruned")
6
- model = AutoModelForCausalLM.from_pretrained("namannn/llama2-13b-hyperbolic-cluster-pruned")
7
-
8
- # Streamlit UI components
9
- st.title("Text Generation with LLaMa2-13b Hyperbolic Model")
10
- st.write("Enter a prompt below and the model will generate text.")
11
-
12
- # User input for prompt
13
- prompt = st.text_area("Input Prompt", "Once upon a time, in a land far away")
14
-
15
- # Slider for controlling the length of the output
16
- max_length = st.slider("Max Length of Generated Text", min_value=50, max_value=200, value=100)
17
-
18
- # Button to trigger text generation
19
- if st.button("Generate Text"):
20
- if prompt:
21
- # Encode the prompt text
22
- inputs = tokenizer(prompt, return_tensors="pt")
23
-
24
- # Generate text with the model
25
- outputs = model.generate(
26
- inputs["input_ids"],
27
- max_length=max_length,
28
- num_return_sequences=1,
29
- no_repeat_ngram_size=2, # You can tune this for diversity
30
- do_sample=True, # Use sampling for diverse generation
31
- top_k=50, # Top-k sampling for diversity
32
- top_p=0.95, # Top-p (nucleus) sampling
33
- temperature=0.7 # Control randomness (lower = more deterministic)
34
- )
35
-
36
- # Decode and display generated text
37
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
38
- st.subheader("Generated Text:")
39
- st.write(generated_text)
40
- else:
41
- st.warning("Please enter a prompt to generate text.")
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
 
5
+ @st.cache_resource
6
+ def load_model_and_tokenizer():
7
+ """
8
+ Load model and tokenizer with Streamlit's caching to prevent reloading.
9
+ @st.cache_resource ensures the model is loaded only once per session.
10
+ """
11
+ tokenizer = AutoTokenizer.from_pretrained("namannn/llama2-13b-hyperbolic-cluster-pruned")
12
+ model = AutoModelForCausalLM.from_pretrained(
13
+ "namannn/llama2-13b-hyperbolic-cluster-pruned",
14
+ # Optional: specify device and precision to optimize loading
15
+ device_map="auto", # Automatically distribute model across available GPUs/CPU
16
+ torch_dtype=torch.float16, # Use half precision to reduce memory usage
17
+ low_cpu_mem_usage=True # Optimize memory usage during model loading
18
+ )
19
+ return tokenizer, model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ def generate_text(prompt, tokenizer, model, max_length):
22
+ """
23
+ Generate text using the loaded model and tokenizer.
24
+ """
25
+ # Encode the prompt text
26
+ inputs = tokenizer(prompt, return_tensors="pt")
27
+
28
+ # Generate text with the model
29
+ outputs = model.generate(
30
+ inputs["input_ids"],
31
+ max_length=max_length,
32
+ num_return_sequences=1,
33
+ no_repeat_ngram_size=2,
34
+ do_sample=True,
35
+ top_k=50,
36
+ top_p=0.95,
37
+ temperature=0.7
38
+ )
39
+
40
+ # Decode and return generated text
41
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
42
+ return generated_text
43
+
44
+ def main():
45
+ # Set page title and icon
46
+ st.set_page_config(page_title="LLaMa2 Text Generation", page_icon="✍️")
47
+
48
+ # Page title and description
49
+ st.title("Text Generation with LLaMa2-13b Hyperbolic Model")
50
+ st.write("Enter a prompt below and the model will generate text.")
51
+
52
+ # Load model and tokenizer (only once)
53
+ try:
54
+ tokenizer, model = load_model_and_tokenizer()
55
+ except Exception as e:
56
+ st.error(f"Error loading model: {e}")
57
+ return
58
+
59
+ # User input for prompt
60
+ prompt = st.text_area("Input Prompt", "Once upon a time, in a land far away")
61
+
62
+ # Slider for controlling the length of the output
63
+ max_length = st.slider("Max Length of Generated Text", min_value=50, max_value=200, value=100)
64
+
65
+ # Button to trigger text generation
66
+ if st.button("Generate Text"):
67
+ if prompt:
68
+ try:
69
+ # Generate text
70
+ generated_text = generate_text(prompt, tokenizer, model, max_length)
71
+
72
+ # Display generated text
73
+ st.subheader("Generated Text:")
74
+ st.write(generated_text)
75
+ except Exception as e:
76
+ st.error(f"Error generating text: {e}")
77
+ else:
78
+ st.warning("Please enter a prompt to generate text.")
79
+
80
+ if __name__ == "__main__":
81
+ main()