File size: 15,897 Bytes
0dc9888
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317


from asyncio import CancelledError
from concurrent.futures import Future, ThreadPoolExecutor
import os
import re
import threading
import time
from common.dequeue import Dequeue
from channel.channel import Channel
from bridge.reply import *
from bridge.context import *
from config import conf
from common.log import logger
from plugins import *
try:
    from voice.audio_convert import any_to_wav
except Exception as e:
    pass

# 抽象类, 它包含了与消息通道无关的通用处理逻辑
class ChatChannel(Channel):
    name = None # 登录的用户名
    user_id = None # 登录的用户id
    futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
    sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理
    lock = threading.Lock() # 用于控制对sessions的访问
    handler_pool = ThreadPoolExecutor(max_workers=8)  # 处理消息的线程池

    def __init__(self):
        _thread = threading.Thread(target=self.consume)
        _thread.setDaemon(True)
        _thread.start()
        

    # 根据消息构造context,消息内容相关的触发项写在这里
    def _compose_context(self, ctype: ContextType, content, **kwargs):
        context = Context(ctype, content)
        context.kwargs = kwargs
        # context首次传入时,origin_ctype是None, 
        # 引入的起因是:当输入语音时,会嵌套生成两个context,第一步语音转文本,第二步通过文本生成文字回复。
        # origin_ctype用于第二步文本回复时,判断是否需要匹配前缀,如果是私聊的语音,就不需要匹配前缀
        if 'origin_ctype' not in context:  
            context['origin_ctype'] = ctype
        # context首次传入时,receiver是None,根据类型设置receiver
        first_in = 'receiver' not in context
        # 群名匹配过程,设置session_id和receiver
        if first_in: # context首次传入时,receiver是None,根据类型设置receiver
            config = conf()
            cmsg = context['msg']
            if cmsg.from_user_id == self.user_id and not config.get('trigger_by_self', True):
                logger.debug("[WX]self message skipped")
                return None
            if context["isgroup"]:
                group_name = cmsg.other_user_nickname
                group_id = cmsg.other_user_id

                group_name_white_list = config.get('group_name_white_list', [])
                group_name_keyword_white_list = config.get('group_name_keyword_white_list', [])
                if any([group_name in group_name_white_list, 'ALL_GROUP' in group_name_white_list, check_contain(group_name, group_name_keyword_white_list)]):
                    group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
                    session_id = cmsg.actual_user_id
                    if any([group_name in group_chat_in_one_session, 'ALL_GROUP' in group_chat_in_one_session]):
                        session_id = group_id
                else:
                    return None
                context['session_id'] = session_id
                context['receiver'] = group_id
            else:
                context['session_id'] = cmsg.other_user_id
                context['receiver'] = cmsg.other_user_id

        # 消息内容匹配过程,并处理content
        if ctype == ContextType.TEXT:
            if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息
                logger.debug("[WX]reference query skipped")
                return None
            
            if context["isgroup"]: # 群聊
                # 校验关键字
                match_prefix = check_prefix(content, conf().get('group_chat_prefix'))
                match_contain = check_contain(content, conf().get('group_chat_keyword'))
                flag = False
                if match_prefix is not None or match_contain is not None:
                    flag = True
                    if match_prefix:
                        content = content.replace(match_prefix, '', 1).strip()
                if context['msg'].is_at:
                    logger.info("[WX]receive group at")
                    if not conf().get("group_at_off", False):
                        flag = True
                    pattern = f'@{self.name}(\u2005|\u0020)'
                    content = re.sub(pattern, r'', content)
                
                if not flag:
                    if context["origin_ctype"] == ContextType.VOICE:
                        logger.info("[WX]receive group voice, but checkprefix didn't match")
                    return None
            else: # 单聊
                match_prefix = check_prefix(content, conf().get('single_chat_prefix'))  
                if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
                    content = content.replace(match_prefix, '', 1).strip()
                elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
                    pass
                else:
                    return None     
                                                  
            img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
            if img_match_prefix:
                content = content.replace(img_match_prefix, '', 1).strip()
                context.type = ContextType.IMAGE_CREATE
            else:
                context.type = ContextType.TEXT
            context.content = content
            if 'desire_rtype' not in context and conf().get('always_reply_voice') and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
                context['desire_rtype'] = ReplyType.VOICE
        elif context.type == ContextType.VOICE: 
            if 'desire_rtype' not in context and conf().get('voice_reply_voice') and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
                context['desire_rtype'] = ReplyType.VOICE

        return context

    def _handle(self, context: Context):
        if context is None or not context.content:
            return
        logger.debug('[WX] ready to handle context: {}'.format(context))
        # reply的构建步骤
        reply = self._generate_reply(context)

        logger.debug('[WX] ready to decorate reply: {}'.format(reply))
        # reply的包装步骤
        reply = self._decorate_reply(context, reply)

        # reply的发送步骤
        self._send_reply(context, reply)

    def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
        e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {
            'channel': self, 'context': context, 'reply': reply}))
        reply = e_context['reply']
        if not e_context.is_pass():
            logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content))
            if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE:  # 文字和图片消息
                reply = super().build_reply_content(context.content, context)
            elif context.type == ContextType.VOICE:  # 语音消息
                cmsg = context['msg']
                cmsg.prepare()
                file_path = context.content
                wav_path = os.path.splitext(file_path)[0] + '.wav'
                try:
                    any_to_wav(file_path, wav_path) 
                except Exception as e:  # 转换失败,直接使用mp3,对于某些api,mp3也可以识别
                    logger.warning("[WX]any to wav error, use raw path. " + str(e))
                    wav_path = file_path
                # 语音识别
                reply = super().build_voice_to_text(wav_path)
                # 删除临时文件
                try:
                    os.remove(file_path)
                    if wav_path != file_path:
                        os.remove(wav_path)
                except Exception as e:
                    pass
                    # logger.warning("[WX]delete temp file error: " + str(e))

                if reply.type == ReplyType.TEXT:
                    new_context = self._compose_context(
                        ContextType.TEXT, reply.content, **context.kwargs)
                    if new_context:
                        reply = self._generate_reply(new_context)
                    else:
                        return
            else:
                logger.error('[WX] unknown context type: {}'.format(context.type))
                return
        return reply

    def _decorate_reply(self, context: Context, reply: Reply) -> Reply:
        if reply and reply.type:
            e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {
                'channel': self, 'context': context, 'reply': reply}))
            reply = e_context['reply']
            desire_rtype = context.get('desire_rtype')
            if not e_context.is_pass() and reply and reply.type:
                
                if reply.type in self.NOT_SUPPORT_REPLYTYPE:
                    logger.error("[WX]reply type not support: " + str(reply.type))
                    reply.type = ReplyType.ERROR
                    reply.content = "不支持发送的消息类型: " + str(reply.type)

                if reply.type == ReplyType.TEXT:
                    reply_text = reply.content
                    if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
                        reply = super().build_text_to_voice(reply.content)
                        return self._decorate_reply(context, reply)
                    if context['isgroup']:
                        reply_text = '@' +  context['msg'].actual_user_nickname + ' ' + reply_text.strip()
                        reply_text = conf().get("group_chat_reply_prefix", "") + reply_text
                    else:
                        reply_text = conf().get("single_chat_reply_prefix", "") + reply_text
                    reply.content = reply_text
                elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
                    reply.content = "["+str(reply.type)+"]\n" + reply.content
                elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
                    pass
                else:
                    logger.error('[WX] unknown reply type: {}'.format(reply.type))
                    return
            if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
                logger.warning('[WX] desire_rtype: {}, but reply type: {}'.format(context.get('desire_rtype'), reply.type))
            return reply

    def _send_reply(self, context: Context, reply: Reply):
        if reply and reply.type:
            e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {
                'channel': self, 'context': context, 'reply': reply}))
            reply = e_context['reply']
            if not e_context.is_pass() and reply and reply.type:
                logger.debug('[WX] ready to send reply: {}, context: {}'.format(reply, context))
                self._send(reply, context)

    def _send(self, reply: Reply, context: Context, retry_cnt = 0):
        try:
            self.send(reply, context)
        except Exception as e:
            logger.error('[WX] sendMsg error: {}'.format(str(e)))
            if isinstance(e, NotImplementedError):
                return
            logger.exception(e)
            if retry_cnt < 2:
                time.sleep(3+3*retry_cnt)
                self._send(reply, context, retry_cnt+1)

    def thread_pool_callback(self, session_id):
        def func(worker:Future):
            try:
                worker_exception = worker.exception()
                if worker_exception:
                    logger.exception("Worker return exception: {}".format(worker_exception))
            except CancelledError as e:
                logger.info("Worker cancelled, session_id = {}".format(session_id))
            except Exception as e:
                logger.exception("Worker raise exception: {}".format(e))
            with self.lock:
                self.sessions[session_id][1].release()
        return func

    def produce(self, context: Context):
        session_id = context['session_id']
        with self.lock:
            if session_id not in self.sessions:
                self.sessions[session_id] = [Dequeue(), threading.BoundedSemaphore(conf().get("concurrency_in_session", 1))]
            if context.type == ContextType.TEXT and context.content.startswith("#"): 
                self.sessions[session_id][0].putleft(context) # 优先处理管理命令
            else:
                self.sessions[session_id][0].put(context)

    # 消费者函数,单独线程,用于从消息队列中取出消息并处理
    def consume(self):
        while True:
            with self.lock:
                session_ids = list(self.sessions.keys())
                for session_id in session_ids:
                    context_queue, semaphore = self.sessions[session_id]
                    if semaphore.acquire(blocking = False): # 等线程处理完毕才能删除
                        if not context_queue.empty():
                            context = context_queue.get()
                            logger.debug("[WX] consume context: {}".format(context))
                            future:Future = self.handler_pool.submit(self._handle, context)
                            future.add_done_callback(self.thread_pool_callback(session_id))
                            if session_id not in self.futures:
                                self.futures[session_id] = []
                            self.futures[session_id].append(future)
                        elif semaphore._initial_value == semaphore._value+1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
                            self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
                            assert len(self.futures[session_id]) == 0, "thread pool error"
                            del self.sessions[session_id]
                        else:
                            semaphore.release()
            time.sleep(0.1)

    # 取消session_id对应的所有任务,只能取消排队的消息和已提交线程池但未执行的任务
    def cancel_session(self, session_id): 
        with self.lock:
            if session_id in self.sessions:
                for future in self.futures[session_id]:
                    future.cancel()
                cnt = self.sessions[session_id][0].qsize()
                if cnt>0:
                    logger.info("Cancel {} messages in session {}".format(cnt, session_id))
                self.sessions[session_id][0] = Dequeue()
    
    def cancel_all_session(self):
        with self.lock:
            for session_id in self.sessions:
                for future in self.futures[session_id]:
                    future.cancel()
                cnt = self.sessions[session_id][0].qsize()
                if cnt>0:
                    logger.info("Cancel {} messages in session {}".format(cnt, session_id))
                self.sessions[session_id][0] = Dequeue()
    

def check_prefix(content, prefix_list):
    for prefix in prefix_list:
        if content.startswith(prefix):
            return prefix
    return None

def check_contain(content, keyword_list):
    if not keyword_list:
        return None
    for ky in keyword_list:
        if content.find(ky) != -1:
            return True
    return None