maolin.liu
commited on
Commit
·
7ec99bc
1
Parent(s):
ef7f04e
[bugfix]Do not serialize string.
Browse files- consumer/asr.py +4 -4
- consumer/base.py +6 -4
consumer/asr.py
CHANGED
@@ -3,7 +3,7 @@ import io
|
|
3 |
import logging
|
4 |
import os
|
5 |
from pathlib import Path
|
6 |
-
from typing import Literal
|
7 |
|
8 |
from faster_whisper import WhisperModel
|
9 |
from pydantic import BaseModel, Field, ValidationError, model_validator
|
@@ -58,7 +58,7 @@ class TranscribeConsumer(BasicMessageReceiver):
|
|
58 |
def setup_message_sender(self):
|
59 |
self.sender = BasicMessageSender()
|
60 |
|
61 |
-
def send_message(self, message: dict):
|
62 |
routing_key = 'transcribe-output'
|
63 |
# headers = Headers(job_id=f'{uuid.uuid4()}', priority=Priority.NORMAL)
|
64 |
self.sender.send_message(
|
@@ -71,11 +71,11 @@ class TranscribeConsumer(BasicMessageReceiver):
|
|
71 |
def send_success_message(self, uuid: str, transcribed_text):
|
72 |
message = TranscribeOutputMessage(uuid=uuid, if_success=True, msg='Transcribe finished.',
|
73 |
transcribed_text=transcribed_text)
|
74 |
-
self.send_message(message.
|
75 |
|
76 |
def send_fail_message(self, uuid: str, error: str):
|
77 |
message = TranscribeOutputMessage(uuid=uuid, if_success=False, msg=error)
|
78 |
-
self.send_message(message.
|
79 |
|
80 |
def consume(self, channel, method, properties, message):
|
81 |
body = self.decode_message(message)
|
|
|
3 |
import logging
|
4 |
import os
|
5 |
from pathlib import Path
|
6 |
+
from typing import Literal, Union
|
7 |
|
8 |
from faster_whisper import WhisperModel
|
9 |
from pydantic import BaseModel, Field, ValidationError, model_validator
|
|
|
58 |
def setup_message_sender(self):
|
59 |
self.sender = BasicMessageSender()
|
60 |
|
61 |
+
def send_message(self, message: Union[dict, str]):
|
62 |
routing_key = 'transcribe-output'
|
63 |
# headers = Headers(job_id=f'{uuid.uuid4()}', priority=Priority.NORMAL)
|
64 |
self.sender.send_message(
|
|
|
71 |
def send_success_message(self, uuid: str, transcribed_text):
|
72 |
message = TranscribeOutputMessage(uuid=uuid, if_success=True, msg='Transcribe finished.',
|
73 |
transcribed_text=transcribed_text)
|
74 |
+
self.send_message(message.model_dump_json())
|
75 |
|
76 |
def send_fail_message(self, uuid: str, error: str):
|
77 |
message = TranscribeOutputMessage(uuid=uuid, if_success=False, msg=error)
|
78 |
+
self.send_message(message.model_dump_json())
|
79 |
|
80 |
def consume(self, channel, method, properties, message):
|
81 |
body = self.decode_message(message)
|
consumer/base.py
CHANGED
@@ -6,7 +6,7 @@ import os
|
|
6 |
import ssl
|
7 |
import time
|
8 |
from enum import Enum
|
9 |
-
from typing import Dict, Optional, Literal
|
10 |
|
11 |
import msgpack
|
12 |
import pika
|
@@ -140,11 +140,13 @@ class BasicPikaClient:
|
|
140 |
class BasicMessageSender(BasicPikaClient):
|
141 |
message_encoding_type: Literal['bytes', 'json'] = 'json'
|
142 |
|
143 |
-
def encode_message(self, body: Dict, encoding_type: str = "bytes"):
|
144 |
if encoding_type == "bytes":
|
145 |
return msgpack.packb(body)
|
146 |
elif encoding_type == "json":
|
147 |
-
|
|
|
|
|
148 |
else:
|
149 |
raise NotImplementedError
|
150 |
|
@@ -152,7 +154,7 @@ class BasicMessageSender(BasicPikaClient):
|
|
152 |
self,
|
153 |
exchange_name: str,
|
154 |
routing_key: str,
|
155 |
-
body: Dict,
|
156 |
headers: Optional[Headers],
|
157 |
):
|
158 |
body = self.encode_message(body=body, encoding_type=self.message_encoding_type)
|
|
|
6 |
import ssl
|
7 |
import time
|
8 |
from enum import Enum
|
9 |
+
from typing import Dict, Optional, Literal, Union
|
10 |
|
11 |
import msgpack
|
12 |
import pika
|
|
|
140 |
class BasicMessageSender(BasicPikaClient):
|
141 |
message_encoding_type: Literal['bytes', 'json'] = 'json'
|
142 |
|
143 |
+
def encode_message(self, body: Union[Dict, str], encoding_type: str = "bytes"):
|
144 |
if encoding_type == "bytes":
|
145 |
return msgpack.packb(body)
|
146 |
elif encoding_type == "json":
|
147 |
+
if isinstance(body, dict):
|
148 |
+
return json.dumps(body)
|
149 |
+
return body
|
150 |
else:
|
151 |
raise NotImplementedError
|
152 |
|
|
|
154 |
self,
|
155 |
exchange_name: str,
|
156 |
routing_key: str,
|
157 |
+
body: Union[Dict, str],
|
158 |
headers: Optional[Headers],
|
159 |
):
|
160 |
body = self.encode_message(body=body, encoding_type=self.message_encoding_type)
|