AshDavid12 commited on
Commit
bdd9100
·
1 Parent(s): c7d5cac

original infer-ivrit

Browse files
Files changed (3) hide show
  1. Dockerfile +4 -2
  2. infer.py +81 -64
  3. requirements.txt +3 -0
Dockerfile CHANGED
@@ -26,6 +26,8 @@ RUN pip install --no-cache-dir -r requirements.txt
26
 
27
  # Copy the current directory contents into the container at /app
28
  COPY . .
 
 
29
 
30
- # Command to run the Python transcription script directly
31
- CMD ["python3","-u", "/infer.py"]
 
26
 
27
  # Copy the current directory contents into the container at /app
28
  COPY . .
29
+ # Expose port 8080 for FastAPI
30
+ EXPOSE 8080
31
 
32
+ # Run FastAPI with Uvicorn
33
+ CMD ["uvicorn", "infer:app", "--host", "0.0.0.0", "--port", "8080"]
infer.py CHANGED
@@ -1,94 +1,111 @@
 
1
  import faster_whisper
2
- import requests
3
  import tempfile
4
- import os
 
 
 
 
5
 
6
- # Load the faster-whisper model that supports Hebrew
7
- model = faster_whisper.WhisperModel("ivrit-ai/faster-whisper-v2-d4")
8
 
9
- # URL of the audio file (replace this with the actual URL of your audio)
10
- audio_url = "https://raw.githubusercontent.com/AshDavid12/runpod-serverless-forked/main/test_hebrew.wav"
11
 
12
- # Download the audio file from the URL
13
- response = requests.get(audio_url)
14
- if response.status_code != 200:
15
- raise Exception("Failed to download audio file")
16
 
17
- # Create a temporary file to store the audio
18
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_audio_file:
19
- tmp_audio_file.write(response.content)
20
- tmp_audio_file_path = tmp_audio_file.name
21
 
22
- # Perform the transcription
23
- segments, info = model.transcribe(tmp_audio_file_path, language="he")
24
 
25
- # Print transcription results
26
- for segment in segments:
27
- print(f"[{segment.start:.2f}s - {segment.end:.2f}s] {segment.text}")
 
 
28
 
29
- # Clean up the temporary file
30
- os.remove(tmp_audio_file_path)
31
 
 
 
 
 
 
 
 
 
32
 
 
 
33
 
 
34
 
 
 
35
 
 
 
 
 
 
 
 
36
 
 
37
 
 
 
 
38
 
39
 
 
 
 
 
 
40
 
 
 
41
 
 
42
 
 
 
43
 
 
 
 
 
 
 
 
 
 
 
 
44
 
 
 
45
 
46
 
 
 
47
 
 
 
 
 
 
 
 
 
 
 
48
 
 
49
 
50
- # import torch
51
- # from transformers import WhisperProcessor, WhisperForConditionalGeneration
52
- # import requests
53
- # import soundfile as sf
54
- # import io
55
 
 
 
 
56
 
57
- # # Load the Whisper model and processor from Hugging Face Model Hub
58
- # model_name = "openai/whisper-base"
59
- # processor = WhisperProcessor.from_pretrained(model_name)
60
- # model = WhisperForConditionalGeneration.from_pretrained(model_name)
61
- #
62
- # # Use GPU if available, otherwise use CPU
63
- # device = "cuda" if torch.cuda.is_available() else "cpu"
64
- # model.to(device)
65
- #
66
- # # URL of the audio file
67
- # audio_url = "https://www.signalogic.com/melp/EngSamples/Orig/male.wav"
68
- #
69
- # # Download the audio file
70
- # response = requests.get(audio_url)
71
- # audio_data = io.BytesIO(response.content)
72
- #
73
- # # Read the audio using soundfile
74
- # audio_input, _ = sf.read(audio_data)
75
- #
76
- # # Preprocess the audio for Whisper
77
- # inputs = processor(audio_input, return_tensors="pt", sampling_rate=16000)
78
- # attention_mask = inputs['input_features'].ne(processor.tokenizer.pad_token_id).long()
79
- #
80
- # # Move inputs and attention mask to the correct device
81
- # inputs = {key: value.to(device) for key, value in inputs.items()}
82
- # attention_mask = attention_mask.to(device)
83
- #
84
- # # Generate the transcription with attention mask
85
- # with torch.no_grad():
86
- # predicted_ids = model.generate(
87
- # inputs["input_features"],
88
- # attention_mask=attention_mask # Pass attention mask explicitly
89
- # )
90
- # # Decode the transcription
91
- # transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
92
- #
93
- # # Print the transcription result
94
- # print("Transcription:", transcription)
 
1
+ import base64
2
  import faster_whisper
 
