File size: 10,622 Bytes
11321d2
 
3213d99
174f794
60e000e
3213d99
e66c430
8121f7d
36cf11d
ed150f6
d77692a
e66c430
 
ed150f6
3a4cf0d
 
ed150f6
3a4cf0d
77b481a
 
3cc35d6
1024625
3e591de
ea27a0d
 
ed150f6
183ac41
2bd1de7
183ac41
 
 
ed150f6
4318079
 
 
 
ed150f6
183ac41
ed150f6
21cc1f9
 
 
 
 
 
e66c430
21cc1f9
27557d4
36cf11d
d77692a
21cc1f9
 
1eeee19
c3f3656
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b57f019
3e095c1
ba699ef
3e095c1
ba699ef
642f22f
ba699ef
3e095c1
ba699ef
642f22f
 
 
ba699ef
642f22f
ba699ef
642f22f
ba699ef
3e095c1
ba699ef
3e095c1
ba699ef
3e095c1
ba699ef
642f22f
 
 
 
 
 
 
 
ba699ef
3e095c1
 
642f22f
 
 
 
 
3e095c1
ba699ef
642f22f
ba699ef
b57f019
1eeee19
60e000e
 
1eeee19
603cd8f
e66c430
ed150f6
603cd8f
8121f7d
 
 
9e94f83
 
8121f7d
 
 
 
603cd8f
8121f7d
 
 
 
 
 
 
 
 
 
 
0c69600
 
 
9e94f83
 
 
 
 
 
 
 
 
 
 
 
 
 
d441b45
 
8121f7d
 
ae6cec8
093e0f5
ae6cec8
b57f019
1458c5c
 
 
 
 
 
 
 
 
 
 
 
 
c79a5b3
 
1458c5c
 
 
 
 
17d9a5e
 
1458c5c
17d9a5e
 
 
1458c5c
17d9a5e
 
1458c5c
17d9a5e
 
1458c5c
17d9a5e
 
1458c5c
17d9a5e
 
 
ae6cec8
17d9a5e
ae6cec8
17d9a5e
36a2dcc
17d9a5e
ae6cec8
1458c5c
b57f019
 
c250125
5efe6cf
 
ae6cec8
21cc1f9
 
9011868
1eeee19
b57f019
 
af71289
 
 
a208505
b57f019
d3996e0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
# This is a project of Chakra Lab LLC. All rights reserved.

import gradio as gr
import os
import time

import torch
import torch.nn.functional as F
from peft import PeftConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

device = 'cuda' if torch.cuda.is_available() else 'cpu'

#base_model_name = "google/gemma-2b"
#base_model_name = "google/gemma-7b"
base_model_name = "google/gemma-2-9b"
#adapter_model_name = "samidh/cope-g2b-2c-hs.s1.5fpc.9-sx.s1.5.9o-vl.s1.5.9-hr.s5-sh.s5.l5e5-e3-d25-r8"
#adapter_model_name = "samidh/cope-g7bq-2c-hs.s1.5fpc.9-sx.s1.5.9o-VL.s1.5.9-HR.s5-SH.s5-l5e5-e3-d25-r8"

#adapter_model_name = "cope-project/cope-g2.9b-2c-hs.s1.5fpc.9-sx.s1.5.9o-vl.s1.5.9-hr.s5-sh.s5.l5e5-e3-d0-r8"
#adapter_model_name = "samidh/cope-g2.9b-2c-hs.at-sx.at-vl.s5-sh.s5-hr.s5-e2"
#adapter_model_name = "samidh/cope.a.r09.2"
#adapter_model_name = "samidh/cope.a.r09.4.no.vl"
#adapter_model_name = "samidh/cope.a.r09.3"
adapter_model_name = "samidh/cope.a.r09.2"

bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,
)

"""
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True
)
"""

model = AutoModelForCausalLM.from_pretrained(base_model_name,
                                             token=os.environ['HF_TOKEN'],
                                             quantization_config=bnb_config,
                                             device_map="auto")
model = PeftModel.from_pretrained(model, adapter_model_name, token=os.environ['HF_TOKEN'])
model.merge_and_unload()

model = model.to(device)

