Pijush2023 commited on
Commit
9820a59
·
verified ·
1 Parent(s): 4de005f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -29
app.py CHANGED
@@ -672,9 +672,11 @@ def generate_audio(text, description="Thomas speaks with emphasis and excitement
672
  prompt = tokenizer(preprocess(text), return_tensors="pt").to(device)
673
 
674
  set_seed(SEED)
675
- input_features = model.get_input_features(prompt.input_ids) # Ensure we have input_features
 
676
 
677
- generation = model.generate(input_features=input_features, input_ids=inputs.input_ids)
 
678
  audio_arr = generation.cpu().numpy().squeeze()
679
 
680
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
@@ -684,12 +686,6 @@ def generate_audio(text, description="Thomas speaks with emphasis and excitement
684
  logging.debug(f"Audio saved to {temp_audio_path}")
685
  return temp_audio_path
686
 
687
- def install_parler_tts():
688
- subprocess.check_call([sys.executable, "-m", "pip", "install", "git+https://github.com/huggingface/parler-tts.git"])
689
-
690
- # Call the function to install parler-tts
691
- install_parler_tts()
692
-
693
  # Check if the token is already set in the environment variables
694
  hf_token = os.getenv("HF_TOKEN")
695
 
@@ -739,7 +735,7 @@ def fetch_local_events():
739
  url = f'https://serpapi.com/search.json?engine=google_events&q=Events+in+Birmingham&hl=en&gl=us&api_key={api_key}'
740
 
741
  response = requests.get(url)
742
- if response.status_code == 200:
743
  events_results = response.json().get("events_results", [])
744
  events_html = """
745
  <h2 style="font-family: 'Georgia', serif; color: #ff0000; background-color: #f8f8f8; padding: 10px; border-radius: 10px;">Local Events</h2>
@@ -760,7 +756,7 @@ def fetch_local_events():
760
  }
761
  </style>
762
  """
763
- for index, event in enumerate(events_results):
764
  title = event.get("title", "No title")
765
  date = event.get("date", "No date")
766
  location = event.get("address", "No location")
@@ -772,11 +768,13 @@ def fetch_local_events():
772
  </div>
773
  """
774
  return events_html
775
- else:
 
776
  return "<p>Failed to fetch local events</p>"
 
777
 
778
  def fetch_local_weather():
779
- try:
780
  api_key = os.environ['WEATHER_API']
781
  url = f'https://weather.visualcrossing.com/VisualCrossingWebServices/rest/services/timeline/birmingham?unitGroup=metric&include=events%2Calerts%2Chours%2Cdays%2Ccurrent&key={api_key}'
782
  response = requests.get(url)
@@ -786,10 +784,12 @@ def fetch_local_weather():
786
  current_conditions = jsonData.get("currentConditions", {})
787
  temp_celsius = current_conditions.get("temp", "N/A")
788
 
789
- if temp_celsius != "N/A":
790
  temp_fahrenheit = int((temp_celsius * 9/5) + 32)
791
- else:
 
792
  temp_fahrenheit = "N/A"
 
793
 
794
  condition = current_conditions.get("conditions", "N/A")
795
  humidity = current_conditions.get("humidity", "N/A")
@@ -840,8 +840,10 @@ def fetch_local_weather():
840
  </style>
841
  """
842
  return weather_html
843
- except requests.exceptions.RequestException as e:
 
844
  return f"<p>Failed to fetch local weather: {e}</p>"
 
845
 
846
  def get_weather_icon(condition):
847
  condition_map = {
@@ -990,12 +992,13 @@ def generate_map(location_names):
990
 
991
  for location_name in all_addresses:
992
  geocode_result = gmaps.geocode(location_name)
993
- if geocode_result:
994
  location = geocode_result[0]['geometry']['location']
995
  folium.Marker(
996
  [location['lat'], location['lng']],
997
  tooltip=f"{geocode_result[0]['formatted_address']}"
998
  ).add_to(m)
 
999
 
1000
  map_html = m._repr_html_()
1001
  return map_html
@@ -1004,7 +1007,7 @@ def fetch_local_news():
1004
  api_key = os.environ['SERP_API']
1005
  url = f'https://serpapi.com/search.json?engine=google_news&q=birmingham headline&api_key={api_key}'
1006
  response = requests.get(url)
1007
- if response.status_code == 200:
1008
  results = response.json().get("news_results", [])
1009
  news_html = """
1010
  <h2 style="font-family: 'Georgia', serif; color: #ff0000; background-color: #f8f8f8; padding: 10px; border-radius: 10px;">Birmingham Today</h2>
@@ -1060,7 +1063,7 @@ def fetch_local_news():
1060
  </script>
1061
  <div id="news-preview" class="news-preview"></div>
1062
  """
1063
- for index, result in enumerate(results[:7]):
1064
  title = result.get("title", "No title")
1065
  link = result.get("link", "#")
1066
  snippet = result.get("snippet", "")
@@ -1071,8 +1074,10 @@ def fetch_local_news():
1071
  </div>
1072
  """
1073
  return news_html
1074
- else:
 
1075
  return "<p>Failed to fetch local news</p>"
 
1076
 
1077
  # Voice Control
1078
  import numpy as np
@@ -1097,18 +1102,22 @@ base_audio_drive = "/data/audio"
1097
  import numpy as np
1098
 
1099
  def transcribe_function(stream, new_chunk):
1100
- try:
1101
  sr, y = new_chunk[0], new_chunk[1]
1102
- except TypeError:
 
1103
  print(f"Error chunk structure: {type(new_chunk)}, content: {new_chunk}")
1104
  return stream, "", None
 
1105
 
1106
  y = y.astype(np.float32) / np.max(np.abs(y))
1107
 
1108
- if stream is not None:
1109
  stream = np.concatenate([stream, y])
1110
- else:
 
1111
  stream = y
 
1112
 
1113
  result = pipe_asr({"array": stream, "sampling_rate": sr}, return_timestamps=False)
1114
 
@@ -1151,16 +1160,18 @@ def generate_audio_elevenlabs(text):
1151
  }
1152
  }
1153
  response = requests.post(tts_url, headers=headers, json=data, stream=True)
1154
- if response.ok:
1155
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as f:
1156
- for chunk in response.iter_content(chunk_size=1024):
1157
  f.write(chunk)
1158
  temp_audio_path = f.name
1159
  logging.debug(f"Audio saved to {temp_audio_path}")
1160
  return temp_audio_path
1161
- else:
 
1162
  logging.error(f"Error generating audio: {response.text}")
1163
  return None
 
1164
 
1165
  # Stable Diffusion setup
1166
  pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
@@ -1188,8 +1199,8 @@ def update_images():
1188
 
1189
  with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1190
 
1191
- with gr.Row():
1192
- with gr.Column():
1193
  state = gr.State()
1194
 
1195
  chatbot = gr.Chatbot([], elem_id="RADAR:Channel 94.1", bubble_full_width=False)
@@ -1212,13 +1223,14 @@ with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1212
  # gr.Markdown("<h1 style='color: red;'>Map</h1>", elem_id="location-markdown")
1213
  # location_output = gr.HTML()
1214
  # bot_msg.then(show_map_if_details, [chatbot, choice], [location_output, location_output])
 
1215
 
1216
  # with gr.Column():
1217
  # weather_output = gr.HTML(value=fetch_local_weather())
1218
  # news_output = gr.HTML(value=fetch_local_news())
1219
  # news_output = gr.HTML(value=fetch_local_events())
1220
 
1221
- with gr.Column():
1222
 
1223
  image_output_1 = gr.Image(value=generate_image(hardcoded_prompt_1), width=400, height=400)
1224
  image_output_2 = gr.Image(value=generate_image(hardcoded_prompt_2), width=400, height=400)
@@ -1227,6 +1239,7 @@ with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1227
 
1228
  refresh_button = gr.Button("Refresh Images")
1229
  refresh_button.click(fn=update_images, inputs=None, outputs=[image_output_1, image_output_2, image_output_3])
 
1230
 
1231
  demo.queue()
1232
  demo.launch(share=True)
 
672
  prompt = tokenizer(preprocess(text), return_tensors="pt").to(device)
673
 
674
  set_seed(SEED)
675
+ input_ids = inputs.input_ids.to(device)
676
+ prompt_input_ids = prompt.input_ids.to(device)
677
 
678
+ # Use the generate method to get the audio features
679
+ generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
680
  audio_arr = generation.cpu().numpy().squeeze()
681
 
682
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
 
686
  logging.debug(f"Audio saved to {temp_audio_path}")
687
  return temp_audio_path
688
 
 
 
 
 
 
 
689
  # Check if the token is already set in the environment variables
690
  hf_token = os.getenv("HF_TOKEN")
691
 
 
735
  url = f'https://serpapi.com/search.json?engine=google_events&q=Events+in+Birmingham&hl=en&gl=us&api_key={api_key}'
736
 
737
  response = requests.get(url)
738
+ if response.status_code == 200){
739
  events_results = response.json().get("events_results", [])
740
  events_html = """
741
  <h2 style="font-family: 'Georgia', serif; color: #ff0000; background-color: #f8f8f8; padding: 10px; border-radius: 10px;">Local Events</h2>
 
756
  }
757
  </style>
758
  """
759
+ for index, event in enumerate(events_results){
760
  title = event.get("title", "No title")
761
  date = event.get("date", "No date")
762
  location = event.get("address", "No location")
 
768
  </div>
769
  """
770
  return events_html
771
+ }
772
+ else{
773
  return "<p>Failed to fetch local events</p>"
774
+ }
775
 
776
  def fetch_local_weather():
777
+ try{
778
  api_key = os.environ['WEATHER_API']
779
  url = f'https://weather.visualcrossing.com/VisualCrossingWebServices/rest/services/timeline/birmingham?unitGroup=metric&include=events%2Calerts%2Chours%2Cdays%2Ccurrent&key={api_key}'
780
  response = requests.get(url)
 
784
  current_conditions = jsonData.get("currentConditions", {})
785
  temp_celsius = current_conditions.get("temp", "N/A")
786
 
787
+ if temp_celsius != "N/A"){
788
  temp_fahrenheit = int((temp_celsius * 9/5) + 32)
789
+ }
790
+ else{
791
  temp_fahrenheit = "N/A"
792
+ }
793
 
794
  condition = current_conditions.get("conditions", "N/A")
795
  humidity = current_conditions.get("humidity", "N/A")
 
840
  </style>
841
  """
842
  return weather_html
843
+ }
844
+ catch (requests.exceptions.RequestException as e){
845
  return f"<p>Failed to fetch local weather: {e}</p>"
846
+ }
847
 
848
  def get_weather_icon(condition):
849
  condition_map = {
 
992
 
993
  for location_name in all_addresses:
994
  geocode_result = gmaps.geocode(location_name)
995
+ if geocode_result){
996
  location = geocode_result[0]['geometry']['location']
997
  folium.Marker(
998
  [location['lat'], location['lng']],
999
  tooltip=f"{geocode_result[0]['formatted_address']}"
1000
  ).add_to(m)
1001
+ }
1002
 
1003
  map_html = m._repr_html_()
1004
  return map_html
 
1007
  api_key = os.environ['SERP_API']
1008
  url = f'https://serpapi.com/search.json?engine=google_news&q=birmingham headline&api_key={api_key}'
1009
  response = requests.get(url)
1010
+ if response.status_code == 200){
1011
  results = response.json().get("news_results", [])
1012
  news_html = """
1013
  <h2 style="font-family: 'Georgia', serif; color: #ff0000; background-color: #f8f8f8; padding: 10px; border-radius: 10px;">Birmingham Today</h2>
 
1063
  </script>
1064
  <div id="news-preview" class="news-preview"></div>
1065
  """
1066
+ for index, result in enumerate(results[:7]){
1067
  title = result.get("title", "No title")
1068
  link = result.get("link", "#")
1069
  snippet = result.get("snippet", "")
 
1074
  </div>
1075
  """
1076
  return news_html
1077
+ }
1078
+ else{
1079
  return "<p>Failed to fetch local news</p>"
1080
+ }
1081
 