3
  import tempfile
4
+ import torch
5
+ import requests
6
+ from fastapi import FastAPI, HTTPException
7
+ from pydantic import BaseModel
8
+ from typing import Optional
9
 
10
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
11
 
12
+ model_name = 'ivrit-ai/faster-whisper-v2-d4'
13
+ model = faster_whisper.WhisperModel(model_name, device=device)
14
 
15
+ # Maximum data size: 200MB
16
+ MAX_PAYLOAD_SIZE = 200 * 1024 * 1024
 
 
17
 
18
+ app = FastAPI()
 
 
 
19
 
 
 
20
 
21
+ class InputData(BaseModel):
22
+ type: str
23
+ data: Optional[str] = None # Used for blob input
24
+ url: Optional[str] = None # Used for url input
25
+ api_key: Optional[str] = None
26
 
 
 
27
 
28
+ def download_file(url, max_size_bytes, output_filename, api_key=None):
29
+ """
30
+ Download a file from a given URL with size limit and optional API key.
31
+ """
32
+ try:
33
+ headers = {}
34
+ if api_key:
35
+ headers['Authorization'] = f'Bearer {api_key}'
36
 
37
+ response = requests.get(url, stream=True, headers=headers)
38
+ response.raise_for_status()
39
 
40
+ file_size = int(response.headers.get('Content-Length', 0))
41
 
42
+ if file_size > max_size_bytes:
43
+ return False
44
 
45
+ downloaded_size = 0
46
+ with open(output_filename, 'wb') as file:
47
+ for chunk in response.iter_content(chunk_size=8192):
48
+ downloaded_size += len(chunk)
49
+ if downloaded_size > max_size_bytes:
50
+ return False
51
+ file.write(chunk)
52
 
53
+ return True
54
 
55
+ except requests.RequestException as e:
56
+ print(f"Error downloading file: {e}")
57
+ return False
58
 
59
 
60
+ @app.post("/transcribe")
61
+ async def transcribe(input_data: InputData):
62
+ datatype = input_data.type
63
+ if not datatype:
64
+ raise HTTPException(status_code=400, detail="datatype field not provided. Should be 'blob' or 'url'.")
65
 
66
+ if datatype not in ['blob', 'url']:
67
+ raise HTTPException(status_code=400, detail=f"datatype should be 'blob' or 'url', but is {datatype} instead.")
68
 
69
+ api_key = input_data.api_key
70
 
71
+ with tempfile.TemporaryDirectory() as d:
72
+ audio_file = f'{d}/audio.mp3'
73
 
74
+ if datatype == 'blob':
75
+ if not input_data.data:
76
+ raise HTTPException(status_code=400, detail="Missing 'data' for 'blob' input.")
77
+ mp3_bytes = base64.b64decode(input_data.data)
78
+ open(audio_file, 'wb').write(mp3_bytes)
79
+ elif datatype == 'url':
80
+ if not input_data.url:
81
+ raise HTTPException(status_code=400, detail="Missing 'url' for 'url' input.")
82
+ success = download_file(input_data.url, MAX_PAYLOAD_SIZE, audio_file, api_key)
83
+ if not success:
84
+ raise HTTPException(status_code=400, detail=f"Error downloading data from {input_data.url}")
85
 
86
+ result = transcribe_core(audio_file)
87
+ return {"result": result}
88
 
89
 
90
+ def transcribe_core(audio_file):
91
+ print('Transcribing...')
92
 
93
+ ret = {'segments': []}
94
+ segs, _ = model.transcribe(audio_file, language='he', word_timestamps=True)
95
+ for s in segs:
96
+ words = [{'start': w.start, 'end': w.end, 'word': w.word, 'probability': w.probability} for w in s.words]
97
+ seg = {
98
+ 'id': s.id, 'seek': s.seek, 'start': s.start, 'end': s.end, 'text': s.text, 'avg_logprob': s.avg_logprob,
99
+ 'compression_ratio': s.compression_ratio, 'no_speech_prob': s.no_speech_prob, 'words': words
100
+ }
101
+ print(seg)
102
+ ret['segments'].append(seg)
103
 
104
+ return ret
105
 
 
 
 
 
 
106
 
107
+ # Make sure Uvicorn starts correctly when deployed
108
+ if __name__ == "__main__":
109
+ import uvicorn
110
 
111
+ uvicorn.run(app, host="0.0.0.0", port=8080)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -4,4 +4,7 @@ requests
4
  transformers
5
  soundfile
6
  faster-whisper
 
 
 
7
 
 
4
  transformers
5
  soundfile
6
  faster-whisper
7
+ torch
8
+ uvicorn
9
+ fastapi
10