tokenizer = AutoTokenizer.from_pretrained(base_model_name)

#gr.load("models/samidh/cope-gemma2b-hs-2c-skr-s1.5.9.d25", hf_token=os.environ['HF_TOKEN']).launch()

PROMPT = """
INSTRUCTIONS
============

Examine the given POLICY and determine if the given CONTENT meets the criteria for ANY of the LABELS. Answer "1" if yes, and "0" if no.

POLICY
======

{policy}


CONTENT
=======

{content}


ANSWER
======

"""

DEFAULT_POLICY = """# Policy

## Overview

This policy is designed to determine whether or not content is hate speech.

## Definition of Terms

- Protected Classes: Individuals or groups distinguished by race (e.g. Black, White), ethnicity (e.g. Hispanic), nationality (e.g. Canadian, Argentinian), caste (e.g. Brahmins, Kshatriyas, Vaishyas), immigration status (e.g. Immigrants), religion (e.g. Catholics, Jews), disability (e.g. blind people, deaf people), disease status (e.g. people with AIDS), sexual orientation (e.g. gays, bisexuals), transgender status (e.g. trans people, cis people), gender identity (e.g. women, men), gender expression (e.g. masculine, feminine), or sex (e.g. male, female).
- Derogatory Slur: Any term or phrase used pejoratively or contemptuously towards a member or members of a protected class, including those historically reclaimed but used in a derogatory context.
- Non-Protected Classes: Categories defined by, political party membership (e.g. Democrats, Tories), political beliefs (e.g. fascists, communists), hobbies (e.g. games, knitters), criminal history (e.g. criminals, convicts), or profession (e.g. doctors, farmers).

## Interpretation of Language:

- Terminology or phrasing that can be construed in various ways IS hate speech UNLESS context clearly indicates it unambiguously IS NOT hate speech. So, for example, attacks on a religion as an institution (e.g. "Hinduism is a garbage religion") SHOULD be understood as attacks on the followers of that religion.

## Definition of Labels

### (HS): Hate Speech

#### Includes

- Direct and Indirect Violence Advocacy: Any form of expression that celebrates, justifies, incites, or threatens physical or psychological harm against individuals or groups identified by their membership in a protected class.
- Dehumanization and Vilification: Content suggesting that individuals or groups, explicitly identified by their membership in a specific protected class, are subhuman, inherently lesser beings, or using animalistic descriptions in a way that promotes disdain or hate.
- Derogatory and Dehumanizing Language: Use of slurs, epithets, or any derogatory language aimed at belittling, humiliating, or inciting hatred against individuals or groups explicitly identified by their membership in a specific protected class.
- Explicit and Implicit Discrimination Advocacy: Promoting exclusion, segregation, or denial of rights against  individuals or groups explicitly identified by their membership in a specific protected class.
- Collective Attribution of Negative Actions: Assigning collective blame or advocating collective punishment based on the actions or perceived characteristics of individuals or groups identified by their membership in a protected class.
- Inferiority and Superiority Claims: Statements that categorically assign inferiority or superiority, moral or intellectual, to individuals or groups identified by their membership in a protected class.
- Denial or Distortion of Historical Atrocities: Denying, grossly trivializing, or distorting documented atrocities against groups identified by their membership in a protected class, undermining their significance or the suffering of their members.
- Conspiracy Theories: Propagating unfounded allegations that individuals or groups, identified by their membership in a protected class, are responsible for serious harms or controlling significant institutions to the detriment of society.

#### Excludes

- Attacks on Non-Protected Classes: Content that attacks or criticizes individuals or groups identified by their membership in a Non-Protected Class, EVEN if that attack is violent, threatening, or otherwise hateful (e.g. "Criminals should all be rounded up and shot!").
- Criticism of Beliefs and Institutions: Constructive critique or discussion of political ideologies, religious doctrines, or institutions without resorting to hate speech or targeting individuals or groups identified by their membership in a protected class.
- Neutrally Reporting Historical Events: Neutrally and descriptively reporting or discussion of factual events in the past that could be construed as negative about individuals or groups identified by their membership in a protected class.
- Pushing Back on Hateful Language: Content where the writer pushes back on, condemns, questions, criticizes, or mocks a different person's hateful language or ideas.
- Artistic and Educational Content: Content with legitimate artistic, educational, or documentary value that discusses or portrays issues related to hate speech in a context clearly aimed at enlightening or informing without promoting hate.
"""

