Akash190104 commited on
Commit
aa0eeb0
·
1 Parent(s): 360a42b

fixing the login error

Browse files
Files changed (1) hide show
  1. app.py +120 -86
app.py CHANGED
@@ -4,109 +4,143 @@ import pandas as pd
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
  from peft import PeftModel
 
 
 
 
 
 
 
 
 
 
7
 
8
  # -------------------------------
9
- # Cache the model to load only once
10
  # -------------------------------
11
- @st.cache_resource(show_spinner=True)
12
- def load_model():
13
- # Determine device
 
 
 
 
 
 
 
 
 
 
 
14
  if torch.cuda.is_available():
15
- device = "cuda"
16
  elif torch.backends.mps.is_available():
17
- device = "mps"
18
  else:
19
- device = "cpu"
20
-
21
- # Load the base model without offloading to avoid meta parameters issues
 
 
 
 
 
 
22
  base_model = AutoModelForCausalLM.from_pretrained(
23
  "meta-llama/Llama-3.2-1B-Instruct",
24
  trust_remote_code=True,
25
- torch_dtype=torch.float16
 
 
 
 
 
 
 
26
  )
27
-
28
- # Load tokenizer from your adapter repo and set pad token if needed
29
- tokenizer = AutoTokenizer.from_pretrained("Akash190104/space_turtle_101", use_fast=False)
30
  if tokenizer.pad_token is None:
31
  tokenizer.pad_token = tokenizer.eos_token
32
 
33
- # Load the LoRA adapter
34
- model = PeftModel.from_pretrained(base_model, "Akash190104/space_turtle_101")
 
 
 
 
35
  model.to(device)
36
-
37
  return model, tokenizer, device
38
 
39
- with st.spinner("Loading model, please wait..."):
40
- model, tokenizer, device = load_model()
41
-
42
- # -------------------------------
43
- # Define helper function for generation
44
- # -------------------------------
45
- def generate_sample(prompt_text):
46
- inputs = tokenizer(prompt_text, return_tensors="pt", padding=True).to(device)
47
- with torch.inference_mode():
48
- outputs = model.generate(
49
- **inputs,
50
- max_new_tokens=150,
51
- do_sample=True,
52
- temperature=0.7,
53
- top_p=0.9
54
- )
55
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
56
- return generated_text
57
-
58
- # -------------------------------
59
- # Define bias and country lists
60
- # -------------------------------
61
- biases = [
62
- "Socio-economic bias", "Gender bias", "Caste bias", "Religious bias", "Ethnic bias", "Age bias",
63
- "Regional stereotyping", "Language discrimination", "Political bias", "Educational bias",
64
- "Occupational bias", "Disability bias", "Appearance-based bias", "Colorism", "Nationality bias",
65
- "Urban-rural bias", "Immigration bias"
66
- ]
67
 
68
- countries = [
69
- "China", "India", "Philippines", "Vietnam", "Southeast Asia", "Europe", "Nigeria", "United States",
70
- "Mexico", "Canada", "Germany", "France", "Brazil", "South Africa", "Russia", "Japan", "South Korea",
71
- "Australia", "Middle East", "Latin America", "Eastern Europe", "Bangladesh", "Pakistan", "Indonesia",
72
- "Turkey", "Egypt", "Kenya", "Argentina"
73
- ]
 
 
 
 
 
 
 
 
 
74
 
75
- # -------------------------------
76
- # Streamlit UI
77
- # -------------------------------
78
- st.title("Space Turtle 101 Demo")
79
- st.markdown("Generate adversarial prompts based on bias category and country/region.")
 
 
 
 
 
 
 
 
 
 
80
 
81
- mode = st.radio("Select Mode", ("Interactive", "Random Generation (10 samples)"))
 
 
 
82
 
83
- if mode == "Interactive":
84
- st.subheader("Interactive Mode")
85
- bias_input = st.text_input("Bias Category", "")
86
- country_input = st.text_input("Country/Region", "")
87
- if st.button("Generate Sample"):
88
- if bias_input.strip() == "" or country_input.strip() == "":
89
- st.error("Please provide both a bias category and a country/region.")
90
- else:
91
- prompt = f"```{bias_input} in {country_input}```\n"
92
- generated = generate_sample(prompt)
93
- st.markdown("**Generated Output:**")
94
- st.text_area("", value=generated, height=200)
95
- st.download_button("Download Output", generated, file_name="output.txt")
96
- elif mode == "Random Generation (10 samples)":
97
- st.subheader("Random Generation Mode")
98
- if st.button("Generate 10 Random Samples"):
99
- results = []
100
- for _ in range(10):
101
- bias = random.choice(biases)
102
- country = random.choice(countries)
103
- prompt = f"```{bias} in {country}```\n"
104
- generated = generate_sample(prompt)
105
- results.append({"prompt": prompt, "generated": generated})
106
- for i, res in enumerate(results):
107
- st.markdown(f"**Sample {i+1}:**")
108
- st.text_area("Prompt", value=res["prompt"], height=50)
109
- st.text_area("Output", value=res["generated"], height=150)
110
- df = pd.DataFrame(results)
111
- csv = df.to_csv(index=False).encode("utf-8")
112
- st.download_button("Download All Samples (CSV)", csv, file_name="samples.csv", mime="text/csv")
 
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
  from peft import PeftModel
7
+ from huggingface_hub import login, whoami
8
+
9
+ # Title and description
10
+ st.title("Space Turtle 101 Demo")
11
+ st.markdown(
12
+ """
13
+ This demo generates adversarial prompts based on a bias category and country/region.
14
+ The model (meta-llama/Llama-3.2-1B-Instruct) is gated. Please enter your Hugging Face API token below to access it.
15
+ """
16
+ )
17
 
18
  # -------------------------------
