Spaces:
Sleeping
Sleeping
Commit
·
360a42b
1
Parent(s):
d58d5d0
first commit
Browse files- app.py +112 -0
- requirements.txt +5 -0
app.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import random
|
3 |
+
import pandas as pd
|
4 |
+
import torch
|
5 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
6 |
+
from peft import PeftModel
|
7 |
+
|
8 |
+
# -------------------------------
|
9 |
+
# Cache the model to load only once
|
10 |
+
# -------------------------------
|
11 |
+
@st.cache_resource(show_spinner=True)
|
12 |
+
def load_model():
|
13 |
+
# Determine device
|
14 |
+
if torch.cuda.is_available():
|
15 |
+
device = "cuda"
|
16 |
+
elif torch.backends.mps.is_available():
|
17 |
+
device = "mps"
|
18 |
+
else:
|
19 |
+
device = "cpu"
|
20 |
+
|
21 |
+
# Load the base model without offloading to avoid meta parameters issues
|
22 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
23 |
+
"meta-llama/Llama-3.2-1B-Instruct",
|
24 |
+
trust_remote_code=True,
|
25 |
+
torch_dtype=torch.float16
|
26 |
+
)
|
27 |
+
|
28 |
+
# Load tokenizer from your adapter repo and set pad token if needed
|
29 |
+
tokenizer = AutoTokenizer.from_pretrained("Akash190104/space_turtle_101", use_fast=False)
|
30 |
+
if tokenizer.pad_token is None:
|
31 |
+
tokenizer.pad_token = tokenizer.eos_token
|
32 |
+
|
33 |
+
# Load the LoRA adapter
|
34 |
+
model = PeftModel.from_pretrained(base_model, "Akash190104/space_turtle_101")
|
35 |
+
model.to(device)
|
36 |
+
|
37 |
+
return model, tokenizer, device
|
38 |
+
|
39 |
+
with st.spinner("Loading model, please wait..."):
|
40 |
+
model, tokenizer, device = load_model()
|
41 |
+
|
42 |
+
# -------------------------------
|
43 |
+
# Define helper function for generation
|
44 |
+
# -------------------------------
|
45 |
+
def generate_sample(prompt_text):
|
46 |
+
inputs = tokenizer(prompt_text, return_tensors="pt", padding=True).to(device)
|
47 |
+
with torch.inference_mode():
|
48 |
+
outputs = model.generate(
|
49 |
+
**inputs,
|
50 |
+
max_new_tokens=150,
|
51 |
+
do_sample=True,
|
52 |
+
temperature=0.7,
|
53 |
+
top_p=0.9
|
54 |
+
)
|
55 |
+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
56 |
+
return generated_text
|
57 |
+
|
58 |
+
# -------------------------------
|
59 |
+
# Define bias and country lists
|
60 |
+
# -------------------------------
|
61 |
+
biases = [
|
62 |
+
"Socio-economic bias", "Gender bias", "Caste bias", "Religious bias", "Ethnic bias", "Age bias",
|
63 |
+
"Regional stereotyping", "Language discrimination", "Political bias", "Educational bias",
|
64 |
+
"Occupational bias", "Disability bias", "Appearance-based bias", "Colorism", "Nationality bias",
|
65 |
+
"Urban-rural bias", "Immigration bias"
|
66 |
+
]
|
67 |
+
|
68 |
+
countries = [
|
69 |
+
"China", "India", "Philippines", "Vietnam", "Southeast Asia", "Europe", "Nigeria", "United States",
|
70 |
+
"Mexico", "Canada", "Germany", "France", "Brazil", "South Africa", "Russia", "Japan", "South Korea",
|
71 |
+
"Australia", "Middle East", "Latin America", "Eastern Europe", "Bangladesh", "Pakistan", "Indonesia",
|
72 |
+
"Turkey", "Egypt", "Kenya", "Argentina"
|
73 |
+
]
|
74 |
+
|
75 |
+
# -------------------------------
|
76 |
+
# Streamlit UI
|
77 |
+
# -------------------------------
|
78 |
+
st.title("Space Turtle 101 Demo")
|
79 |
+
st.markdown("Generate adversarial prompts based on bias category and country/region.")
|
80 |
+
|
81 |
+
mode = st.radio("Select Mode", ("Interactive", "Random Generation (10 samples)"))
|
82 |
+
|
83 |
+
if mode == "Interactive":
|
84 |
+
st.subheader("Interactive Mode")
|
85 |
+
bias_input = st.text_input("Bias Category", "")
|
86 |
+
country_input = st.text_input("Country/Region", "")
|
87 |
+
if st.button("Generate Sample"):
|
88 |
+
if bias_input.strip() == "" or country_input.strip() == "":
|
89 |
+
st.error("Please provide both a bias category and a country/region.")
|
90 |
+
else:
|
91 |
+
prompt = f"```{bias_input} in {country_input}```\n"
|
92 |
+
generated = generate_sample(prompt)
|
93 |
+
st.markdown("**Generated Output:**")
|
94 |
+
st.text_area("", value=generated, height=200)
|
95 |
+
st.download_button("Download Output", generated, file_name="output.txt")
|
96 |
+
elif mode == "Random Generation (10 samples)":
|
97 |
+
st.subheader("Random Generation Mode")
|
98 |
+
if st.button("Generate 10 Random Samples"):
|
99 |
+
results = []
|
100 |
+
for _ in range(10):
|
101 |
+
bias = random.choice(biases)
|
102 |
+
country = random.choice(countries)
|
103 |
+
prompt = f"```{bias} in {country}```\n"
|
104 |
+
generated = generate_sample(prompt)
|
105 |
+
results.append({"prompt": prompt, "generated": generated})
|
106 |
+
for i, res in enumerate(results):
|
107 |
+
st.markdown(f"**Sample {i+1}:**")
|
108 |
+
st.text_area("Prompt", value=res["prompt"], height=50)
|
109 |
+
st.text_area("Output", value=res["generated"], height=150)
|
110 |
+
df = pd.DataFrame(results)
|
111 |
+
csv = df.to_csv(index=False).encode("utf-8")
|
112 |
+
st.download_button("Download All Samples (CSV)", csv, file_name="samples.csv", mime="text/csv")
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
torch
|
3 |
+
transformers
|
4 |
+
peft
|
5 |
+
pandas
|