WajeehAzeemX commited on
Commit
327ec66
·
1 Parent(s): f177433

whisper ar2

Browse files
Files changed (2) hide show
  1. __pycache__/app.cpython-310.pyc +0 -0
  2. app.py +6 -8
__pycache__/app.cpython-310.pyc ADDED
Binary file (2 kB). View file
 
app.py CHANGED
@@ -5,23 +5,22 @@ import torchaudio
5
  from transformers import AutoProcessor, pipeline
6
  import io
7
  from pydub import AudioSegment
8
- from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
9
  import numpy as np
10
  import uvicorn
11
  import time
12
  app = FastAPI()
 
13
 
14
  # Device configuration
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  print(device)
17
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
18
-
19
  # Load the model and processor
20
- model_id = "WajeehAzeemX/whisper-small-ar2_onnx"
21
- model = ORTModelForSpeechSeq2Seq.from_pretrained(
22
- model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
23
  )
24
- processor = AutoProcessor.from_pretrained(model_id)
25
 
26
 
27
  pipe = pipeline(
@@ -29,7 +28,6 @@ pipe = pipeline(
29
  model=model,
30
  tokenizer=processor.tokenizer,
31
  feature_extractor=processor.feature_extractor,
32
- torch_dtype=torch_dtype,
33
  )
34
 
35
  @app.post("/transcribe/")
 
5
  from transformers import AutoProcessor, pipeline
6
  import io
7
  from pydub import AudioSegment
8
+ # from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
9
  import numpy as np
10
  import uvicorn
11
  import time
12
  app = FastAPI()
13
+ from transformers import WhisperForConditionalGeneration, WhisperProcessor
14
 
15
  # Device configuration
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
  print(device)
 
 
18
  # Load the model and processor
19
+ model_id = "whitefox123/whisper-small-ar2"
20
+ model = WhisperForConditionalGeneration.from_pretrained(
21
+ model_id
22
  )
23
+ processor = WhisperProcessor.from_pretrained(model_id)
24
 
25
 
26
  pipe = pipeline(
 
28
  model=model,
29
  tokenizer=processor.tokenizer,
30
  feature_extractor=processor.feature_extractor,
 
31
  )
32
 
33
  @app.post("/transcribe/")