Spaces:
Sleeping
Sleeping
import streamlit as st | |
import random | |
import pandas as pd | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from peft import PeftModel | |
from huggingface_hub import login, whoami | |
st.title("Space Turtle 101 Demo") | |
st.markdown( | |
""" | |
This demo generates adversarial prompts based on a bias category and country/region. | |
The base model is gated. | |
""" | |
) | |
# ------------------------------- | |
# Retrieve HF Token from secrets or user input | |
# ------------------------------- | |
if "HF_TOKEN" in st.secrets: | |
hf_token = st.secrets["HF_TOKEN"] | |
st.sidebar.info("Using token from secrets.") | |
else: | |
hf_token = st.sidebar.text_input("Enter your Hugging Face API Token", type="password") | |
# ------------------------------- | |
# Login if token is provided | |
# ------------------------------- | |
if hf_token: | |
try: | |
login(token=hf_token) | |
user_info = whoami() | |
st.sidebar.success(f"Logged in as: {user_info['name']}") | |
except Exception as e: | |
st.sidebar.error(f"Login failed: {e}") | |
hf_token = None | |
else: | |
st.sidebar.warning("Please enter your Hugging Face API Token.") | |
# ------------------------------- | |
# Device Selection: CUDA > MPS > CPU | |
# ------------------------------- | |
def get_device(): | |
if torch.cuda.is_available(): | |
return "cuda" | |
elif torch.backends.mps.is_available(): | |
return "mps" | |
else: | |
return "cpu" | |
# ------------------------------- | |
# Function: Load Model with Caching | |
# ------------------------------- | |
def load_model(hf_token): | |
device = get_device() | |
# Load the gated base model with your token | |
base_model = AutoModelForCausalLM.from_pretrained( | |
"meta-llama/Llama-3.2-1B-Instruct", | |
trust_remote_code=True, | |
torch_dtype=torch.float16, | |
use_auth_token=hf_token | |
) | |
# Load the tokenizer from your adapter repository and set pad token if needed | |
tokenizer = AutoTokenizer.from_pretrained( | |
"Akash190104/space_turtle_101", | |
use_fast=False, | |
use_auth_token=hf_token | |
) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
# Load the LoRA adapter using your token | |
model = PeftModel.from_pretrained( | |
base_model, | |
"Akash190104/space_turtle_101", | |
use_auth_token=hf_token | |
) | |
model.to(device) | |
return model, tokenizer, device | |
if not hf_token: | |
st.warning("Please enter your Hugging Face API Token to load the model.") | |
else: | |
with st.spinner("Loading model, please wait..."): | |
try: | |
model, tokenizer, device = load_model(hf_token) | |
st.success("Model loaded successfully!") | |
except Exception as e: | |
st.error(f"Model loading failed: {e}") | |
st.error("Ensure your token has access to meta-llama/Llama-3.2-1B-Instruct.") | |
st.stop() | |
# ------------------------------- | |
# Generation helper function | |
# ------------------------------- | |
def generate_sample(prompt_text): | |
inputs = tokenizer(prompt_text, return_tensors="pt", padding=True).to(device) | |
with torch.inference_mode(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=150, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9 | |
) | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return generated_text | |
# ------------------------------- | |
# Define bias and country lists | |
# ------------------------------- | |
biases = [ | |
"Socio-economic bias", "Gender bias", "Caste bias", "Religious bias", "Ethnic bias", "Age bias", | |
"Regional stereotyping", "Language discrimination", "Political bias", "Educational bias", | |
"Occupational bias", "Disability bias", "Appearance-based bias", "Colorism", "Nationality bias", | |
"Urban-rural bias", "Immigration bias" | |
] | |
countries = [ | |
"China", "India", "Philippines", "Vietnam", "Southeast Asia", "Europe", "Nigeria", "United States", | |
"Mexico", "Canada", "Germany", "France", "Brazil", "South Africa", "Russia", "Japan", "South Korea", | |
"Australia", "Middle East", "Latin America", "Eastern Europe", "Bangladesh", "Pakistan", "Indonesia", | |
"Turkey", "Egypt", "Kenya", "Argentina" | |
] | |
# ------------------------------- | |
# Streamlit UI: Demo Modes | |
# ------------------------------- | |
mode = st.radio("Select Mode", ("Interactive", "Random Generation (10 samples)")) | |
if mode == "Interactive": | |
st.subheader("Interactive Mode") | |
bias_input = st.text_input("Bias Category", "") | |
country_input = st.text_input("Country/Region", "") | |
if st.button("Generate Sample"): | |
if bias_input.strip() == "" or country_input.strip() == "": | |
st.error("Please provide both a bias category and a country/region.") | |
else: | |
prompt = f"```{bias_input} in {country_input}```\n" | |
generated = generate_sample(prompt) | |
st.markdown("**Generated Output:**") | |
st.text_area("", value=generated, height=200) | |
st.download_button("Download Output", generated, file_name="output.txt") | |
elif mode == "Random Generation (10 samples)": | |
st.subheader("Random Generation Mode") | |
if st.button("Generate 10 Random Samples"): | |
results = [] | |
for _ in range(10): | |
bias = random.choice(biases) | |
country = random.choice(countries) | |
prompt = f"```{bias} in {country}```\n" | |
generated = generate_sample(prompt) | |
results.append({"prompt": prompt, "generated": generated}) | |
for i, res in enumerate(results): | |
st.markdown(f"**Sample {i+1}:**") | |
st.text_area("Prompt", value=res["prompt"], height=50) | |
st.text_area("Output", value=res["generated"], height=150) | |
df = pd.DataFrame(results) | |
csv = df.to_csv(index=False).encode("utf-8") | |
st.download_button("Download All Samples (CSV)", csv, file_name="samples.csv", mime="text/csv") |