File size: 5,414 Bytes
0cd0b89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2c04a4
0cd0b89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a42b9fc
 
0cd0b89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2c04a4
0cd0b89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce90c8f
0cd0b89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import gradio as gr
from huggingface_hub import InferenceClient
from datasets import load_dataset
import torch
from transformers import pipeline

class ContentFilter:
    def __init__(self):
        # Initialize toxic content detection model
        self.toxicity_classifier = pipeline(
            'text-classification', 
            model='unitary/toxic-bert', 
            return_all_scores=True
        )
        
        # Keyword blacklist
        self.blacklist = [
            'hate', 'discriminate', 'violent', 
            'offensive', 'inappropriate', 'racist', 
            'sexist', 'homophobic', 'transphobic'
        ]
    
    def filter_toxicity(self, text, toxicity_threshold=0.5):
        """
        Detect toxic content using pre-trained model
        
        Args:
            text (str): Input text to check
            toxicity_threshold (float): Threshold for filtering
        
        Returns:
            dict: Filtering results
        """
        results = self.toxicity_classifier(text)[0]
        
        # Convert results to dictionary
        toxicity_scores = {
            result['label']: result['score'] 
            for result in results
        }
        
        # Check if any toxic category exceeds threshold
        is_toxic = any(
            score > toxicity_threshold 
            for score in toxicity_scores.values()
        )
        
        return {
            'is_toxic': is_toxic,
            'toxicity_scores': toxicity_scores
        }
    
    def filter_keywords(self, text):
        """
        Check text against keyword blacklist
        
        Args:
            text (str): Input text to check
        
        Returns:
            list: Matched blacklisted keywords
        """
        matched_keywords = [
            keyword for keyword in self.blacklist 
            if keyword.lower() in text.lower()
        ]
        
        return matched_keywords
    
    def comprehensive_filter(self, text):
        """
        Perform comprehensive content filtering
        
        Args:
            text (str): Input text to filter
        
        Returns:
            dict: Comprehensive filtering results
        """
        # Toxicity model filtering
        toxicity_results = self.filter_toxicity(text)
        
        # Keyword blacklist filtering
        blacklisted_keywords = self.filter_keywords(text)
        
        # Combine results
        return {
            'toxicity': toxicity_results,
            'blacklisted_keywords': blacklisted_keywords,
            'is_safe': not toxicity_results['is_toxic'] and len(blacklisted_keywords) == 0
        }

# Initialize content filter
content_filter = ContentFilter()

# Initialize Hugging Face client
#client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
client = InferenceClient("google-t5/t5-small")

# Load dataset (optional)
dataset = load_dataset("JustKiddo/KiddosVault")

def respond(
    message, 
    history: list[tuple[str, str]], 
    system_message, 
    max_tokens, 
    temperature, 
    top_p
):
    # First, filter the incoming user message
    message_filter_result = content_filter.comprehensive_filter(message)
    
    # If message is not safe, return a warning
    if not message_filter_result['is_safe']:
        toxicity_details = message_filter_result['toxicity']['toxicity_scores']
        blacklisted_keywords = message_filter_result['blacklisted_keywords']
        
        warning_message = "Message flagged for inappropriate content. "
        warning_message += "Detected issues: "
        
        # Add toxicity details
        for category, score in toxicity_details.items():
            if score > 0.5:
                warning_message += f"{category} (Score: {score:.2f}), "
        
        # Add blacklisted keywords
        if blacklisted_keywords:
            warning_message += f"Blacklisted keywords: {', '.join(blacklisted_keywords)}"
        
        return warning_message

    # Prepare messages for chat completion
    messages = [{"role": "system", "content": system_message}]
    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})
    messages.append({"role": "user", "content": message})

    # Generate response
    response = ""
    for message in client.chat_completion(
        messages, 
        max_tokens=max_tokens, 
        stream=True, 
        temperature=temperature, 
        top_p=top_p
    ):
        token = message.choices[0].delta.content
        response += token
        yield response

# Create Gradio interface
demo = gr.ChatInterface(
    respond, 
    additional_inputs=[
        gr.Textbox(
            value="You are a professional and friendly assistant.", 
            label="System message"
        ),
        gr.Slider(
            minimum=1, 
            maximum=6144, 
            value=6144, 
            step=1, 
            label="Max new tokens"
        ),
        gr.Slider(
            minimum=0.1, 
            maximum=4.0, 
            value=1, 
            step=0.1, 
            label="Temperature"
        ),
        gr.Slider(
            minimum=0.1, 
            maximum=1.0, 
            value=0.95, 
            step=0.05, 
            label="Top-p (nucleus sampling)"
        ),
    ]
)

if __name__ == "__main__":
    demo.launch(debug=True)