GitHub Actions commited on
Commit
8fa1fd3
·
1 Parent(s): 57a5b90

Sync from GitHub repo

Browse files
Files changed (1) hide show
  1. app.py +57 -3
app.py CHANGED
@@ -4,6 +4,7 @@ from apscheduler.schedulers.background import BackgroundScheduler
4
  from concurrent.futures import ThreadPoolExecutor
5
  from datetime import datetime
6
  import threading # Added for locking
 
7
 
8
  year = datetime.now().year
9
  month = datetime.now().month
@@ -117,6 +118,7 @@ TTS_CACHE_SIZE = int(os.getenv("TTS_CACHE_SIZE", "10"))
117
  CACHE_AUDIO_SUBDIR = "cache"
118
  tts_cache = {} # sentence -> {model_a, model_b, audio_a, audio_b, created_at}
119
  tts_cache_lock = threading.Lock()
 
120
  # Increased max_workers to 8 for concurrent generation/refill
121
  cache_executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix='CacheReplacer')
122
  all_harvard_sentences = [] # Keep the full list available
@@ -433,7 +435,7 @@ def _generate_cache_entry_task(sentence):
433
  return
434
 
435
  try:
436
- models = random.sample(available_models, 2)
437
  model_a_id = models[0].id
438
  model_b_id = models[1].id
439
 
@@ -574,7 +576,7 @@ def generate_tts():
574
  if len(available_models) < 2:
575
  return jsonify({"error": "Not enough TTS models available"}), 500
576
 
577
- selected_models = random.sample(available_models, 2)
578
 
579
  try:
580
  audio_files = []
@@ -840,7 +842,7 @@ def generate_podcast():
840
  if len(available_models) < 2:
841
  return jsonify({"error": "Not enough conversational models available"}), 500
842
 
843
- selected_models = random.sample(available_models, 2)
844
 
845
  try:
846
  # Generate audio for both models concurrently
@@ -1306,6 +1308,58 @@ def get_cached_sentences():
1306
  return jsonify(cached_keys)
1307
 
1308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1309
  if __name__ == "__main__":
1310
  with app.app_context():
1311
  # Ensure ./instance and ./votes directories exist
 
4
  from concurrent.futures import ThreadPoolExecutor
5
  from datetime import datetime
6
  import threading # Added for locking
7
+ from sqlalchemy import or_ # Added for vote counting query
8
 
9
  year = datetime.now().year
10
  month = datetime.now().month
 
118
  CACHE_AUDIO_SUBDIR = "cache"
119
  tts_cache = {} # sentence -> {model_a, model_b, audio_a, audio_b, created_at}
120
  tts_cache_lock = threading.Lock()
121
+ SMOOTHING_FACTOR_MODEL_SELECTION = 500 # For weighted random model selection
122
  # Increased max_workers to 8 for concurrent generation/refill
123
  cache_executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix='CacheReplacer')
124
  all_harvard_sentences = [] # Keep the full list available
 
435
  return
436
 
437
  try:
438
+ models = get_weighted_random_models(available_models, 2, ModelType.TTS)
439
  model_a_id = models[0].id
440
  model_b_id = models[1].id
441
 
 
576
  if len(available_models) < 2:
577
  return jsonify({"error": "Not enough TTS models available"}), 500
578
 
579
+ selected_models = get_weighted_random_models(available_models, 2, ModelType.TTS)
580
 
581
  try:
582
  audio_files = []
 
842
  if len(available_models) < 2:
843
  return jsonify({"error": "Not enough conversational models available"}), 500
844
 
845
+ selected_models = get_weighted_random_models(available_models, 2, ModelType.CONVERSATIONAL)
846
 
847
  try:
848
  # Generate audio for both models concurrently
 
1308
  return jsonify(cached_keys)
1309
 
1310
 
1311
+ def get_weighted_random_models(
1312
+ applicable_models: list[Model], num_to_select: int, model_type: ModelType
1313
+ ) -> list[Model]:
1314
+ """
1315
+ Selects a specified number of models randomly from a list of applicable_models,
1316
+ weighting models with fewer votes higher. A smoothing factor is used to ensure
1317
+ the preference is slight and to prevent models with zero votes from being
1318
+ overwhelmingly favored. Models are selected without replacement.
1319
+
1320
+ Assumes len(applicable_models) >= num_to_select, which should be checked by the caller.
1321
+ """
1322
+ model_votes_counts = {}
1323
+ for model in applicable_models:
1324
+ votes = (
1325
+ Vote.query.filter(Vote.model_type == model_type)
1326
+ .filter(or_(Vote.chosen_model_id == model.id, Vote.rejected_model_id == model.id))
1327
+ .count()
1328
+ )
1329
+ model_votes_counts[model.id] = votes
1330
+
1331
+ weights = [
1332
+ 1.0 / (model_votes_counts[model.id] + SMOOTHING_FACTOR_MODEL_SELECTION)
1333
+ for model in applicable_models
1334
+ ]
1335
+
1336
+ selected_models_list = []
1337
+ # Create copies to modify during selection process
1338
+ current_candidates = list(applicable_models)
1339
+ current_weights = list(weights)
1340
+
1341
+ # Assumes num_to_select is positive and less than or equal to len(current_candidates)
1342
+ # Callers should ensure this (e.g., len(available_models) >= 2).
1343
+ for _ in range(num_to_select):
1344
+ if not current_candidates: # Safety break
1345
+ app.logger.warning("Not enough candidates left for weighted selection.")
1346
+ break
1347
+
1348
+ chosen_model = random.choices(current_candidates, weights=current_weights, k=1)[0]
1349
+ selected_models_list.append(chosen_model)
1350
+
1351
+ try:
1352
+ idx_to_remove = current_candidates.index(chosen_model)
1353
+ current_candidates.pop(idx_to_remove)
1354
+ current_weights.pop(idx_to_remove)
1355
+ except ValueError:
1356
+ # This should ideally not happen if chosen_model came from current_candidates.
1357
+ app.logger.error(f"Error removing model {chosen_model.id} from weighted selection candidates.")
1358
+ break # Avoid potential issues
1359
+
1360
+ return selected_models_list
1361
+
1362
+
1363
  if __name__ == "__main__":
1364
  with app.app_context():
1365
  # Ensure ./instance and ./votes directories exist