Spaces:
Sleeping
Sleeping
import streamlit as st | |
import random | |
import pandas as pd | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from peft import PeftModel | |
# ------------------------------- | |
# Cache the model to load only once | |
# ------------------------------- | |
def load_model(): | |
# Determine device | |
if torch.cuda.is_available(): | |
device = "cuda" | |
elif torch.backends.mps.is_available(): | |
device = "mps" | |
else: | |
device = "cpu" | |
# Load the base model without offloading to avoid meta parameters issues | |
base_model = AutoModelForCausalLM.from_pretrained( | |
"meta-llama/Llama-3.2-1B-Instruct", | |
trust_remote_code=True, | |
torch_dtype=torch.float16 | |
) | |
# Load tokenizer from your adapter repo and set pad token if needed | |
tokenizer = AutoTokenizer.from_pretrained("Akash190104/space_turtle_101", use_fast=False) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
# Load the LoRA adapter | |
model = PeftModel.from_pretrained(base_model, "Akash190104/space_turtle_101") | |
model.to(device) | |
return model, tokenizer, device | |
with st.spinner("Loading model, please wait..."): | |
model, tokenizer, device = load_model() | |
# ------------------------------- | |
# Define helper function for generation | |
# ------------------------------- | |
def generate_sample(prompt_text): | |
inputs = tokenizer(prompt_text, return_tensors="pt", padding=True).to(device) | |
with torch.inference_mode(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=150, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9 | |
) | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return generated_text | |
# ------------------------------- | |
# Define bias and country lists | |
# ------------------------------- | |
biases = [ | |
"Socio-economic bias", "Gender bias", "Caste bias", "Religious bias", "Ethnic bias", "Age bias", | |
"Regional stereotyping", "Language discrimination", "Political bias", "Educational bias", | |
"Occupational bias", "Disability bias", "Appearance-based bias", "Colorism", "Nationality bias", | |
"Urban-rural bias", "Immigration bias" | |
] | |
countries = [ | |
"China", "India", "Philippines", "Vietnam", "Southeast Asia", "Europe", "Nigeria", "United States", | |
"Mexico", "Canada", "Germany", "France", "Brazil", "South Africa", "Russia", "Japan", "South Korea", | |
"Australia", "Middle East", "Latin America", "Eastern Europe", "Bangladesh", "Pakistan", "Indonesia", | |
"Turkey", "Egypt", "Kenya", "Argentina" | |
] | |
# ------------------------------- | |
# Streamlit UI | |
# ------------------------------- | |
st.title("Space Turtle 101 Demo") | |
st.markdown("Generate adversarial prompts based on bias category and country/region.") | |
mode = st.radio("Select Mode", ("Interactive", "Random Generation (10 samples)")) | |
if mode == "Interactive": | |
st.subheader("Interactive Mode") | |
bias_input = st.text_input("Bias Category", "") | |
country_input = st.text_input("Country/Region", "") | |
if st.button("Generate Sample"): | |
if bias_input.strip() == "" or country_input.strip() == "": | |
st.error("Please provide both a bias category and a country/region.") | |
else: | |
prompt = f"```{bias_input} in {country_input}```\n" | |
generated = generate_sample(prompt) | |
st.markdown("**Generated Output:**") | |
st.text_area("", value=generated, height=200) | |
st.download_button("Download Output", generated, file_name="output.txt") | |
elif mode == "Random Generation (10 samples)": | |
st.subheader("Random Generation Mode") | |
if st.button("Generate 10 Random Samples"): | |
results = [] | |
for _ in range(10): | |
bias = random.choice(biases) | |
country = random.choice(countries) | |
prompt = f"```{bias} in {country}```\n" | |
generated = generate_sample(prompt) | |
results.append({"prompt": prompt, "generated": generated}) | |
for i, res in enumerate(results): | |
st.markdown(f"**Sample {i+1}:**") | |
st.text_area("Prompt", value=res["prompt"], height=50) | |
st.text_area("Output", value=res["generated"], height=150) | |
df = pd.DataFrame(results) | |
csv = df.to_csv(index=False).encode("utf-8") | |
st.download_button("Download All Samples (CSV)", csv, file_name="samples.csv", mime="text/csv") |