JustKiddo commited on
Commit
0cd0b89
·
verified ·
1 Parent(s): b1ffab2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -0
app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+ from datasets import load_dataset
4
+ import torch
5
+ from transformers import pipeline
6
+
7
+ class ContentFilter:
8
+ def __init__(self):
9
+ # Initialize toxic content detection model
10
+ self.toxicity_classifier = pipeline(
11
+ 'text-classification',
12
+ model='unitary/toxic-bert',
13
+ return_all_scores=True
14
+ )
15
+
16
+ # Keyword blacklist
17
+ self.blacklist = [
18
+ 'hate', 'discriminate', 'violent',
19
+ 'offensive', 'inappropriate', 'racist',
20
+ 'sexist', 'homophobic', 'transphobic'
21
+ ]
22
+
23
+ def filter_toxicity(self, text, toxicity_threshold=0.7):
24
+ """
25
+ Detect toxic content using pre-trained model
26
+
27
+ Args:
28
+ text (str): Input text to check
29
+ toxicity_threshold (float): Threshold for filtering
30
+
31
+ Returns:
32
+ dict: Filtering results
33
+ """
34
+ results = self.toxicity_classifier(text)[0]
35
+
36
+ # Convert results to dictionary
37
+ toxicity_scores = {
38
+ result['label']: result['score']
39
+ for result in results
40
+ }
41
+
42
+ # Check if any toxic category exceeds threshold
43
+ is_toxic = any(
44
+ score > toxicity_threshold
45
+ for score in toxicity_scores.values()
46
+ )
47
+
48
+ return {
49
+ 'is_toxic': is_toxic,
50
+ 'toxicity_scores': toxicity_scores
51
+ }
52
+
53
+ def filter_keywords(self, text):
54
+ """
55
+ Check text against keyword blacklist
56
+
57
+ Args:
58
+ text (str): Input text to check
59
+
60
+ Returns:
61
+ list: Matched blacklisted keywords
62
+ """
63
+ matched_keywords = [
64
+ keyword for keyword in self.blacklist
65
+ if keyword.lower() in text.lower()
66
+ ]
67
+
68
+ return matched_keywords
69
+
70
+ def comprehensive_filter(self, text):
71
+ """
72
+ Perform comprehensive content filtering
73
+
74
+ Args:
75
+ text (str): Input text to filter
76
+
77
+ Returns:
78
+ dict: Comprehensive filtering results
79
+ """
80
+ # Toxicity model filtering
81
+ toxicity_results = self.filter_toxicity(text)
82
+
83
+ # Keyword blacklist filtering
84
+ blacklisted_keywords = self.filter_keywords(text)
85
+
86
+ # Combine results
87
+ return {
88
+ 'toxicity': toxicity_results,
89
+ 'blacklisted_keywords': blacklisted_keywords,
90
+ 'is_safe': not toxicity_results['is_toxic'] and len(blacklisted_keywords) == 0
91
+ }
92
+
93
+ # Initialize content filter
94
+ content_filter = ContentFilter()
95
+
96
+ # Initialize Hugging Face client
97
+ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
98
+
99
+ # Load dataset (optional)
100
+ dataset = load_dataset("JustKiddo/KiddosVault")
101
+
102
+ def respond(
103
+ message,
104
+ history: list[tuple[str, str]],
105
+ system_message,
106
+ max_tokens,
107
+ temperature,
108
+ top_p
109
+ ):
110
+ # First, filter the incoming user message
111
+ message_filter_result = content_filter.comprehensive_filter(message)
112
+
113
+ # If message is not safe, return a warning
114
+ if not message_filter_result['is_safe']:
115
+ toxicity_details = message_filter_result['toxicity']['toxicity_scores']
116
+ blacklisted_keywords = message_filter_result['blacklisted_keywords']
117
+
118
+ warning_message = "Message flagged for inappropriate content. "
119
+ warning_message += "Detected issues: "
120
+
121
+ # Add toxicity details
122
+ for category, score in toxicity_details.items():
123
+ if score > 0.7:
124
+ warning_message += f"{category} (Score: {score:.2f}), "
125
+
126
+ # Add blacklisted keywords
127
+ if blacklisted_keywords:
128
+ warning_message += f"Blacklisted keywords: {', '.join(blacklisted_keywords)}"
129
+
130
+ return warning_message
131
+
132
+ # Prepare messages for chat completion
133
+ messages = [{"role": "system", "content": system_message}]
134
+ for val in history:
135
+ if val[0]:
136
+ messages.append({"role": "user", "content": val[0]})
137
+ if val[1]:
138
+ messages.append({"role": "assistant", "content": val[1]})
139
+ messages.append({"role": "user", "content": message})
140
+
141
+ # Generate response
142
+ response = ""
143
+ for message in client.chat_completion(
144
+ messages,
145
+ max_tokens=max_tokens,
146
+ stream=True,
147
+ temperature=temperature,
148
+ top_p=top_p
149
+ ):
150
+ token = message.choices[0].delta.content
151
+ response += token
152
+ yield response
153
+
154
+ # Create Gradio interface
155
+ demo = gr.ChatInterface(
156
+ respond,
157
+ additional_inputs=[
158
+ gr.Textbox(
159
+ value="You are a professional and friendly therapist specialized on LGBT+ issues",
160
+ label="System message"
161
+ ),
162
+ gr.Slider(
163
+ minimum=1,
164
+ maximum=6144,
165
+ value=6144,
166
+ step=1,
167
+ label="Max new tokens"
168
+ ),
169
+ gr.Slider(
170
+ minimum=0.1,
171
+ maximum=4.0,
172
+ value=1,
173
+ step=0.1,
174
+ label="Temperature"
175
+ ),
176
+ gr.Slider(
177
+ minimum=0.1,
178
+ maximum=1.0,
179
+ value=0.95,
180
+ step=0.05,
181
+ label="Top-p (nucleus sampling)"
182
+ ),
183
+ ]
184
+ )
185
+
186
+ if __name__ == "__main__":
187
+ demo.launch(debug=True)