Spaces:
Running
on
L4
Running
on
L4
File size: 10,622 Bytes
11321d2 3213d99 174f794 60e000e 3213d99 e66c430 8121f7d 36cf11d ed150f6 d77692a e66c430 ed150f6 3a4cf0d ed150f6 3a4cf0d 77b481a 3cc35d6 1024625 3e591de ea27a0d ed150f6 183ac41 2bd1de7 183ac41 ed150f6 4318079 ed150f6 183ac41 ed150f6 21cc1f9 e66c430 21cc1f9 27557d4 36cf11d d77692a 21cc1f9 1eeee19 c3f3656 b57f019 3e095c1 ba699ef 3e095c1 ba699ef 642f22f ba699ef 3e095c1 ba699ef 642f22f ba699ef 642f22f ba699ef 642f22f ba699ef 3e095c1 ba699ef 3e095c1 ba699ef 3e095c1 ba699ef 642f22f ba699ef 3e095c1 642f22f 3e095c1 ba699ef 642f22f ba699ef b57f019 1eeee19 60e000e 1eeee19 603cd8f e66c430 ed150f6 603cd8f 8121f7d 9e94f83 8121f7d 603cd8f 8121f7d 0c69600 9e94f83 d441b45 8121f7d ae6cec8 093e0f5 ae6cec8 b57f019 1458c5c c79a5b3 1458c5c 17d9a5e 1458c5c 17d9a5e 1458c5c 17d9a5e 1458c5c 17d9a5e 1458c5c 17d9a5e 1458c5c 17d9a5e ae6cec8 17d9a5e ae6cec8 17d9a5e 36a2dcc 17d9a5e ae6cec8 1458c5c b57f019 c250125 5efe6cf ae6cec8 21cc1f9 9011868 1eeee19 b57f019 af71289 a208505 b57f019 d3996e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 |
# This is a project of Chakra Lab LLC. All rights reserved.
import gradio as gr
import os
import time
import torch
import torch.nn.functional as F
from peft import PeftConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#base_model_name = "google/gemma-2b"
#base_model_name = "google/gemma-7b"
base_model_name = "google/gemma-2-9b"
#adapter_model_name = "samidh/cope-g2b-2c-hs.s1.5fpc.9-sx.s1.5.9o-vl.s1.5.9-hr.s5-sh.s5.l5e5-e3-d25-r8"
#adapter_model_name = "samidh/cope-g7bq-2c-hs.s1.5fpc.9-sx.s1.5.9o-VL.s1.5.9-HR.s5-SH.s5-l5e5-e3-d25-r8"
#adapter_model_name = "cope-project/cope-g2.9b-2c-hs.s1.5fpc.9-sx.s1.5.9o-vl.s1.5.9-hr.s5-sh.s5.l5e5-e3-d0-r8"
#adapter_model_name = "samidh/cope-g2.9b-2c-hs.at-sx.at-vl.s5-sh.s5-hr.s5-e2"
#adapter_model_name = "samidh/cope.a.r09.2"
#adapter_model_name = "samidh/cope.a.r09.4.no.vl"
#adapter_model_name = "samidh/cope.a.r09.3"
adapter_model_name = "samidh/cope.a.r09.2"
bnb_config = BitsAndBytesConfig(
load_in_8bit=True,
)
"""
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True
)
"""
model = AutoModelForCausalLM.from_pretrained(base_model_name,
token=os.environ['HF_TOKEN'],
quantization_config=bnb_config,
device_map="auto")
model = PeftModel.from_pretrained(model, adapter_model_name, token=os.environ['HF_TOKEN'])
model.merge_and_unload()
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
#gr.load("models/samidh/cope-gemma2b-hs-2c-skr-s1.5.9.d25", hf_token=os.environ['HF_TOKEN']).launch()
PROMPT = """
INSTRUCTIONS
============
Examine the given POLICY and determine if the given CONTENT meets the criteria for ANY of the LABELS. Answer "1" if yes, and "0" if no.
POLICY
======
{policy}
CONTENT
=======
{content}
ANSWER
======
"""
DEFAULT_POLICY = """# Policy
## Overview
This policy is designed to determine whether or not content is hate speech.
## Definition of Terms
- Protected Classes: Individuals or groups distinguished by race (e.g. Black, White), ethnicity (e.g. Hispanic), nationality (e.g. Canadian, Argentinian), caste (e.g. Brahmins, Kshatriyas, Vaishyas), immigration status (e.g. Immigrants), religion (e.g. Catholics, Jews), disability (e.g. blind people, deaf people), disease status (e.g. people with AIDS), sexual orientation (e.g. gays, bisexuals), transgender status (e.g. trans people, cis people), gender identity (e.g. women, men), gender expression (e.g. masculine, feminine), or sex (e.g. male, female).
- Derogatory Slur: Any term or phrase used pejoratively or contemptuously towards a member or members of a protected class, including those historically reclaimed but used in a derogatory context.
- Non-Protected Classes: Categories defined by, political party membership (e.g. Democrats, Tories), political beliefs (e.g. fascists, communists), hobbies (e.g. games, knitters), criminal history (e.g. criminals, convicts), or profession (e.g. doctors, farmers).
## Interpretation of Language:
- Terminology or phrasing that can be construed in various ways IS hate speech UNLESS context clearly indicates it unambiguously IS NOT hate speech. So, for example, attacks on a religion as an institution (e.g. "Hinduism is a garbage religion") SHOULD be understood as attacks on the followers of that religion.
## Definition of Labels
### (HS): Hate Speech
#### Includes
- Direct and Indirect Violence Advocacy: Any form of expression that celebrates, justifies, incites, or threatens physical or psychological harm against individuals or groups identified by their membership in a protected class.
- Dehumanization and Vilification: Content suggesting that individuals or groups, explicitly identified by their membership in a specific protected class, are subhuman, inherently lesser beings, or using animalistic descriptions in a way that promotes disdain or hate.
- Derogatory and Dehumanizing Language: Use of slurs, epithets, or any derogatory language aimed at belittling, humiliating, or inciting hatred against individuals or groups explicitly identified by their membership in a specific protected class.
- Explicit and Implicit Discrimination Advocacy: Promoting exclusion, segregation, or denial of rights against individuals or groups explicitly identified by their membership in a specific protected class.
- Collective Attribution of Negative Actions: Assigning collective blame or advocating collective punishment based on the actions or perceived characteristics of individuals or groups identified by their membership in a protected class.
- Inferiority and Superiority Claims: Statements that categorically assign inferiority or superiority, moral or intellectual, to individuals or groups identified by their membership in a protected class.
- Denial or Distortion of Historical Atrocities: Denying, grossly trivializing, or distorting documented atrocities against groups identified by their membership in a protected class, undermining their significance or the suffering of their members.
- Conspiracy Theories: Propagating unfounded allegations that individuals or groups, identified by their membership in a protected class, are responsible for serious harms or controlling significant institutions to the detriment of society.
#### Excludes
- Attacks on Non-Protected Classes: Content that attacks or criticizes individuals or groups identified by their membership in a Non-Protected Class, EVEN if that attack is violent, threatening, or otherwise hateful (e.g. "Criminals should all be rounded up and shot!").
- Criticism of Beliefs and Institutions: Constructive critique or discussion of political ideologies, religious doctrines, or institutions without resorting to hate speech or targeting individuals or groups identified by their membership in a protected class.
- Neutrally Reporting Historical Events: Neutrally and descriptively reporting or discussion of factual events in the past that could be construed as negative about individuals or groups identified by their membership in a protected class.
- Pushing Back on Hateful Language: Content where the writer pushes back on, condemns, questions, criticizes, or mocks a different person's hateful language or ideas.
- Artistic and Educational Content: Content with legitimate artistic, educational, or documentary value that discusses or portrays issues related to hate speech in a context clearly aimed at enlightening or informing without promoting hate.
"""
DEFAULT_CONTENT = "Put your content sample here."
# Function to make predictions
def predict(content, policy):
print(f"New request at {time.time()}")
input_text = PROMPT.format(policy=policy, content=content)
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(model.device)
with torch.inference_mode():
outputs = model(input_ids)
# Get logits for the last token
logits = outputs.logits[:, -1, :]
"""
# Apply softmax to get probabilities
probabilities = F.softmax(logits, dim=-1)
# Get the predicted token ID
predicted_token_id = torch.argmax(logits, dim=-1).item()
# Decode the predicted token
decoded_output = tokenizer.decode([predicted_token_id])
# Get the probability of the predicted token
predicted_prob = probabilities[0, predicted_token_id].item()
# Function to get probability for a specific token
def get_token_probability(token):
token_id = tokenizer.encode(token, add_special_tokens=False)[0]
return probabilities[0, token_id].item()
predicted_prob_0 = get_token_probability('0')
predicted_prob_1 = get_token_probability('1')
"""
# Get token IDs for "0" and "1"
token_id_0 = tokenizer.encode("0", add_special_tokens=False)[0]
token_id_1 = tokenizer.encode("1", add_special_tokens=False)[0]
# Extract logits for "0" and "1"
binary_logits = logits[:, [token_id_0, token_id_1]]
# Apply softmax to get probabilities for these two tokens
probabilities = F.softmax(binary_logits, dim=-1)
predicted_prob_0 = probabilities[0,0].item()
predicted_prob_1 = probabilities[0,1].item()
decoded_output = '0' if predicted_prob_0 > predicted_prob_1 else '1'
if decoded_output == '1':
return f'VIOLATING (P: {predicted_prob_1:.2f})'
else:
return f'NON-Violating (P: {predicted_prob_0:.2f})'
# Function to make predictions in batches
def predict_batch(contents, policies):
print(f"New batch request at {time.time()}")
input_texts = [
PROMPT.format(policy=policy, content=content)
for content, policy in zip(contents, policies)
]
input_ids = tokenizer(
input_texts,
return_tensors="pt",
padding=True,
truncation=True
).input_ids.to(model.device)
with torch.inference_mode():
outputs = model(input_ids)
# Get logits for the last tokens
logits = outputs.logits[:, -1, :]
# Get token IDs for "0" and "1"
token_id_0 = tokenizer.encode("0", add_special_tokens=False)[0]
token_id_1 = tokenizer.encode("1", add_special_tokens=False)[0]
# Extract logits for "0" and "1"
binary_logits = logits[:, [token_id_0, token_id_1]]
# Apply softmax to get probabilities for these two tokens
probabilities = F.softmax(binary_logits, dim=-1)
probs_0 = probabilities[:, 0].cpu().numpy()
probs_1 = probabilities[:, 1].cpu().numpy()
results = []
for prob_0, prob_1 in zip(probs_0, probs_1):
if prob_1 > prob_0:
output = f'VIOLATING (P: {prob_1:.2f})'
else:
output = f'NON-Violating (P: {prob_0:.2f})'
results.append(output)
print(results)
return [results]
# Create Gradio interface
iface = gr.Interface(
fn=predict_batch,
inputs=[gr.Textbox(label="Content", lines=2, value=DEFAULT_CONTENT),
gr.Textbox(label="Policy", lines=10, value=DEFAULT_POLICY)],
outputs=[gr.Textbox(label="Result")],
batch=True,
max_batch_size=4,
title="CoPE Dev (Unstable)",
description="See if the given content violates your given policy."
)
#iface.queue(
# default_concurrency_limit=2
#)
# Launch the app
iface.launch()
|