venkatks515 commited on
Commit
82db251
·
1 Parent(s): 9cf557b

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -0
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import time
4
+ import librosa
5
+ import soundfile
6
+ import nemo.collections.asr as nemo_asr
7
+ import tempfile
8
+ import os
9
+ import uuid
10
+
11
+ from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration
12
+ import torch
13
+
14
+ # PersistDataset -----
15
+ import os
16
+ import csv
17
+ import gradio as gr
18
+ from gradio import inputs, outputs
19
+ import huggingface_hub
20
+ from huggingface_hub import Repository, hf_hub_download, upload_file
21
+ from datetime import datetime
22
+
23
+ # ---------------------------------------------
24
+ # Dataset and Token links - change awacke1 to your own HF id, and add a HF_TOKEN copy to your repo for write permissions
25
+ # This should allow you to save your results to your own Dataset hosted on HF.
26
+
27
+ DATASET_REPO_URL = "https://huggingface.co/datasets/awacke1/ASRLive.csv"
28
+ DATASET_REPO_ID = "awacke1/ASRLive.csv"
29
+ DATA_FILENAME = "ASRLive.csv"
30
+ DATA_FILE = os.path.join("data", DATA_FILENAME)
31
+ HF_TOKEN = os.environ.get("HF_TOKEN")
32
+
33
+ PersistToDataset = False
34
+ #PersistToDataset = True # uncomment to save inference output to ASRLive.csv dataset
35
+
36
+ if PersistToDataset:
37
+ try:
38
+ hf_hub_download(
39
+ repo_id=DATASET_REPO_ID,
40
+ filename=DATA_FILENAME,
41
+ cache_dir=DATA_DIRNAME,
42
+ force_filename=DATA_FILENAME
43
+ )
44
+ except:
45
+ print("file not found")
46
+ repo = Repository(
47
+ local_dir="data", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN
48
+ )
49
+
50
+ def store_message(name: str, message: str):
51
+ if name and message:
52
+ with open(DATA_FILE, "a") as csvfile:
53
+ writer = csv.DictWriter(csvfile, fieldnames=["name", "message", "time"])
54
+ writer.writerow(
55
+ {"name": name.strip(), "message": message.strip(), "time": str(datetime.now())}
56
+ )
57
+ # uncomment line below to begin saving -
58
+ commit_url = repo.push_to_hub()
59
+ ret = ""
60
+ with open(DATA_FILE, "r") as csvfile:
61
+ reader = csv.DictReader(csvfile)
62
+
63
+ for row in reader:
64
+ ret += row
65
+ ret += "\r\n"
66
+ return ret
67
+
68
+ # main -------------------------
69
+ mname = "facebook/blenderbot-400M-distill"
70
+ model = BlenderbotForConditionalGeneration.from_pretrained(mname)
71
+ tokenizer = BlenderbotTokenizer.from_pretrained(mname)
72
+
73
+ def take_last_tokens(inputs, note_history, history):
74
+ filterTokenCount = 128 # filter last 128 tokens
75
+ if inputs['input_ids'].shape[1] > filterTokenCount:
76
+ inputs['input_ids'] = torch.tensor([inputs['input_ids'][0][-filterTokenCount:].tolist()])
77
+ inputs['attention_mask'] = torch.tensor([inputs['attention_mask'][0][-filterTokenCount:].tolist()])
78
+ note_history = ['</s> <s>'.join(note_history[0].split('</s> <s>')[2:])]
79
+ history = history[1:]
80
+ return inputs, note_history, history
81
+
82
+ def add_note_to_history(note, note_history):
83
+ note_history.append(note)
84
+ note_history = '</s> <s>'.join(note_history)
85
+ return [note_history]
86
+
87
+
88
+
89
+ SAMPLE_RATE = 16000
90
+ model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained("nvidia/stt_en_conformer_transducer_xlarge")
91
+ model.change_decoding_strategy(None)
92
+ model.eval()
93
+
94
+ def process_audio_file(file):
95
+ data, sr = librosa.load(file)
96
+ if sr != SAMPLE_RATE:
97
+ data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE)
98
+ data = librosa.to_mono(data)
99
+ return data
100
+
101
+
102
+ def transcribe(audio, state = ""):
103
+ if state is None:
104
+ state = ""
105
+ audio_data = process_audio_file(audio)
106
+ with tempfile.TemporaryDirectory() as tmpdir:
107
+ audio_path = os.path.join(tmpdir, f'audio_{uuid.uuid4()}.wav')
108
+ soundfile.write(audio_path, audio_data, SAMPLE_RATE)
109
+ transcriptions = model.transcribe([audio_path])
110
+ if type(transcriptions) == tuple and len(transcriptions) == 2:
111
+ transcriptions = transcriptions[0]
112
+ transcriptions = transcriptions[0]
113
+
114
+ if PersistToDataset:
115
+ ret = store_message(transcriptions, state) # Save to dataset - uncomment to store into a dataset - hint you will need your HF_TOKEN
116
+ state = state + transcriptions + " " + ret
117
+ else:
118
+ state = state + transcriptions
119
+ return state, state
120
+
121
+ gr.Interface(
122
+ fn=transcribe,
123
+ inputs=[
124
+ gr.Audio(source="microphone", type='filepath', streaming=True),
125
+ "state",
126
+ ],
127
+ outputs=[
128
+ "textbox",
129
+ "state"
130
+ ],
131
+ layout="horizontal",
132
+ theme="huggingface",
133
+ title="🗣️ASR-Live🧠Memory💾",
134
+ description=f"Live Automatic Speech Recognition (ASR) with Memory💾 Dataset.",
135
+ allow_flagging='never',
136
+ live=True,
137
+ article=f"Result Output Saved to Memory💾 Dataset: [{DATASET_REPO_URL}]({DATASET_REPO_URL})"
138
+ ).launch(debug=True)