19
+ # Sidebar: Hugging Face API Token Input & Auto Login
20
  # -------------------------------
21
+ hf_token = st.sidebar.text_input("Enter your Hugging Face API Token", type="password")
22
+ if hf_token:
23
+ try:
24
+ login(token=hf_token)
25
+ user_info = whoami()
26
+ st.sidebar.success(f"Logged in as: {user_info['name']}")
27
+ except Exception as e:
28
+ st.sidebar.error(f"Login failed: {e}")
29
+ hf_token = None
30
+
31
+ # -------------------------------
32
+ # Device Selection: CUDA > MPS > CPU
33
+ # -------------------------------
34
+ def get_device():
35
  if torch.cuda.is_available():
36
+ return "cuda"
37
  elif torch.backends.mps.is_available():
38
+ return "mps"
39
  else:
40
+ return "cpu"
41
+
42
+ # -------------------------------
43
+ # Function: Load Model with Caching
44
+ # -------------------------------
45
+ @st.cache_resource(show_spinner=True)
46
+ def load_model(hf_token):
47
+ device = get_device()
48
+ # Load the gated base model with your token
49
  base_model = AutoModelForCausalLM.from_pretrained(
50
  "meta-llama/Llama-3.2-1B-Instruct",
51
  trust_remote_code=True,
52
+ torch_dtype=torch.float16,
53
+ use_auth_token=hf_token
54
+ )
55
+ # Load the tokenizer from your adapter repo and set the pad token
56
+ tokenizer = AutoTokenizer.from_pretrained(
57
+ "Akash190104/space_turtle_101",
58
+ use_fast=False,
59
+ use_auth_token=hf_token
60
  )
 
 
 
61
  if tokenizer.pad_token is None:
62
  tokenizer.pad_token = tokenizer.eos_token
63
 
64
+ # Load the LoRA adapter with the token
65
+ model = PeftModel.from_pretrained(
66
+ base_model,
67
+ "Akash190104/space_turtle_101",
68
+ use_auth_token=hf_token
69
+ )
70
  model.to(device)
 
71
  return model, tokenizer, device
72
 
73
+ if not hf_token:
74
+ st.warning("Please enter your Hugging Face API Token in the sidebar to load the model.")
75
+ else:
76
+ with st.spinner("Loading model, please wait..."):
77
+ model, tokenizer, device = load_model(hf_token)
78
+ st.success("Model loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ # -------------------------------
81
+ # Generation helper function
82
+ # -------------------------------
83
+ def generate_sample(prompt_text):
84
+ inputs = tokenizer(prompt_text, return_tensors="pt", padding=True).to(device)
85
+ with torch.inference_mode():
86
+ outputs = model.generate(
87
+ **inputs,
88
+ max_new_tokens=150,
89
+ do_sample=True,
90
+ temperature=0.7,
91
+ top_p=0.9
92
+ )
93
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
94
+ return generated_text
95
 
96
+ # -------------------------------
97
+ # Define bias and country lists
98
+ # -------------------------------
99
+ biases = [
100
+ "Socio-economic bias", "Gender bias", "Caste bias", "Religious bias", "Ethnic bias", "Age bias",
101
+ "Regional stereotyping", "Language discrimination", "Political bias", "Educational bias",
102
+ "Occupational bias", "Disability bias", "Appearance-based bias", "Colorism", "Nationality bias",
103
+ "Urban-rural bias", "Immigration bias"
104
+ ]
105
+ countries = [
106
+ "China", "India", "Philippines", "Vietnam", "Southeast Asia", "Europe", "Nigeria", "United States",
107
+ "Mexico", "Canada", "Germany", "France", "Brazil", "South Africa", "Russia", "Japan", "South Korea",
108
+ "Australia", "Middle East", "Latin America", "Eastern Europe", "Bangladesh", "Pakistan", "Indonesia",
109
+ "Turkey", "Egypt", "Kenya", "Argentina"
110
+ ]
111
 
112
+ # -------------------------------
113
+ # Streamlit UI: Demo Modes
114
+ # -------------------------------
115
+ mode = st.radio("Select Mode", ("Interactive", "Random Generation (10 samples)"))
116
 
117
+ if mode == "Interactive":
118
+ st.subheader("Interactive Mode")
119
+ bias_input = st.text_input("Bias Category", "")
120
+ country_input = st.text_input("Country/Region", "")
121
+ if st.button("Generate Sample"):
122
+ if bias_input.strip() == "" or country_input.strip() == "":
123
+ st.error("Please provide both a bias category and a country/region.")
124
+ else:
125
+ prompt = f"```{bias_input} in {country_input}```\n"
126
+ generated = generate_sample(prompt)
127
+ st.markdown("**Generated Output:**")
128
+ st.text_area("", value=generated, height=200)
129
+ st.download_button("Download Output", generated, file_name="output.txt")
130
+ elif mode == "Random Generation (10 samples)":
131
+ st.subheader("Random Generation Mode")
132
+ if st.button("Generate 10 Random Samples"):
133
+ results = []
134
+ for _ in range(10):
135
+ bias = random.choice(biases)
136
+ country = random.choice(countries)
137
+ prompt = f"```{bias} in {country}```\n"
138
+ generated = generate_sample(prompt)
139
+ results.append({"prompt": prompt, "generated": generated})
140
+ for i, res in enumerate(results):
141
+ st.markdown(f"**Sample {i+1}:**")
142
+ st.text_area("Prompt", value=res["prompt"], height=50)
143
+ st.text_area("Output", value=res["generated"], height=150)
144
+ df = pd.DataFrame(results)
145
+ csv = df.to_csv(index=False).encode("utf-8")
146
+ st.download_button("Download All Samples (CSV)", csv, file_name="samples.csv", mime="text/csv")