space_turtle / app.py
Akash190104's picture
first commit
360a42b
raw
history blame
4.47 kB
import streamlit as st
import random
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
# -------------------------------
# Cache the model to load only once
# -------------------------------
@st.cache_resource(show_spinner=True)
def load_model():
# Determine device
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
# Load the base model without offloading to avoid meta parameters issues
base_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B-Instruct",
trust_remote_code=True,
torch_dtype=torch.float16
)
# Load tokenizer from your adapter repo and set pad token if needed
tokenizer = AutoTokenizer.from_pretrained("Akash190104/space_turtle_101", use_fast=False)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load the LoRA adapter
model = PeftModel.from_pretrained(base_model, "Akash190104/space_turtle_101")
model.to(device)
return model, tokenizer, device
with st.spinner("Loading model, please wait..."):
model, tokenizer, device = load_model()
# -------------------------------
# Define helper function for generation
# -------------------------------
def generate_sample(prompt_text):
inputs = tokenizer(prompt_text, return_tensors="pt", padding=True).to(device)
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=150,
do_sample=True,
temperature=0.7,
top_p=0.9
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text
# -------------------------------
# Define bias and country lists
# -------------------------------
biases = [
"Socio-economic bias", "Gender bias", "Caste bias", "Religious bias", "Ethnic bias", "Age bias",
"Regional stereotyping", "Language discrimination", "Political bias", "Educational bias",
"Occupational bias", "Disability bias", "Appearance-based bias", "Colorism", "Nationality bias",
"Urban-rural bias", "Immigration bias"
]
countries = [
"China", "India", "Philippines", "Vietnam", "Southeast Asia", "Europe", "Nigeria", "United States",
"Mexico", "Canada", "Germany", "France", "Brazil", "South Africa", "Russia", "Japan", "South Korea",
"Australia", "Middle East", "Latin America", "Eastern Europe", "Bangladesh", "Pakistan", "Indonesia",
"Turkey", "Egypt", "Kenya", "Argentina"
]
# -------------------------------
# Streamlit UI
# -------------------------------
st.title("Space Turtle 101 Demo")
st.markdown("Generate adversarial prompts based on bias category and country/region.")
mode = st.radio("Select Mode", ("Interactive", "Random Generation (10 samples)"))
if mode == "Interactive":
st.subheader("Interactive Mode")
bias_input = st.text_input("Bias Category", "")
country_input = st.text_input("Country/Region", "")
if st.button("Generate Sample"):
if bias_input.strip() == "" or country_input.strip() == "":
st.error("Please provide both a bias category and a country/region.")
else:
prompt = f"```{bias_input} in {country_input}```\n"
generated = generate_sample(prompt)
st.markdown("**Generated Output:**")
st.text_area("", value=generated, height=200)
st.download_button("Download Output", generated, file_name="output.txt")
elif mode == "Random Generation (10 samples)":
st.subheader("Random Generation Mode")
if st.button("Generate 10 Random Samples"):
results = []
for _ in range(10):
bias = random.choice(biases)
country = random.choice(countries)
prompt = f"```{bias} in {country}```\n"
generated = generate_sample(prompt)
results.append({"prompt": prompt, "generated": generated})
for i, res in enumerate(results):
st.markdown(f"**Sample {i+1}:**")
st.text_area("Prompt", value=res["prompt"], height=50)
st.text_area("Output", value=res["generated"], height=150)
df = pd.DataFrame(results)
csv = df.to_csv(index=False).encode("utf-8")
st.download_button("Download All Samples (CSV)", csv, file_name="samples.csv", mime="text/csv")