Merge pull request #44 from Infamous-Hydra/Infamous-Hydra-patch-17
Browse files- Database/sql/warns_sql.py +327 -0
Database/sql/warns_sql.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import threading
|
2 |
+
|
3 |
+
from sqlalchemy import (
|
4 |
+
BigInteger,
|
5 |
+
Boolean,
|
6 |
+
Column,
|
7 |
+
Integer,
|
8 |
+
String,
|
9 |
+
UnicodeText,
|
10 |
+
distinct,
|
11 |
+
func,
|
12 |
+
)
|
13 |
+
from sqlalchemy.dialects import postgresql
|
14 |
+
|
15 |
+
from Database.sql import BASE, SESSION
|
16 |
+
|
17 |
+
|
18 |
+
class Warns(BASE):
|
19 |
+
__tablename__ = "warns"
|
20 |
+
|
21 |
+
user_id = Column(BigInteger, primary_key=True)
|
22 |
+
chat_id = Column(String(14), primary_key=True)
|
23 |
+
num_warns = Column(Integer, default=0)
|
24 |
+
reasons = Column(postgresql.ARRAY(UnicodeText))
|
25 |
+
|
26 |
+
def __init__(self, user_id, chat_id):
|
27 |
+
self.user_id = user_id
|
28 |
+
self.chat_id = str(chat_id)
|
29 |
+
self.num_warns = 0
|
30 |
+
self.reasons = []
|
31 |
+
|
32 |
+
def __repr__(self):
|
33 |
+
return "<{} warns for {} in {} for reasons {}>".format(
|
34 |
+
self.num_warns,
|
35 |
+
self.user_id,
|
36 |
+
self.chat_id,
|
37 |
+
self.reasons,
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
class WarnFilters(BASE):
|
42 |
+
__tablename__ = "warn_filters"
|
43 |
+
chat_id = Column(String(14), primary_key=True)
|
44 |
+
keyword = Column(UnicodeText, primary_key=True, nullable=False)
|
45 |
+
reply = Column(UnicodeText, nullable=False)
|
46 |
+
|
47 |
+
def __init__(self, chat_id, keyword, reply):
|
48 |
+
self.chat_id = str(chat_id) # ensure string
|
49 |
+
self.keyword = keyword
|
50 |
+
self.reply = reply
|
51 |
+
|
52 |
+
def __repr__(self):
|
53 |
+
return "<Permissions for %s>" % self.chat_id
|
54 |
+
|
55 |
+
def __eq__(self, other):
|
56 |
+
return bool(
|
57 |
+
isinstance(other, WarnFilters)
|
58 |
+
and self.chat_id == other.chat_id
|
59 |
+
and self.keyword == other.keyword,
|
60 |
+
)
|
61 |
+
|
62 |
+
|
63 |
+
class WarnSettings(BASE):
|
64 |
+
__tablename__ = "warn_settings"
|
65 |
+
chat_id = Column(String(14), primary_key=True)
|
66 |
+
warn_limit = Column(Integer, default=3)
|
67 |
+
soft_warn = Column(Boolean, default=False)
|
68 |
+
|
69 |
+
def __init__(self, chat_id, warn_limit=3, soft_warn=False):
|
70 |
+
self.chat_id = str(chat_id)
|
71 |
+
self.warn_limit = warn_limit
|
72 |
+
self.soft_warn = soft_warn
|
73 |
+
|
74 |
+
def __repr__(self):
|
75 |
+
return "<{} has {} possible warns.>".format(self.chat_id, self.warn_limit)
|
76 |
+
|
77 |
+
|
78 |
+
Warns.__table__.create(checkfirst=True)
|
79 |
+
WarnFilters.__table__.create(checkfirst=True)
|
80 |
+
WarnSettings.__table__.create(checkfirst=True)
|
81 |
+
|
82 |
+
WARN_INSERTION_LOCK = threading.RLock()
|
83 |
+
WARN_FILTER_INSERTION_LOCK = threading.RLock()
|
84 |
+
WARN_SETTINGS_LOCK = threading.RLock()
|
85 |
+
|
86 |
+
WARN_FILTERS = {}
|
87 |
+
|
88 |
+
|
89 |
+
def warn_user(user_id, chat_id, reason=None):
|
90 |
+
with WARN_INSERTION_LOCK:
|
91 |
+
warned_user = SESSION.query(Warns).get((user_id, str(chat_id)))
|
92 |
+
if not warned_user:
|
93 |
+
warned_user = Warns(user_id, str(chat_id))
|
94 |
+
|
95 |
+
warned_user.num_warns += 1
|
96 |
+
if reason:
|
97 |
+
warned_user.reasons = warned_user.reasons + [
|
98 |
+
reason,
|
99 |
+
] # TODO:: double check this wizardry
|
100 |
+
|
101 |
+
reasons = warned_user.reasons
|
102 |
+
num = warned_user.num_warns
|
103 |
+
|
104 |
+
SESSION.add(warned_user)
|
105 |
+
SESSION.commit()
|
106 |
+
|
107 |
+
return num, reasons
|
108 |
+
|
109 |
+
|
110 |
+
def remove_warn(user_id, chat_id):
|
111 |
+
with WARN_INSERTION_LOCK:
|
112 |
+
removed = False
|
113 |
+
warned_user = SESSION.query(Warns).get((user_id, str(chat_id)))
|
114 |
+
|
115 |
+
if warned_user and warned_user.num_warns > 0:
|
116 |
+
warned_user.num_warns -= 1
|
117 |
+
warned_user.reasons = warned_user.reasons[:-1]
|
118 |
+
SESSION.add(warned_user)
|
119 |
+
SESSION.commit()
|
120 |
+
removed = True
|
121 |
+
|
122 |
+
SESSION.close()
|
123 |
+
return removed
|
124 |
+
|
125 |
+
|
126 |
+
def reset_warns(user_id, chat_id):
|
127 |
+
with WARN_INSERTION_LOCK:
|
128 |
+
warned_user = SESSION.query(Warns).get((user_id, str(chat_id)))
|
129 |
+
if warned_user:
|
130 |
+
warned_user.num_warns = 0
|
131 |
+
warned_user.reasons = []
|
132 |
+
|
133 |
+
SESSION.add(warned_user)
|
134 |
+
SESSION.commit()
|
135 |
+
SESSION.close()
|
136 |
+
|
137 |
+
|
138 |
+
def get_warns(user_id, chat_id):
|
139 |
+
try:
|
140 |
+
user = SESSION.query(Warns).get((user_id, str(chat_id)))
|
141 |
+
if not user:
|
142 |
+
return None
|
143 |
+
reasons = user.reasons
|
144 |
+
num = user.num_warns
|
145 |
+
return num, reasons
|
146 |
+
finally:
|
147 |
+
SESSION.close()
|
148 |
+
|
149 |
+
|
150 |
+
def add_warn_filter(chat_id, keyword, reply):
|
151 |
+
with WARN_FILTER_INSERTION_LOCK:
|
152 |
+
warn_filt = WarnFilters(str(chat_id), keyword, reply)
|
153 |
+
|
154 |
+
if keyword not in WARN_FILTERS.get(str(chat_id), []):
|
155 |
+
WARN_FILTERS[str(chat_id)] = sorted(
|
156 |
+
WARN_FILTERS.get(str(chat_id), []) + [keyword],
|
157 |
+
key=lambda x: (-len(x), x),
|
158 |
+
)
|
159 |
+
|
160 |
+
SESSION.merge(warn_filt) # merge to avoid duplicate key issues
|
161 |
+
SESSION.commit()
|
162 |
+
|
163 |
+
|
164 |
+
def remove_warn_filter(chat_id, keyword):
|
165 |
+
with WARN_FILTER_INSERTION_LOCK:
|
166 |
+
warn_filt = SESSION.query(WarnFilters).get((str(chat_id), keyword))
|
167 |
+
if warn_filt:
|
168 |
+
if keyword in WARN_FILTERS.get(str(chat_id), []): # sanity check
|
169 |
+
WARN_FILTERS.get(str(chat_id), []).remove(keyword)
|
170 |
+
|
171 |
+
SESSION.delete(warn_filt)
|
172 |
+
SESSION.commit()
|
173 |
+
return True
|
174 |
+
SESSION.close()
|
175 |
+
return False
|
176 |
+
|
177 |
+
|
178 |
+
def get_chat_warn_triggers(chat_id):
|
179 |
+
return WARN_FILTERS.get(str(chat_id), set())
|
180 |
+
|
181 |
+
|
182 |
+
def get_chat_warn_filters(chat_id):
|
183 |
+
try:
|
184 |
+
return (
|
185 |
+
SESSION.query(WarnFilters).filter(WarnFilters.chat_id == str(chat_id)).all()
|
186 |
+
)
|
187 |
+
finally:
|
188 |
+
SESSION.close()
|
189 |
+
|
190 |
+
|
191 |
+
def get_warn_filter(chat_id, keyword):
|
192 |
+
try:
|
193 |
+
return SESSION.query(WarnFilters).get((str(chat_id), keyword))
|
194 |
+
finally:
|
195 |
+
SESSION.close()
|
196 |
+
|
197 |
+
|
198 |
+
def set_warn_limit(chat_id, warn_limit):
|
199 |
+
with WARN_SETTINGS_LOCK:
|
200 |
+
curr_setting = SESSION.query(WarnSettings).get(str(chat_id))
|
201 |
+
if not curr_setting:
|
202 |
+
curr_setting = WarnSettings(chat_id, warn_limit=warn_limit)
|
203 |
+
|
204 |
+
curr_setting.warn_limit = warn_limit
|
205 |
+
|
206 |
+
SESSION.add(curr_setting)
|
207 |
+
SESSION.commit()
|
208 |
+
|
209 |
+
|
210 |
+
def set_warn_strength(chat_id, soft_warn):
|
211 |
+
with WARN_SETTINGS_LOCK:
|
212 |
+
curr_setting = SESSION.query(WarnSettings).get(str(chat_id))
|
213 |
+
if not curr_setting:
|
214 |
+
curr_setting = WarnSettings(chat_id, soft_warn=soft_warn)
|
215 |
+
|
216 |
+
curr_setting.soft_warn = soft_warn
|
217 |
+
|
218 |
+
SESSION.add(curr_setting)
|
219 |
+
SESSION.commit()
|
220 |
+
|
221 |
+
|
222 |
+
def get_warn_setting(chat_id):
|
223 |
+
try:
|
224 |
+
setting = SESSION.query(WarnSettings).get(str(chat_id))
|
225 |
+
if setting:
|
226 |
+
return setting.warn_limit, setting.soft_warn
|
227 |
+
else:
|
228 |
+
return 3, False
|
229 |
+
|
230 |
+
finally:
|
231 |
+
SESSION.close()
|
232 |
+
|
233 |
+
|
234 |
+
def num_warns():
|
235 |
+
try:
|
236 |
+
return SESSION.query(func.sum(Warns.num_warns)).scalar() or 0
|
237 |
+
finally:
|
238 |
+
SESSION.close()
|
239 |
+
|
240 |
+
|
241 |
+
def num_warn_chats():
|
242 |
+
try:
|
243 |
+
return SESSION.query(func.count(distinct(Warns.chat_id))).scalar()
|
244 |
+
finally:
|
245 |
+
SESSION.close()
|
246 |
+
|
247 |
+
|
248 |
+
def num_warn_filters():
|
249 |
+
try:
|
250 |
+
return SESSION.query(WarnFilters).count()
|
251 |
+
finally:
|
252 |
+
SESSION.close()
|
253 |
+
|
254 |
+
|
255 |
+
def num_warn_chat_filters(chat_id):
|
256 |
+
try:
|
257 |
+
return (
|
258 |
+
SESSION.query(WarnFilters.chat_id)
|
259 |
+
.filter(WarnFilters.chat_id == str(chat_id))
|
260 |
+
.count()
|
261 |
+
)
|
262 |
+
finally:
|
263 |
+
SESSION.close()
|
264 |
+
|
265 |
+
|
266 |
+
def num_warn_filter_chats():
|
267 |
+
try:
|
268 |
+
return SESSION.query(func.count(distinct(WarnFilters.chat_id))).scalar()
|
269 |
+
finally:
|
270 |
+
SESSION.close()
|
271 |
+
|
272 |
+
|
273 |
+
def __load_chat_warn_filters():
|
274 |
+
global WARN_FILTERS
|
275 |
+
try:
|
276 |
+
chats = SESSION.query(WarnFilters.chat_id).distinct().all()
|
277 |
+
for (chat_id,) in chats: # remove tuple by ( ,)
|
278 |
+
WARN_FILTERS[chat_id] = []
|
279 |
+
|
280 |
+
all_filters = SESSION.query(WarnFilters).all()
|
281 |
+
for x in all_filters:
|
282 |
+
WARN_FILTERS[x.chat_id] += [x.keyword]
|
283 |
+
|
284 |
+
WARN_FILTERS = {
|
285 |
+
x: sorted(set(y), key=lambda i: (-len(i), i))
|
286 |
+
for x, y in WARN_FILTERS.items()
|
287 |
+
}
|
288 |
+
|
289 |
+
finally:
|
290 |
+
SESSION.close()
|
291 |
+
|
292 |
+
|
293 |
+
def migrate_chat(old_chat_id, new_chat_id):
|
294 |
+
with WARN_INSERTION_LOCK:
|
295 |
+
chat_notes = (
|
296 |
+
SESSION.query(Warns).filter(Warns.chat_id == str(old_chat_id)).all()
|
297 |
+
)
|
298 |
+
for note in chat_notes:
|
299 |
+
note.chat_id = str(new_chat_id)
|
300 |
+
SESSION.commit()
|
301 |
+
|
302 |
+
with WARN_FILTER_INSERTION_LOCK:
|
303 |
+
chat_filters = (
|
304 |
+
SESSION.query(WarnFilters)
|
305 |
+
.filter(WarnFilters.chat_id == str(old_chat_id))
|
306 |
+
.all()
|
307 |
+
)
|
308 |
+
for filt in chat_filters:
|
309 |
+
filt.chat_id = str(new_chat_id)
|
310 |
+
SESSION.commit()
|
311 |
+
old_warn_filt = WARN_FILTERS.get(str(old_chat_id))
|
312 |
+
if old_warn_filt is not None:
|
313 |
+
WARN_FILTERS[str(new_chat_id)] = old_warn_filt
|
314 |
+
del WARN_FILTERS[str(old_chat_id)]
|
315 |
+
|
316 |
+
with WARN_SETTINGS_LOCK:
|
317 |
+
chat_settings = (
|
318 |
+
SESSION.query(WarnSettings)
|
319 |
+
.filter(WarnSettings.chat_id == str(old_chat_id))
|
320 |
+
.all()
|
321 |
+
)
|
322 |
+
for setting in chat_settings:
|
323 |
+
setting.chat_id = str(new_chat_id)
|
324 |
+
SESSION.commit()
|
325 |
+
|
326 |
+
|
327 |
+
__load_chat_warn_filters()
|