Spaces:
Sleeping
Sleeping
Commit
·
805c816
1
Parent(s):
8726172
Streaming enabled
Browse files
app.py
CHANGED
@@ -14,8 +14,6 @@ st.markdown(
|
|
14 |
"""
|
15 |
)
|
16 |
|
17 |
-
|
18 |
-
|
19 |
hf_token = st.sidebar.text_input("Enter your Hugging Face API Token", type="password")
|
20 |
|
21 |
# -------------------------------
|
@@ -49,14 +47,12 @@ def get_device():
|
|
49 |
@st.cache_resource(show_spinner=True)
|
50 |
def load_model(hf_token):
|
51 |
device = get_device()
|
52 |
-
# Load the gated base model with your token
|
53 |
base_model = AutoModelForCausalLM.from_pretrained(
|
54 |
"meta-llama/Llama-3.2-1B-Instruct",
|
55 |
trust_remote_code=True,
|
56 |
torch_dtype=torch.float16,
|
57 |
use_auth_token=hf_token
|
58 |
)
|
59 |
-
# Load the tokenizer from your adapter repository and set pad token if needed
|
60 |
tokenizer = AutoTokenizer.from_pretrained(
|
61 |
"Akash190104/space_turtle_101",
|
62 |
use_fast=False,
|
@@ -65,7 +61,6 @@ def load_model(hf_token):
|
|
65 |
if tokenizer.pad_token is None:
|
66 |
tokenizer.pad_token = tokenizer.eos_token
|
67 |
|
68 |
-
# Load the LoRA adapter using your token
|
69 |
model = PeftModel.from_pretrained(
|
70 |
base_model,
|
71 |
"Akash190104/space_turtle_101",
|
@@ -83,23 +78,27 @@ else:
|
|
83 |
st.success("Model loaded successfully!")
|
84 |
except Exception as e:
|
85 |
st.error(f"Model loading failed: {e}")
|
86 |
-
st.error("Ensure your token has access to meta-llama/Llama-3.2-1B-Instruct.")
|
87 |
st.stop()
|
88 |
|
89 |
# -------------------------------
|
90 |
-
# Generation
|
91 |
# -------------------------------
|
92 |
-
def
|
93 |
inputs = tokenizer(prompt_text, return_tensors="pt", padding=True).to(device)
|
|
|
|
|
94 |
with torch.inference_mode():
|
95 |
outputs = model.generate(
|
96 |
**inputs,
|
97 |
max_new_tokens=150,
|
98 |
do_sample=True,
|
99 |
temperature=0.7,
|
100 |
-
top_p=0.9
|
|
|
101 |
)
|
102 |
-
|
|
|
|
|
103 |
return generated_text
|
104 |
|
105 |
# -------------------------------
|
@@ -133,10 +132,11 @@ if mode == "Interactive":
|
|
133 |
st.error("Please provide both a bias category and a country/region.")
|
134 |
else:
|
135 |
prompt = f"```{bias_input} in {country_input}```\n"
|
136 |
-
generated =
|
137 |
st.markdown("**Generated Output:**")
|
138 |
st.text_area("", value=generated, height=200)
|
139 |
st.download_button("Download Output", generated, file_name="output.txt")
|
|
|
140 |
elif mode == "Random Generation (10 samples)":
|
141 |
st.subheader("Random Generation Mode")
|
142 |
if st.button("Generate 10 Random Samples"):
|
@@ -145,7 +145,7 @@ elif mode == "Random Generation (10 samples)":
|
|
145 |
bias = random.choice(biases)
|
146 |
country = random.choice(countries)
|
147 |
prompt = f"```{bias} in {country}```\n"
|
148 |
-
generated =
|
149 |
results.append({"prompt": prompt, "generated": generated})
|
150 |
for i, res in enumerate(results):
|
151 |
st.markdown(f"**Sample {i+1}:**")
|
|
|
14 |
"""
|
15 |
)
|
16 |
|
|
|
|
|
17 |
hf_token = st.sidebar.text_input("Enter your Hugging Face API Token", type="password")
|
18 |
|
19 |
# -------------------------------
|
|
|
47 |
@st.cache_resource(show_spinner=True)
|
48 |
def load_model(hf_token):
|
49 |
device = get_device()
|
|
|
50 |
base_model = AutoModelForCausalLM.from_pretrained(
|
51 |
"meta-llama/Llama-3.2-1B-Instruct",
|
52 |
trust_remote_code=True,
|
53 |
torch_dtype=torch.float16,
|
54 |
use_auth_token=hf_token
|
55 |
)
|
|
|
56 |
tokenizer = AutoTokenizer.from_pretrained(
|
57 |
"Akash190104/space_turtle_101",
|
58 |
use_fast=False,
|
|
|
61 |
if tokenizer.pad_token is None:
|
62 |
tokenizer.pad_token = tokenizer.eos_token
|
63 |
|
|
|
64 |
model = PeftModel.from_pretrained(
|
65 |
base_model,
|
66 |
"Akash190104/space_turtle_101",
|
|
|
78 |
st.success("Model loaded successfully!")
|
79 |
except Exception as e:
|
80 |
st.error(f"Model loading failed: {e}")
|
|
|
81 |
st.stop()
|
82 |
|
83 |
# -------------------------------
|
84 |
+
# Streaming Generation Function
|
85 |
# -------------------------------
|
86 |
+
def stream_generate(prompt_text):
|
87 |
inputs = tokenizer(prompt_text, return_tensors="pt", padding=True).to(device)
|
88 |
+
result_placeholder = st.empty() # Placeholder for streaming text
|
89 |
+
generated_text = ""
|
90 |
with torch.inference_mode():
|
91 |
outputs = model.generate(
|
92 |
**inputs,
|
93 |
max_new_tokens=150,
|
94 |
do_sample=True,
|
95 |
temperature=0.7,
|
96 |
+
top_p=0.9,
|
97 |
+
streamer=True # Enables token-by-token streaming
|
98 |
)
|
99 |
+
for token in outputs:
|
100 |
+
generated_text += tokenizer.decode([token], skip_special_tokens=True)
|
101 |
+
result_placeholder.text(generated_text) # Update in real-time
|
102 |
return generated_text
|
103 |
|
104 |
# -------------------------------
|
|
|
132 |
st.error("Please provide both a bias category and a country/region.")
|
133 |
else:
|
134 |
prompt = f"```{bias_input} in {country_input}```\n"
|
135 |
+
generated = stream_generate(prompt)
|
136 |
st.markdown("**Generated Output:**")
|
137 |
st.text_area("", value=generated, height=200)
|
138 |
st.download_button("Download Output", generated, file_name="output.txt")
|
139 |
+
|
140 |
elif mode == "Random Generation (10 samples)":
|
141 |
st.subheader("Random Generation Mode")
|
142 |
if st.button("Generate 10 Random Samples"):
|
|
|
145 |
bias = random.choice(biases)
|
146 |
country = random.choice(countries)
|
147 |
prompt = f"```{bias} in {country}```\n"
|
148 |
+
generated = stream_generate(prompt)
|
149 |
results.append({"prompt": prompt, "generated": generated})
|
150 |
for i, res in enumerate(results):
|
151 |
st.markdown(f"**Sample {i+1}:**")
|