ryanpdwyer commited on
Commit
3618983
·
1 Parent(s): af6ac26

Add application file

Browse files
Files changed (1) hide show
  1. app.py +43 -0
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+
5
+ # Load models and tokenizers
6
+ @st.cache_resource
7
+ def load_model_and_tokenizer(model_name):
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ # Load the model in 8-bit quantization
10
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu", load_in_8bit=True)
11
+ return model, tokenizer
12
+
13
+ model_8b, tokenizer_8b = load_model_and_tokenizer("huggyllama/llama-3.1-8b")
14
+ model_8b_instruct, tokenizer_8b_instruct = load_model_and_tokenizer("huggyllama/llama-3.1-8b-instruct")
15
+
16
+ def generate_text(model, tokenizer, prompt, max_length=100):
17
+ inputs = tokenizer(prompt, return_tensors="pt")
18
+ with torch.no_grad():
19
+ outputs = model.generate(**inputs, max_length=max_length, num_return_sequences=1)
20
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
21
+
22
+ st.title("LLaMA-3.1-8B vs LLaMA-3.1-8B-Instruct Comparison (CPU Version)")
23
+
24
+ prompt = st.text_area("Enter your prompt:", height=100)
25
+ max_length = st.slider("Max output length:", min_value=50, max_value=500, value=100)
26
+
27
+ if st.button("Generate"):
28
+ if prompt:
29
+ st.warning("Generation may take several minutes. Please be patient.")
30
+
31
+ col1, col2 = st.columns(2)
32
+
33
+ with col1:
34
+ st.subheader("LLaMA-3.1-8B Output")
35
+ output_8b = generate_text(model_8b, tokenizer_8b, prompt, max_length)
36
+ st.write(output_8b)
37
+
38
+ with col2:
39
+ st.subheader("LLaMA-3.1-8B-Instruct Output")
40
+ output_8b_instruct = generate_text(model_8b_instruct, tokenizer_8b_instruct, prompt, max_length)
41
+ st.write(output_8b_instruct)
42
+ else:
43
+ st.warning("Please enter a prompt.")