Akash190104 commited on
Commit
805c816
·
1 Parent(s): 8726172

Streaming enabled

Browse files
Files changed (1) hide show
  1. app.py +12 -12
app.py CHANGED
@@ -14,8 +14,6 @@ st.markdown(
14
  """
15
  )
16
 
17
-
18
-
19
  hf_token = st.sidebar.text_input("Enter your Hugging Face API Token", type="password")
20
 
21
  # -------------------------------
@@ -49,14 +47,12 @@ def get_device():
49
  @st.cache_resource(show_spinner=True)
50
  def load_model(hf_token):
51
  device = get_device()
52
- # Load the gated base model with your token
53
  base_model = AutoModelForCausalLM.from_pretrained(
54
  "meta-llama/Llama-3.2-1B-Instruct",
55
  trust_remote_code=True,
56
  torch_dtype=torch.float16,
57
  use_auth_token=hf_token
58
  )
59
- # Load the tokenizer from your adapter repository and set pad token if needed
60
  tokenizer = AutoTokenizer.from_pretrained(
61
  "Akash190104/space_turtle_101",
62
  use_fast=False,
@@ -65,7 +61,6 @@ def load_model(hf_token):
65
  if tokenizer.pad_token is None:
66
  tokenizer.pad_token = tokenizer.eos_token
67
 
68
- # Load the LoRA adapter using your token
69
  model = PeftModel.from_pretrained(
70
  base_model,
71
  "Akash190104/space_turtle_101",
@@ -83,23 +78,27 @@ else:
83
  st.success("Model loaded successfully!")
84
  except Exception as e:
85
  st.error(f"Model loading failed: {e}")
86
- st.error("Ensure your token has access to meta-llama/Llama-3.2-1B-Instruct.")
87
  st.stop()
88
 
89
  # -------------------------------
90
- # Generation helper function
91
  # -------------------------------
92
- def generate_sample(prompt_text):
93
  inputs = tokenizer(prompt_text, return_tensors="pt", padding=True).to(device)
 
 
94
  with torch.inference_mode():
95
  outputs = model.generate(
96
  **inputs,
97
  max_new_tokens=150,
98
  do_sample=True,
99
  temperature=0.7,
100
- top_p=0.9
 
101
  )
102
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
103
  return generated_text
104
 
105
  # -------------------------------
@@ -133,10 +132,11 @@ if mode == "Interactive":
133
  st.error("Please provide both a bias category and a country/region.")
134
  else:
135
  prompt = f"```{bias_input} in {country_input}```\n"
136
- generated = generate_sample(prompt)
137
  st.markdown("**Generated Output:**")
138
  st.text_area("", value=generated, height=200)
139
  st.download_button("Download Output", generated, file_name="output.txt")
 
140
  elif mode == "Random Generation (10 samples)":
141
  st.subheader("Random Generation Mode")
142
  if st.button("Generate 10 Random Samples"):
@@ -145,7 +145,7 @@ elif mode == "Random Generation (10 samples)":
145
  bias = random.choice(biases)
146
  country = random.choice(countries)
147
  prompt = f"```{bias} in {country}```\n"
148
- generated = generate_sample(prompt)
149
  results.append({"prompt": prompt, "generated": generated})
150
  for i, res in enumerate(results):
151
  st.markdown(f"**Sample {i+1}:**")
 
14
  """
15
  )
16
 
 
 
17
  hf_token = st.sidebar.text_input("Enter your Hugging Face API Token", type="password")
18
 
19
  # -------------------------------
 
47
  @st.cache_resource(show_spinner=True)
48
  def load_model(hf_token):
49
  device = get_device()
 
50
  base_model = AutoModelForCausalLM.from_pretrained(
51
  "meta-llama/Llama-3.2-1B-Instruct",
52
  trust_remote_code=True,
53
  torch_dtype=torch.float16,
54
  use_auth_token=hf_token
55
  )
 
56
  tokenizer = AutoTokenizer.from_pretrained(
57
  "Akash190104/space_turtle_101",
58
  use_fast=False,
 
61
  if tokenizer.pad_token is None:
62
  tokenizer.pad_token = tokenizer.eos_token
63
 
 
64
  model = PeftModel.from_pretrained(
65
  base_model,
66
  "Akash190104/space_turtle_101",
 
78
  st.success("Model loaded successfully!")
79
  except Exception as e:
80
  st.error(f"Model loading failed: {e}")
 
81
  st.stop()
82
 
83
  # -------------------------------
84
+ # Streaming Generation Function
85
  # -------------------------------
86
+ def stream_generate(prompt_text):
87
  inputs = tokenizer(prompt_text, return_tensors="pt", padding=True).to(device)
88
+ result_placeholder = st.empty() # Placeholder for streaming text
89
+ generated_text = ""
90
  with torch.inference_mode():
91
  outputs = model.generate(
92
  **inputs,
93
  max_new_tokens=150,
94
  do_sample=True,
95
  temperature=0.7,
96
+ top_p=0.9,
97
+ streamer=True # Enables token-by-token streaming
98
  )
99
+ for token in outputs:
100
+ generated_text += tokenizer.decode([token], skip_special_tokens=True)
101
+ result_placeholder.text(generated_text) # Update in real-time
102
  return generated_text
103
 
104
  # -------------------------------
 
132
  st.error("Please provide both a bias category and a country/region.")
133
  else:
134
  prompt = f"```{bias_input} in {country_input}```\n"
135
+ generated = stream_generate(prompt)
136
  st.markdown("**Generated Output:**")
137
  st.text_area("", value=generated, height=200)
138
  st.download_button("Download Output", generated, file_name="output.txt")
139
+
140
  elif mode == "Random Generation (10 samples)":
141
  st.subheader("Random Generation Mode")
142
  if st.button("Generate 10 Random Samples"):
 
145
  bias = random.choice(biases)
146
  country = random.choice(countries)
147
  prompt = f"```{bias} in {country}```\n"
148
+ generated = stream_generate(prompt)
149
  results.append({"prompt": prompt, "generated": generated})
150
  for i, res in enumerate(results):
151
  st.markdown(f"**Sample {i+1}:**")