abdullahmubeen10 commited on
Commit
eb62ff9
Β·
verified Β·
1 Parent(s): b344ee7

Update Demo.py

Browse files
Files changed (1) hide show
  1. Demo.py +8 -16
Demo.py CHANGED
@@ -65,21 +65,13 @@ def create_pipeline(model):
65
  def fit_data(pipeline, fed_data):
66
  """Fit the data into the pipeline and return the transcription."""
67
  data, sampling_rate = librosa.load(fed_data, sr=16000)
68
- data = [float(x) for x in data]
69
-
70
- schema = StructType([
71
- StructField("audio_content", ArrayType(FloatType())),
72
- StructField("sampling_rate", LongType())
73
- ])
74
-
75
- df = pd.DataFrame({
76
- "audio_content": [data],
77
- "sampling_rate": [sampling_rate]
78
- })
79
 
80
- spark_df = spark.createDataFrame(df, schema)
81
- pipeline_df = pipeline.fit(spark_df).transform(spark_df)
82
- return pipeline_df.select("text.result")
 
83
 
84
  def save_uploadedfile(uploadedfile, path):
85
  """Save the uploaded file to the specified path."""
@@ -119,7 +111,7 @@ st.sidebar.markdown("""
119
  """, unsafe_allow_html=True)
120
 
121
  # Load examples
122
- AUDIO_FILE_PATH = "inputs"
123
  audio_files = sorted(os.listdir(AUDIO_FILE_PATH))
124
 
125
  selected_audio = st.selectbox("Select an audio", audio_files)
@@ -146,4 +138,4 @@ pipeline = create_pipeline(model)
146
  output = fit_data(pipeline, selected_audio)
147
 
148
  st.subheader(f"Transcription:")
149
- st.markdown(f"**{output[0]}**")
 
65
  def fit_data(pipeline, fed_data):
66
  """Fit the data into the pipeline and return the transcription."""
67
  data, sampling_rate = librosa.load(fed_data, sr=16000)
68
+ data = data.tolist()
69
+ spark_df = spark.createDataFrame([[data]], ["audio_content"])
 
 
 
 
 
 
 
 
 
70
 
71
+ model = pipeline.fit(spark_df)
72
+ lp = LightPipeline(model)
73
+ lp_result = lp.fullAnnotate(data)[0]
74
+ return lp_result
75
 
76
  def save_uploadedfile(uploadedfile, path):
77
  """Save the uploaded file to the specified path."""
 
111
  """, unsafe_allow_html=True)
112
 
113
  # Load examples
114
+ AUDIO_FILE_PATH = "/content/Wav2Vec2ForCTC/inputs"
115
  audio_files = sorted(os.listdir(AUDIO_FILE_PATH))
116
 
117
  selected_audio = st.selectbox("Select an audio", audio_files)
 
138
  output = fit_data(pipeline, selected_audio)
139
 
140
  st.subheader(f"Transcription:")
141
+ st.markdown(f"{(output['text'][0].result).title()}")