1082
  # Voice Control
1083
  import numpy as np
 
1102
  import numpy as np
1103
 
1104
  def transcribe_function(stream, new_chunk):
1105
+ try{
1106
  sr, y = new_chunk[0], new_chunk[1]
1107
+ }
1108
+ catch (TypeError){
1109
  print(f"Error chunk structure: {type(new_chunk)}, content: {new_chunk}")
1110
  return stream, "", None
1111
+ }
1112
 
1113
  y = y.astype(np.float32) / np.max(np.abs(y))
1114
 
1115
+ if stream is not None{
1116
  stream = np.concatenate([stream, y])
1117
+ }
1118
+ else{
1119
  stream = y
1120
+ }
1121
 
1122
  result = pipe_asr({"array": stream, "sampling_rate": sr}, return_timestamps=False)
1123
 
 
1160
  }
1161
  }
1162
  response = requests.post(tts_url, headers=headers, json=data, stream=True)
1163
+ if response.ok{
1164
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as f:
1165
+ for chunk in response.iter_content(chunk_size=1024){
1166
  f.write(chunk)
1167
  temp_audio_path = f.name
1168
  logging.debug(f"Audio saved to {temp_audio_path}")
1169
  return temp_audio_path
1170
+ }
1171
+ else{
1172
  logging.error(f"Error generating audio: {response.text}")
1173
  return None
1174
+ }
1175
 
1176
  # Stable Diffusion setup
1177
  pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
 
1199
 
1200
  with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1201
 
1202
+ with gr.Row(){
1203
+ with gr.Column(){
1204
  state = gr.State()
1205
 
1206
  chatbot = gr.Chatbot([], elem_id="RADAR:Channel 94.1", bubble_full_width=False)
 
1223
  # gr.Markdown("<h1 style='color: red;'>Map</h1>", elem_id="location-markdown")
1224
  # location_output = gr.HTML()
1225
  # bot_msg.then(show_map_if_details, [chatbot, choice], [location_output, location_output])
1226
+ }
1227
 
1228
  # with gr.Column():
1229
  # weather_output = gr.HTML(value=fetch_local_weather())
1230
  # news_output = gr.HTML(value=fetch_local_news())
1231
  # news_output = gr.HTML(value=fetch_local_events())
1232
 
1233
+ with gr.Column(){
1234
 
1235
  image_output_1 = gr.Image(value=generate_image(hardcoded_prompt_1), width=400, height=400)
1236
  image_output_2 = gr.Image(value=generate_image(hardcoded_prompt_2), width=400, height=400)
 
1239
 
1240
  refresh_button = gr.Button("Refresh Images")
1241
  refresh_button.click(fn=update_images, inputs=None, outputs=[image_output_1, image_output_2, image_output_3])
1242
+ }
1243
 
1244
  demo.queue()
1245
  demo.launch(share=True)