Akash190104 commited on
Commit
3054b3b
·
1 Parent(s): aa0eeb0

test the login error

Browse files
Files changed (1) hide show
  1. app.py +88 -73
app.py CHANGED
@@ -6,19 +6,26 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
6
  from peft import PeftModel
7
  from huggingface_hub import login, whoami
8
 
9
- # Title and description
10
  st.title("Space Turtle 101 Demo")
11
  st.markdown(
12
  """
13
  This demo generates adversarial prompts based on a bias category and country/region.
14
- The model (meta-llama/Llama-3.2-1B-Instruct) is gated. Please enter your Hugging Face API token below to access it.
15
  """
16
  )
17
 
18
  # -------------------------------
19
- # Sidebar: Hugging Face API Token Input & Auto Login
 
 
 
 
 
 
 
 
 
20
  # -------------------------------
21
- hf_token = st.sidebar.text_input("Enter your Hugging Face API Token", type="password")
22
  if hf_token:
23
  try:
24
  login(token=hf_token)
@@ -27,6 +34,8 @@ if hf_token:
27
  except Exception as e:
28
  st.sidebar.error(f"Login failed: {e}")
29
  hf_token = None
 
 
30
 
31
  # -------------------------------
32
  # Device Selection: CUDA > MPS > CPU
@@ -52,7 +61,7 @@ def load_model(hf_token):
52
  torch_dtype=torch.float16,
53
  use_auth_token=hf_token
54
  )