DEFAULT_CONTENT = "Put your content sample here."

# Function to make predictions
def predict(content, policy):
    print(f"New request at {time.time()}")

    input_text = PROMPT.format(policy=policy, content=content)
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(model.device)

    with torch.inference_mode():
        outputs = model(input_ids)

        # Get logits for the last token
        logits = outputs.logits[:, -1, :]  
        
        """
        # Apply softmax to get probabilities
        probabilities = F.softmax(logits, dim=-1)
        
        # Get the predicted token ID
        predicted_token_id = torch.argmax(logits, dim=-1).item()
        
        # Decode the predicted token
        decoded_output = tokenizer.decode([predicted_token_id])
        
        # Get the probability of the predicted token
        predicted_prob = probabilities[0, predicted_token_id].item()
        
        # Function to get probability for a specific token
        def get_token_probability(token):
            token_id = tokenizer.encode(token, add_special_tokens=False)[0]
            return probabilities[0, token_id].item()

        predicted_prob_0 = get_token_probability('0')
        predicted_prob_1 = get_token_probability('1')
        """

        # Get token IDs for "0" and "1"
        token_id_0 = tokenizer.encode("0", add_special_tokens=False)[0]
        token_id_1 = tokenizer.encode("1", add_special_tokens=False)[0]

        # Extract logits for "0" and "1"
        binary_logits = logits[:, [token_id_0, token_id_1]]

        # Apply softmax to get probabilities for these two tokens
        probabilities = F.softmax(binary_logits, dim=-1)

        predicted_prob_0 = probabilities[0,0].item()
        predicted_prob_1 = probabilities[0,1].item()

        decoded_output = '0' if predicted_prob_0 > predicted_prob_1 else '1'
        
    if decoded_output == '1':
        return f'VIOLATING (P: {predicted_prob_1:.2f})'
    else:
        return f'NON-Violating (P: {predicted_prob_0:.2f})'


# Function to make predictions in batches
def predict_batch(contents, policies):
    print(f"New batch request at {time.time()}")

    input_texts = [
        PROMPT.format(policy=policy, content=content)
        for content, policy in zip(contents, policies)
    ]

    input_ids = tokenizer(
        input_texts,
        return_tensors="pt",
        padding=True,
        truncation=True
    ).input_ids.to(model.device)

    with torch.inference_mode():
        outputs = model(input_ids)

    # Get logits for the last tokens
    logits = outputs.logits[:, -1, :]

    # Get token IDs for "0" and "1"
    token_id_0 = tokenizer.encode("0", add_special_tokens=False)[0]
    token_id_1 = tokenizer.encode("1", add_special_tokens=False)[0]

    # Extract logits for "0" and "1"
    binary_logits = logits[:, [token_id_0, token_id_1]]

    # Apply softmax to get probabilities for these two tokens
    probabilities = F.softmax(binary_logits, dim=-1)

    probs_0 = probabilities[:, 0].cpu().numpy()
    probs_1 = probabilities[:, 1].cpu().numpy()

    results = []
    for prob_0, prob_1 in zip(probs_0, probs_1):
        if prob_1 > prob_0:
            output = f'VIOLATING (P: {prob_1:.2f})'
        else:
            output = f'NON-Violating (P: {prob_0:.2f})'
        results.append(output)

    print(results)
    return [results]

# Create Gradio interface
iface = gr.Interface(
    fn=predict_batch,
    inputs=[gr.Textbox(label="Content", lines=2, value=DEFAULT_CONTENT), 
            gr.Textbox(label="Policy", lines=10, value=DEFAULT_POLICY)], 
    outputs=[gr.Textbox(label="Result")],
    batch=True,
    max_batch_size=4,
    title="CoPE Dev (Unstable)",
    description="See if the given content violates your given policy."
)

#iface.queue(
#    default_concurrency_limit=2
#)

# Launch the app
iface.launch()