Akash190104 commited on
Commit
360a42b
·
1 Parent(s): d58d5d0

first commit

Browse files
Files changed (2) hide show
  1. app.py +112 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import random
3
+ import pandas as pd
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ from peft import PeftModel
7
+
8
+ # -------------------------------
9
+ # Cache the model to load only once
10
+ # -------------------------------
11
+ @st.cache_resource(show_spinner=True)
12
+ def load_model():
13
+ # Determine device
14
+ if torch.cuda.is_available():
15
+ device = "cuda"
16
+ elif torch.backends.mps.is_available():
17
+ device = "mps"
18
+ else:
19
+ device = "cpu"
20
+
21
+ # Load the base model without offloading to avoid meta parameters issues
22
+ base_model = AutoModelForCausalLM.from_pretrained(
23
+ "meta-llama/Llama-3.2-1B-Instruct",
24
+ trust_remote_code=True,
25
+ torch_dtype=torch.float16
26
+ )
27
+
28
+ # Load tokenizer from your adapter repo and set pad token if needed
29
+ tokenizer = AutoTokenizer.from_pretrained("Akash190104/space_turtle_101", use_fast=False)
30
+ if tokenizer.pad_token is None:
31
+ tokenizer.pad_token = tokenizer.eos_token
32
+
33
+ # Load the LoRA adapter
34
+ model = PeftModel.from_pretrained(base_model, "Akash190104/space_turtle_101")
35
+ model.to(device)
36
+
37
+ return model, tokenizer, device
38
+
39
+ with st.spinner("Loading model, please wait..."):
40
+ model, tokenizer, device = load_model()
41
+
42
+ # -------------------------------
43
+ # Define helper function for generation
44
+ # -------------------------------
45
+ def generate_sample(prompt_text):
46
+ inputs = tokenizer(prompt_text, return_tensors="pt", padding=True).to(device)
47
+ with torch.inference_mode():
48
+ outputs = model.generate(
49
+ **inputs,
50
+ max_new_tokens=150,
51
+ do_sample=True,
52
+ temperature=0.7,
53
+ top_p=0.9
54
+ )
55
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
56
+ return generated_text
57
+
58
+ # -------------------------------
59
+ # Define bias and country lists
60
+ # -------------------------------
61
+ biases = [
62
+ "Socio-economic bias", "Gender bias", "Caste bias", "Religious bias", "Ethnic bias", "Age bias",
63
+ "Regional stereotyping", "Language discrimination", "Political bias", "Educational bias",
64
+ "Occupational bias", "Disability bias", "Appearance-based bias", "Colorism", "Nationality bias",
65
+ "Urban-rural bias", "Immigration bias"
66
+ ]
67
+
68
+ countries = [
69
+ "China", "India", "Philippines", "Vietnam", "Southeast Asia", "Europe", "Nigeria", "United States",
70
+ "Mexico", "Canada", "Germany", "France", "Brazil", "South Africa", "Russia", "Japan", "South Korea",
71
+ "Australia", "Middle East", "Latin America", "Eastern Europe", "Bangladesh", "Pakistan", "Indonesia",
72
+ "Turkey", "Egypt", "Kenya", "Argentina"
73
+ ]
74
+
75
+ # -------------------------------
76
+ # Streamlit UI
77
+ # -------------------------------
78
+ st.title("Space Turtle 101 Demo")
79
+ st.markdown("Generate adversarial prompts based on bias category and country/region.")
80
+
81
+ mode = st.radio("Select Mode", ("Interactive", "Random Generation (10 samples)"))
82
+
83
+ if mode == "Interactive":
84
+ st.subheader("Interactive Mode")
85
+ bias_input = st.text_input("Bias Category", "")
86
+ country_input = st.text_input("Country/Region", "")
87
+ if st.button("Generate Sample"):
88
+ if bias_input.strip() == "" or country_input.strip() == "":
89
+ st.error("Please provide both a bias category and a country/region.")
90
+ else:
91
+ prompt = f"```{bias_input} in {country_input}```\n"
92
+ generated = generate_sample(prompt)
93
+ st.markdown("**Generated Output:**")
94
+ st.text_area("", value=generated, height=200)
95
+ st.download_button("Download Output", generated, file_name="output.txt")
96
+ elif mode == "Random Generation (10 samples)":
97
+ st.subheader("Random Generation Mode")
98
+ if st.button("Generate 10 Random Samples"):
99
+ results = []
100
+ for _ in range(10):
101
+ bias = random.choice(biases)
102
+ country = random.choice(countries)
103
+ prompt = f"```{bias} in {country}```\n"
104
+ generated = generate_sample(prompt)
105
+ results.append({"prompt": prompt, "generated": generated})
106
+ for i, res in enumerate(results):
107
+ st.markdown(f"**Sample {i+1}:**")
108
+ st.text_area("Prompt", value=res["prompt"], height=50)
109
+ st.text_area("Output", value=res["generated"], height=150)
110
+ df = pd.DataFrame(results)
111
+ csv = df.to_csv(index=False).encode("utf-8")
112
+ st.download_button("Download All Samples (CSV)", csv, file_name="samples.csv", mime="text/csv")
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ transformers
4
+ peft
5
+ pandas