55
- # Load the tokenizer from your adapter repo and set the pad token
56
  tokenizer = AutoTokenizer.from_pretrained(
57
  "Akash190104/space_turtle_101",
58
  use_fast=False,
@@ -61,7 +70,7 @@ def load_model(hf_token):
61
  if tokenizer.pad_token is None:
62
  tokenizer.pad_token = tokenizer.eos_token
63
 
64
- # Load the LoRA adapter with the token
65
  model = PeftModel.from_pretrained(
66
  base_model,
67
  "Akash190104/space_turtle_101",
@@ -71,76 +80,82 @@ def load_model(hf_token):
71
  return model, tokenizer, device
72
 
73
  if not hf_token:
74
- st.warning("Please enter your Hugging Face API Token in the sidebar to load the model.")
75
  else:
76
  with st.spinner("Loading model, please wait..."):
77
- model, tokenizer, device = load_model(hf_token)
78
- st.success("Model loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- # -------------------------------
81
- # Generation helper function
82
- # -------------------------------
83
- def generate_sample(prompt_text):
84
- inputs = tokenizer(prompt_text, return_tensors="pt", padding=True).to(device)
85
- with torch.inference_mode():
86
- outputs = model.generate(
87
- **inputs,
88
- max_new_tokens=150,
89
- do_sample=True,
90
- temperature=0.7,
91
- top_p=0.9
92
- )
93
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
94
- return generated_text
95
 
96
- # -------------------------------
97
- # Define bias and country lists
98
- # -------------------------------
99
- biases = [
100
- "Socio-economic bias", "Gender bias", "Caste bias", "Religious bias", "Ethnic bias", "Age bias",
101
- "Regional stereotyping", "Language discrimination", "Political bias", "Educational bias",
102
- "Occupational bias", "Disability bias", "Appearance-based bias", "Colorism", "Nationality bias",
103
- "Urban-rural bias", "Immigration bias"
104
- ]
105
- countries = [
106
- "China", "India", "Philippines", "Vietnam", "Southeast Asia", "Europe", "Nigeria", "United States",
107
- "Mexico", "Canada", "Germany", "France", "Brazil", "South Africa", "Russia", "Japan", "South Korea",
108
- "Australia", "Middle East", "Latin America", "Eastern Europe", "Bangladesh", "Pakistan", "Indonesia",
109
- "Turkey", "Egypt", "Kenya", "Argentina"
110
- ]
111
 
112
- # -------------------------------
113
- # Streamlit UI: Demo Modes
114
- # -------------------------------
115
- mode = st.radio("Select Mode", ("Interactive", "Random Generation (10 samples)"))
116
 
117
- if mode == "Interactive":
118
- st.subheader("Interactive Mode")
119
- bias_input = st.text_input("Bias Category", "")
120
- country_input = st.text_input("Country/Region", "")
121
- if st.button("Generate Sample"):
122
- if bias_input.strip() == "" or country_input.strip() == "":
123
- st.error("Please provide both a bias category and a country/region.")
124
- else:
125
- prompt = f"```{bias_input} in {country_input}```\n"
126
- generated = generate_sample(prompt)
127
- st.markdown("**Generated Output:**")
128
- st.text_area("", value=generated, height=200)
129
- st.download_button("Download Output", generated, file_name="output.txt")
130
- elif mode == "Random Generation (10 samples)":
131
- st.subheader("Random Generation Mode")
132
- if st.button("Generate 10 Random Samples"):
133
- results = []
134
- for _ in range(10):
135
- bias = random.choice(biases)
136
- country = random.choice(countries)
137
- prompt = f"```{bias} in {country}```\n"
138
- generated = generate_sample(prompt)
139
- results.append({"prompt": prompt, "generated": generated})
140
- for i, res in enumerate(results):
141
- st.markdown(f"**Sample {i+1}:**")
142
- st.text_area("Prompt", value=res["prompt"], height=50)
143
- st.text_area("Output", value=res["generated"], height=150)
144
- df = pd.DataFrame(results)
145
- csv = df.to_csv(index=False).encode("utf-8")
146
- st.download_button("Download All Samples (CSV)", csv, file_name="samples.csv", mime="text/csv")
 
6
  from peft import PeftModel
7
  from huggingface_hub import login, whoami
8
 
 
9
  st.title("Space Turtle 101 Demo")
10
  st.markdown(
11
  """
12
  This demo generates adversarial prompts based on a bias category and country/region.
13
+ The base model (meta-llama/Llama-3.2-1B-Instruct) is gated.
14
  """
15
  )
16
 
17
  # -------------------------------
18
+ # Retrieve HF Token from secrets or user input
19
+ # -------------------------------
20
+ if "HF_TOKEN" in st.secrets:
21
+ hf_token = st.secrets["HF_TOKEN"]
22
+ st.sidebar.info("Using token from secrets.")
23
+ else:
24
+ hf_token = st.sidebar.text_input("Enter your Hugging Face API Token", type="password")
25
+
26
+ # -------------------------------
27
+ # Login if token is provided
28
  # -------------------------------
 
29
  if hf_token:
30
  try:
31
  login(token=hf_token)
 
34
  except Exception as e:
35
  st.sidebar.error(f"Login failed: {e}")
36
  hf_token = None
37
+ else:
38
+ st.sidebar.warning("Please enter your Hugging Face API Token.")
39
 
40
  # -------------------------------
41
  # Device Selection: CUDA > MPS > CPU
 
61
  torch_dtype=torch.float16,
62
  use_auth_token=hf_token
63
  )
64
+ # Load the tokenizer from your adapter repository and set pad token if needed
65
  tokenizer = AutoTokenizer.from_pretrained(
66
  "Akash190104/space_turtle_101",
67
  use_fast=False,
 
70
  if tokenizer.pad_token is None:
71
  tokenizer.pad_token = tokenizer.eos_token
72
 
73
+ # Load the LoRA adapter using your token
74
  model = PeftModel.from_pretrained(
75
  base_model,
76
  "Akash190104/space_turtle_101",
 
80
  return model, tokenizer, device
81
 
82
  if not hf_token:
83
+ st.warning("Please enter your Hugging Face API Token to load the model.")
84
  else:
85
  with st.spinner("Loading model, please wait..."):
86
+ try:
87
+ model, tokenizer, device = load_model(hf_token)
88
+ st.success("Model loaded successfully!")
89
+ except Exception as e:
90
+ st.error(f"Model loading failed: {e}")
91
+ st.error("Ensure your token has access to meta-llama/Llama-3.2-1B-Instruct.")
92
+ st.stop()
93
+
94
+ # -------------------------------
95
+ # Generation helper function
96
+ # -------------------------------
97
+ def generate_sample(prompt_text):
98
+ inputs = tokenizer(prompt_text, return_tensors="pt", padding=True).to(device)
99
+ with torch.inference_mode():
100
+ outputs = model.generate(
101
+ **inputs,
102
+ max_new_tokens=150,
103
+ do_sample=True,
104
+ temperature=0.7,
105
+ top_p=0.9
106
+ )
107
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
108
+ return generated_text
109
 
110
+ # -------------------------------
111
+ # Define bias and country lists
112
+ # -------------------------------
113
+ biases = [
114
+ "Socio-economic bias", "Gender bias", "Caste bias", "Religious bias", "Ethnic bias", "Age bias",
115
+ "Regional stereotyping", "Language discrimination", "Political bias", "Educational bias",
116
+ "Occupational bias", "Disability bias", "Appearance-based bias", "Colorism", "Nationality bias",
117
+ "Urban-rural bias", "Immigration bias"
118
+ ]
 
 
 
 
 
 
119
 
120
+ countries = [
121
+ "China", "India", "Philippines", "Vietnam", "Southeast Asia", "Europe", "Nigeria", "United States",
122
+ "Mexico", "Canada", "Germany", "France", "Brazil", "South Africa", "Russia", "Japan", "South Korea",
123
+ "Australia", "Middle East", "Latin America", "Eastern Europe", "Bangladesh", "Pakistan", "Indonesia",
124
+ "Turkey", "Egypt", "Kenya", "Argentina"
125
+ ]
 
 
 
 
 
 
 
 
 
126
 
127
+ # -------------------------------
128
+ # Streamlit UI: Demo Modes
129
+ # -------------------------------
130
+ mode = st.radio("Select Mode", ("Interactive", "Random Generation (10 samples)"))
131
 
132
+ if mode == "Interactive":
133
+ st.subheader("Interactive Mode")
134
+ bias_input = st.text_input("Bias Category", "")
135
+ country_input = st.text_input("Country/Region", "")
136
+ if st.button("Generate Sample"):
137
+ if bias_input.strip() == "" or country_input.strip() == "":
138
+ st.error("Please provide both a bias category and a country/region.")
139
+ else:
140
+ prompt = f"```{bias_input} in {country_input}```\n"
141
+ generated = generate_sample(prompt)
142
+ st.markdown("**Generated Output:**")
143
+ st.text_area("", value=generated, height=200)
144
+ st.download_button("Download Output", generated, file_name="output.txt")
145
+ elif mode == "Random Generation (10 samples)":
146
+ st.subheader("Random Generation Mode")
147
+ if st.button("Generate 10 Random Samples"):
148
+ results = []
149
+ for _ in range(10):
150
+ bias = random.choice(biases)
151
+ country = random.choice(countries)
152
+ prompt = f"```{bias} in {country}```\n"
153
+ generated = generate_sample(prompt)
154
+ results.append({"prompt": prompt, "generated": generated})
155
+ for i, res in enumerate(results):
156
+ st.markdown(f"**Sample {i+1}:**")
157
+ st.text_area("Prompt", value=res["prompt"], height=50)
158
+ st.text_area("Output", value=res["generated"], height=150)
159
+ df = pd.DataFrame(results)
160
+ csv = df.to_csv(index=False).encode("utf-8")
161
+ st.download_button("Download All Samples (CSV)", csv, file_name="samples.csv", mime="text/csv")