deepugaur commited on
Commit
6ff3417
·
verified ·
1 Parent(s): 489c898

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -48
app.py CHANGED
@@ -1,53 +1,59 @@
1
- import librosa
2
- import numpy as np
3
-
4
- def extract_features(audio_path):
5
- y, sr = librosa.load(audio_path, sr=16000)
6
- mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
7
- return np.mean(mfccs.T, axis=0)
8
-
9
- # Example usage
10
- features = extract_features("path/to/audio/file.wav")
11
-
12
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer, MarianMTModel, MarianTokenizer
13
-
14
- # Load pre-trained models
15
- speech_recognition_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h")
16
- speech_recognition_tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h")
17
- translation_model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-hi")
18
- translation_tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-hi")
19
-
20
- from transformers import pipeline
21
-
22
- # Example inference pipeline
23
- def translate_audio(audio_path):
24
- # Speech Recognition
25
- speech_input = speech_recognition_tokenizer(extract_features(audio_path), return_tensors="pt").input_values
26
- logits = speech_recognition_model(speech_input).logits
27
- transcription = speech_recognition_tokenizer.batch_decode(torch.argmax(logits, dim=-1))[0]
28
-
29
- # Translation
30
- translated = translation_model.generate(**translation_tokenizer.prepare_seq2seq_batch(transcription, return_tensors="pt"))
31
- translation = translation_tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
32
-
33
- return translation
34
-
35
- # Save the models and tokenizer
36
- speech_recognition_model.save_pretrained("path/to/save/wav2vec2")
37
- speech_recognition_tokenizer.save_pretrained("path/to/save/wav2vec2")
38
- translation_model.save_pretrained("path/to/save/opus-mt-en-hi")
39
- translation_tokenizer.save_pretrained("path/to/save/opus-mt-en-hi")
40
-
41
- from datetime import datetime
42
  import pytz
43
-
44
- def is_after_6_pm_ist():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  ist = pytz.timezone('Asia/Kolkata')
46
  current_time = datetime.now(ist)
47
  return current_time.hour >= 18
48
 
49
- if is_after_6_pm_ist():
50
- translation = translate_audio("path/to/audio/file.wav")
51
- print(translation)
52
- else:
53
- print("The translation service is available after 6 PM IST.")
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, MarianMTModel, MarianTokenizer
3
+ import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import pytz
5
+ from datetime import datetime
6
+ from pydub import AudioSegment
7
+ import io
8
+
9
+ app = Flask(__name__)
10
+
11
+ # Load pre-trained models and tokenizers
12
+ asr_model_name = "facebook/wav2vec2-large-960h"
13
+ translation_model_name = "Helsinki-NLP/opus-mt-en-hi"
14
+
15
+ asr_processor = Wav2Vec2Processor.from_pretrained(asr_model_name)
16
+ asr_model = Wav2Vec2ForCTC.from_pretrained(asr_model_name)
17
+ translator = MarianMTModel.from_pretrained(translation_model_name)
18
+ tokenizer = MarianTokenizer.from_pretrained(translation_model_name)
19
+
20
+ # Function to convert audio file to text
21
+ def audio_to_text(audio_file):
22
+ audio_input = AudioSegment.from_file(audio_file)
23
+ audio_array = np.array(audio_input.get_array_of_samples())
24
+ inputs = asr_processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True)
25
+ with torch.no_grad():
26
+ logits = asr_model(inputs.input_values).logits
27
+ predicted_ids = torch.argmax(logits, dim=-1)
28
+ transcription = asr_processor.batch_decode(predicted_ids)[0]
29
+ return transcription
30
+
31
+ # Function to translate text from English to Hindi
32
+ def translate_text(text):
33
+ inputs = tokenizer(text, return_tensors="pt", padding=True)
34
+ translated = translator.generate(**inputs)
35
+ translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
36
+ return translated_text
37
+
38
+ # Function to check if the current time is after 6 PM IST
39
+ def is_after_6pm_ist():
40
  ist = pytz.timezone('Asia/Kolkata')
41
  current_time = datetime.now(ist)
42
  return current_time.hour >= 18
43
 
44
+ @app.route('/translate', methods=['POST'])
45
+ def translate_audio():
46
+ if not is_after_6pm_ist():
47
+ return jsonify({'error': 'Service available only after 6 PM IST'}), 403
48
+
49
+ if 'audio' not in request.files:
50
+ return jsonify({'error': 'No audio file provided'}), 400
51
+
52
+ audio_file = request.files['audio']
53
+ text = audio_to_text(audio_file)
54
+ translated_text = translate_text(text)
55
+ return jsonify({'translation': translated_text})
56
+
57
+ if __name__ == '__main__':
58
+ app.run(host='0.0.0.0', port=5000)
59
+