maolin.liu commited on
Commit
7ec99bc
·
1 Parent(s): ef7f04e

[bugfix]Do not serialize string.

Browse files
Files changed (2) hide show
  1. consumer/asr.py +4 -4
  2. consumer/base.py +6 -4
consumer/asr.py CHANGED
@@ -3,7 +3,7 @@ import io
3
  import logging
4
  import os
5
  from pathlib import Path
6
- from typing import Literal
7
 
8
  from faster_whisper import WhisperModel
9
  from pydantic import BaseModel, Field, ValidationError, model_validator
@@ -58,7 +58,7 @@ class TranscribeConsumer(BasicMessageReceiver):
58
  def setup_message_sender(self):
59
  self.sender = BasicMessageSender()
60
 
61
- def send_message(self, message: dict):
62
  routing_key = 'transcribe-output'
63
  # headers = Headers(job_id=f'{uuid.uuid4()}', priority=Priority.NORMAL)
64
  self.sender.send_message(
@@ -71,11 +71,11 @@ class TranscribeConsumer(BasicMessageReceiver):
71
  def send_success_message(self, uuid: str, transcribed_text):
72
  message = TranscribeOutputMessage(uuid=uuid, if_success=True, msg='Transcribe finished.',
73
  transcribed_text=transcribed_text)
74
- self.send_message(message.model_dump())
75
 
76
  def send_fail_message(self, uuid: str, error: str):
77
  message = TranscribeOutputMessage(uuid=uuid, if_success=False, msg=error)
78
- self.send_message(message.model_dump())
79
 
80
  def consume(self, channel, method, properties, message):
81
  body = self.decode_message(message)
 
3
  import logging
4
  import os
5
  from pathlib import Path
6
+ from typing import Literal, Union
7
 
8
  from faster_whisper import WhisperModel
9
  from pydantic import BaseModel, Field, ValidationError, model_validator
 
58
  def setup_message_sender(self):
59
  self.sender = BasicMessageSender()
60
 
61
+ def send_message(self, message: Union[dict, str]):
62
  routing_key = 'transcribe-output'
63
  # headers = Headers(job_id=f'{uuid.uuid4()}', priority=Priority.NORMAL)
64
  self.sender.send_message(
 
71
  def send_success_message(self, uuid: str, transcribed_text):
72
  message = TranscribeOutputMessage(uuid=uuid, if_success=True, msg='Transcribe finished.',
73
  transcribed_text=transcribed_text)
74
+ self.send_message(message.model_dump_json())
75
 
76
  def send_fail_message(self, uuid: str, error: str):
77
  message = TranscribeOutputMessage(uuid=uuid, if_success=False, msg=error)
78
+ self.send_message(message.model_dump_json())
79
 
80
  def consume(self, channel, method, properties, message):
81
  body = self.decode_message(message)
consumer/base.py CHANGED
@@ -6,7 +6,7 @@ import os
6
  import ssl
7
  import time
8
  from enum import Enum
9
- from typing import Dict, Optional, Literal
10
 
11
  import msgpack
12
  import pika
@@ -140,11 +140,13 @@ class BasicPikaClient:
140
  class BasicMessageSender(BasicPikaClient):
141
  message_encoding_type: Literal['bytes', 'json'] = 'json'
142
 
143
- def encode_message(self, body: Dict, encoding_type: str = "bytes"):
144
  if encoding_type == "bytes":
145
  return msgpack.packb(body)
146
  elif encoding_type == "json":
147
- return json.dumps(body)
 
 
148
  else:
149
  raise NotImplementedError
150
 
@@ -152,7 +154,7 @@ class BasicMessageSender(BasicPikaClient):
152
  self,
153
  exchange_name: str,
154
  routing_key: str,
155
- body: Dict,
156
  headers: Optional[Headers],
157
  ):
158
  body = self.encode_message(body=body, encoding_type=self.message_encoding_type)
 
6
  import ssl
7
  import time
8
  from enum import Enum
9
+ from typing import Dict, Optional, Literal, Union
10
 
11
  import msgpack
12
  import pika
 
140
  class BasicMessageSender(BasicPikaClient):
141
  message_encoding_type: Literal['bytes', 'json'] = 'json'
142
 
143
+ def encode_message(self, body: Union[Dict, str], encoding_type: str = "bytes"):
144
  if encoding_type == "bytes":
145
  return msgpack.packb(body)
146
  elif encoding_type == "json":
147
+ if isinstance(body, dict):
148
+ return json.dumps(body)
149
+ return body
150
  else:
151
  raise NotImplementedError
152
 
 
154
  self,
155
  exchange_name: str,
156
  routing_key: str,
157
+ body: Union[Dict, str],
158
  headers: Optional[Headers],
159
  ):
160
  body = self.encode_message(body=body, encoding_type=self.message_encoding_type)