cstr commited on
Commit
16bc3e4
·
verified ·
1 Parent(s): 8f3e46d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -665
app.py CHANGED
@@ -1,670 +1,60 @@
1
- #python app.py
2
  import gradio as gr
3
  import os
4
- import pandas as pd
5
- import requests
6
- from pathlib import Path
7
- import ctranslate2
8
  import time
9
- import logging
10
- import transformers
11
- import json
12
- import io
13
- from tqdm import tqdm
14
  import subprocess
15
- from huggingface_hub import snapshot_download, upload_file, HfApi, create_repo
16
 
17
- # Function to download a Parquet file from a specified URL
18
- def download_parquet(url, local_path):
19
- response = requests.get(url, stream=True)
20
- if response.status_code == 200:
21
- with open(local_path, 'wb') as file:
22
- for chunk in response.iter_content(chunk_size=1024):
23
- file.write(chunk)
24
- print("File downloaded successfully.")
25
- else:
26
- print(f"Failed to download file, status code: {response.status_code}")
27
-
28
- # Function to convert Parquet files to JSONL format
29
- def convert_parquet_to_jsonl_polars(input_file, output_dir, override=False):
30
- output_dir_path = Path(output_dir)
31
- output_dir_path.mkdir(parents=True, exist_ok=True)
32
-
33
- input_path = Path(input_file)
34
- output_file_path = output_dir_path / input_path.with_suffix(".jsonl").name
35
-
36
- if output_file_path.exists() and not override:
37
- print(f"Skipping because output exists already: {output_file_path}")
38
- else:
39
- df = pl.read_parquet(input_path)
40
- df.write_ndjson(output_file_path)
41
- print(f"Data written to {output_file_path}")
42
-
43
- def convert_parquet_to_jsonl(parquet_filename, jsonl_filename):
44
- try:
45
- # Read the parquet file
46
- df = pd.read_parquet(parquet_filename)
47
- logger.info(f"Read Parquet file {parquet_filename} successfully.")
48
-
49
- # Convert the dataframe to a JSON string and handle Unicode characters and forward slashes
50
- json_str = df.to_json(orient='records', lines=True, force_ascii=False)
51
- logger.info(f"Converted Parquet file to JSON string.")
52
-
53
- # Replace escaped forward slashes if needed
54
- json_str = json_str.replace('\\/', '/')
55
-
56
- # Write the modified JSON string to the JSONL file
57
- jsonl_filename += '/train.jsonl'
58
- logger.info(f"Attempting to save to {jsonl_filename}")
59
- with open(jsonl_filename, 'w', encoding='utf-8') as file:
60
- file.write(json_str)
61
- logger.info(f"Data saved to {jsonl_filename}")
62
- except Exception as e:
63
- logger.error(f"Failed to convert Parquet to JSONL: {e}")
64
- raise
65
-
66
- # Function to count lines in a JSONL file
67
- def count_lines_in_jsonl(file_path):
68
- with open(file_path, 'r', encoding='utf-8') as file:
69
- line_count = sum(1 for _ in file)
70
- return line_count
71
-
72
- def parse_range_specification(range_specification, file_length):
73
- line_indices = []
74
- ranges = range_specification.split(',')
75
- for r in ranges:
76
- if '-' in r:
77
- parts = r.split('-')
78
- start = int(parts[0]) - 1 if parts[0] else 0
79
- end = int(parts[1]) - 1 if parts[1] else file_length - 1
80
- if start < 0 or end >= file_length:
81
- logging.error(f"Range {r} is out of bounds.")
82
- continue # Skip ranges that are out of bounds
83
- line_indices.extend(range(start, end + 1))
84
- else:
85
- single_line = int(r) - 1
86
- if single_line < 0 or single_line >= file_length:
87
- logging.error(f"Line number {r} is out of bounds.")
88
- continue # Skip line numbers that are out of bounds
89
- line_indices.append(single_line)
90
- return line_indices
91
-
92
- def translate_text(text, translator, tokenizer, target_language):
93
- """
94
- Translates the given text from English to German using CTranslate2 and the WMT21 model,
95
- with special handling for newlines and segmenting text longer than 500 characters.
96
- Ensures sequences of newlines (\n\n, \n\n\n, etc.) are accurately reproduced.
97
- """
98
- try:
99
- segments = []
100
- newline_sequences = [] # To store sequences of newlines
101
- segment = ""
102
-
103
- i = 0
104
- while i < len(text):
105
- # Collect sequences of newlines
106
- if text[i] == '\n':
107
- newline_sequence = '\n'
108
- while i + 1 < len(text) and text[i + 1] == '\n':
109
- newline_sequence += '\n'
110
- i += 1
111
- if segment:
112
- segments.append(segment) # Add the preceding text segment
113
- segment = ""
114
- newline_sequences.append(newline_sequence) # Store the newline sequence
115
- else:
116
- segment += text[i]
117
- # If segment exceeds 500 characters, or if we reach the end of the text, process it
118
- if len(segment) >= 500 or i == len(text) - 1:
119
- end_index = max(segment.rfind('.', 0, 500), segment.rfind('?', 0, 500), segment.rfind('!', 0, 500))
120
- if end_index != -1 and len(segment) > 500:
121
- # Split at the last punctuation within the first 500 characters
122
- segments.append(segment[:end_index+1])
123
- segment = segment[end_index+1:].lstrip()
124
- else:
125
- # No suitable punctuation or end of text, add the whole segment
126
- segments.append(segment)
127
- segment = ""
128
- i += 1
129
-
130
- # Translate the collected text segments
131
- translated_segments = []
132
- for segment in segments:
133
- source = tokenizer.convert_ids_to_tokens(tokenizer.encode(segment))
134
- target_prefix = [tokenizer.lang_code_to_token[target_language]]
135
- results = translator.translate_batch([source], target_prefix=[target_prefix])
136
- target = results[0].hypotheses[0][1:]
137
- translated_segment = tokenizer.decode(tokenizer.convert_tokens_to_ids(target))
138
- translated_segments.append(translated_segment)
139
-
140
- # Reassemble the translated text with original newline sequences
141
- translated_text = ""
142
- for i, segment in enumerate(translated_segments):
143
- translated_text += segment
144
- if i < len(newline_sequences):
145
- translated_text += newline_sequences[i] # Insert the newline sequence
146
-
147
- return translated_text.strip()
148
-
149
- except Exception as e:
150
- logging.error(f"An error occurred during translation: {e}")
151
- return None
152
-
153
- def translate_item_ufb(item, raw_file_path, translator, tokenizer, target_language):
154
- try:
155
- # Translate the prompt directly since it's a string
156
- translated_prompt = translate_text(item['prompt'], translator, tokenizer)
157
-
158
- # Translate the chosen and rejected contents
159
- translated_chosen = []
160
- for choice in item['chosen']:
161
- translated_content = translate_text(choice['content'], translator, tokenizer, target_language)
162
- translated_chosen.append({'content': translated_content, 'role': choice['role']})
163
-
164
- translated_rejected = []
165
- for choice in item['rejected']:
166
- translated_content = translate_text(choice['content'], translator, tokenizer, target_language)
167
- translated_rejected.append({'content': translated_content, 'role': choice['role']})
168
-
169
- # Write the raw response to a backup file
170
- with open(raw_file_path, 'a', encoding='utf-8') as raw_file:
171
- raw_file.write(f"Prompt: {translated_prompt}\n")
172
- raw_file.write(f"Chosen: {json.dumps(translated_chosen, ensure_ascii=False)}\n")
173
- raw_file.write(f"Rejected: {json.dumps(translated_rejected, ensure_ascii=False)}\n\n")
174
-
175
- logging.info("Translation request successful.")
176
- # Update the original item with the translated fields
177
- item['prompt'] = translated_prompt
178
- item['chosen'] = translated_chosen
179
- item['rejected'] = translated_rejected
180
- return item
181
-
182
- except Exception as e:
183
- logging.error(f"An error occurred during translation: {e}")
184
- return None
185
-
186
- def validate_item_ufb(item):
187
- # Check basic required fields including 'prompt' as a simple string
188
- required_fields = ['source', 'prompt', 'chosen', 'rejected']
189
- for field in required_fields:
190
- if field not in item:
191
- logging.warning(f"Missing required field: {field}")
192
- return False
193
- if field == 'prompt' and not isinstance(item['prompt'], str):
194
- logging.warning("Prompt must be a string.")
195
- return False
196
-
197
- # Check 'chosen' and 'rejected' which should be lists of dictionaries
198
- for field in ['chosen', 'rejected']:
199
- if not isinstance(item[field], list) or not item[field]:
200
- logging.warning(f"No entries or incorrect type for section: {field}")
201
- return False
202
- for idx, message in enumerate(item[field]):
203
- if 'content' not in message or 'role' not in message:
204
- logging.warning(f"Missing 'content' or 'role' field in {field} at index {idx}")
205
- return False
206
- if not isinstance(message['content'], str) or not isinstance(message['role'], str):
207
- logging.warning(f"Invalid type for 'content' or 'role' field in {field} at index {idx}")
208
- return False
209
-
210
- return True
211
-
212
-
213
-
214
- def translate_item_mix(item, raw_file_path, translator, tokenizer, target_language):
215
- """
216
- Translates the relevant fields in the given item from English to German using CTranslate2 and the WMT21 model,
217
- and saves the raw response to a backup file.
218
- """
219
- #print ("translating:", item)
220
- try:
221
- # Translate each part of the prompt separately and preserve the order
222
- translated_prompts = []
223
- for message in item['prompt']:
224
- translated_content = translate_text(message['content'], translator, tokenizer, target_language)
225
- translated_prompts.append({'content': translated_content, 'role': message['role']})
226
-
227
- # Translate the chosen and rejected contents
228
- translated_chosen_content = translate_text(item['chosen'][0]['content'], translator, tokenizer, target_language)
229
- translated_rejected_content = translate_text(item['rejected'][0]['content'], translator, tokenizer, target_language)
230
-
231
- # Write the raw response to a backup file
232
- with open(raw_file_path, 'a', encoding='utf-8') as raw_file:
233
- raw_file.write("Prompt content:\n")
234
- for translated_prompt in translated_prompts:
235
- raw_file.write(f"{translated_prompt['role']}: {translated_prompt['content']}\n")
236
- raw_file.write(f"Chosen content: {translated_chosen_content}\n")
237
- raw_file.write(f"Rejected content: {translated_rejected_content}\n\n")
238
-
239
- logging.info("Translation request successful.")
240
- except Exception as e:
241
- logging.error(f"An error occurred during translation: {e}")
242
- return None
243
-
244
- # Update the original item with the translated fields
245
- item['prompt'] = translated_prompts
246
- item['chosen'][0]['content'] = translated_chosen_content
247
- item['rejected'][0]['content'] = translated_rejected_content
248
-
249
- logging.info("Translation processing successful.")
250
- return item
251
-
252
- def validate_item_mix(item):
253
- """
254
- Validates the structure, presence, and content of required fields in the given item,
255
- allowing for multiple elements in the 'prompt' field for multi-turn conversations.
256
- """
257
- required_fields = ['dataset', 'prompt', 'chosen', 'rejected']
258
- for field in required_fields:
259
- if field not in item:
260
- logging.warning(f"Missing required field: {field}")
261
- return False
262
-
263
- # Check for at least one element in 'prompt' and exactly one element in 'chosen' and 'rejected'
264
- if len(item['prompt']) < 1 or len(item['chosen']) != 1 or len(item['rejected']) != 1:
265
- logging.warning("Invalid number of elements in 'prompt', 'chosen', or 'rejected' field.")
266
- return False
267
-
268
- # Validate 'content' and 'role' fields in all messages of 'prompt', and single elements of 'chosen' and 'rejected'
269
- for choice in item['prompt'] + item['chosen'] + item['rejected']:
270
- if 'content' not in choice or 'role' not in choice:
271
- logging.warning("Missing 'content' or 'role' field in choice.")
272
- return False
273
- if not isinstance(choice['content'], str) or not isinstance(choice['role'], str):
274
- logging.warning("Invalid type for 'content' or 'role' field in choice.")
275
- return False
276
-
277
- return True
278
-
279
- def translate_item_ufb_cached(item, raw_file_path, translator, tokenizer, target_language):
280
- try:
281
- translated_texts = {} # Cache to store translated texts
282
-
283
- # Translate the prompt if necessary (which is a user input and can appear again)
284
- if item['prompt'] not in translated_texts:
285
- translated_prompt = translate_text(item['prompt'], translator, tokenizer, target_language)
286
- translated_texts[item['prompt']] = translated_prompt
287
- else:
288
- translated_prompt = translated_texts[item['prompt']]
289
-
290
- # Helper function to handle content translation with caching
291
- def get_translated_content(content):
292
- if content not in translated_texts:
293
- translated_texts[content] = translate_text(content, translator, tokenizer, target_language)
294
- return translated_texts[content]
295
-
296
- # Process translations for chosen and rejected sections
297
- def translate_interactions(interactions):
298
- translated_interactions = []
299
- for interaction in interactions:
300
- translated_content = get_translated_content(interaction['content'])
301
- translated_interactions.append({'content': translated_content, 'role': interaction['role']})
302
- return translated_interactions
303
-
304
- translated_chosen = translate_interactions(item['chosen'])
305
- translated_rejected = translate_interactions(item['rejected'])
306
-
307
- # Write the raw response to a backup file
308
- with open(raw_file_path, 'a', encoding='utf-8') as raw_file:
309
- raw_file.write(f"Prompt: {translated_prompt}\n")
310
- raw_file.write(f"Chosen: {json.dumps(translated_chosen, ensure_ascii=False)}\n")
311
- raw_file.write(f"Rejected: {json.dumps(translated_rejected, ensure_ascii=False)}\n\n")
312
-
313
- logging.info("Translation request successful.")
314
- # Update the original item with the translated fields
315
- item['prompt'] = translated_prompt
316
- item['chosen'] = translated_chosen
317
- item['rejected'] = translated_rejected
318
- return item
319
-
320
- except Exception as e:
321
- logging.error(f"An error occurred during translation: {e}")
322
- return None
323
-
324
- def validate_item_ufb_cached(item):
325
- # Check basic required fields
326
- required_fields = ['source', 'prompt', 'chosen', 'rejected']
327
- for field in required_fields:
328
- if field not in item:
329
- logging.warning(f"Missing required field: {field}")
330
- return False
331
-
332
- # Ensure 'prompt' is a string
333
- if not isinstance(item['prompt'], str):
334
- logging.warning("Prompt must be a string.")
335
- return False
336
-
337
- # Check 'chosen' and 'rejected' which should be lists of dictionaries
338
- for field in ['chosen', 'rejected']:
339
- if not isinstance(item[field], list) or not item[field]:
340
- logging.warning(f"No entries or incorrect type for section: {field}")
341
- return False
342
- for idx, message in enumerate(item[field]):
343
- if 'content' not in message or 'role' not in message:
344
- logging.warning(f"Missing 'content' or 'role' field in {field} at index {idx}")
345
- return False
346
- if not isinstance(message['content'], str) or not isinstance(message['role'], str):
347
- logging.warning(f"Invalid type for 'content' or 'role' field in {field} at index {idx}")
348
- return False
349
-
350
- return True
351
-
352
- def process_file(input_file_path, output_file_path, raw_file_path, line_indices, translator, tokenizer, model_type, target_language):
353
- try:
354
- # Assigning validation and translation functions based on model_type
355
- if model_type == "mix":
356
- print ("translating a mix-style model...")
357
- validate_item = validate_item_mix
358
- translate_item = translate_item_mix
359
- elif model_type == "ufb_cached":
360
- print ("translating an ufb_cached-style model...")
361
- validate_item = validate_item_ufb_cached
362
- translate_item = translate_item_ufb_cached # def translate_item_ufb(item, raw_file_path, translator, tokenizer):
363
- elif model_type == "ufb":
364
- print ("translating an ultrafeedback-style model...")
365
- validate_item = validate_item_ufb
366
- translate_item = translate_item_ufb # def translate_item_ufb(item, raw_file_path, translator, tokenizer):
367
- else:
368
- raise ValueError(f"Unsupported model_type: {model_type}")
369
-
370
- with open(input_file_path, 'r', encoding='utf-8') as file:
371
- data_points = [json.loads(line) for line in file]
372
-
373
- failed_items = []
374
- failed_items_indices = []
375
-
376
- for index in tqdm(line_indices, desc="Processing lines", unit="item"):
377
- item = data_points[index]
378
-
379
- # Validate the item structure
380
- if not validate_item(item):
381
- logging.warning("Skipping item due to invalid structure.")
382
- failed_items.append(item)
383
- continue
384
-
385
- # Translate the relevant fields in the item
386
- translated_item = None
387
- retry_count = 0
388
- while translated_item is None and retry_count < 3:
389
- print ("going to translate the item...")
390
- translated_item = translate_item(item, raw_file_path, translator, tokenizer, target_language)
391
- retry_count += 1
392
- if translated_item is None:
393
- logging.warning(f"Translation failed for item. Retry attempt: {retry_count}")
394
- time.sleep(1)
395
-
396
- if translated_item is not None:
397
- translated_item['index'] = index
398
- with open(output_file_path, 'a', encoding='utf-8') as file:
399
- file.write(json.dumps(translated_item, ensure_ascii=False) + "\n")
400
- else:
401
- failed_items_indices.append(index)
402
- failed_items.append(item)
403
- logging.error("Translation failed after multiple attempts. Skipping item.")
404
-
405
- # Validate the translated item structure
406
- if not validate_item(translated_item):
407
- logging.warning("Skipping translated item due to invalid structure.")
408
- failed_items.append(item)
409
- continue
410
-
411
- with open('failed_items.jsonl', 'w', encoding='utf-8') as file:
412
- for item in failed_items:
413
- file.write(json.dumps(item, ensure_ascii=False) + "\n")
414
-
415
- failed_items_str = generate_failed_items_str(failed_items_indices)
416
- with open('failed_items_index.txt', 'w', encoding='utf-8') as f:
417
- f.write(failed_items_str)
418
-
419
- logging.info("Translation completed successfully.")
420
-
421
- except Exception as e:
422
- logging.error(f"An error occurred: {e}")
423
-
424
- def generate_failed_items_str(indices):
425
- """
426
- Converts a list of failed item indices into a string.
427
- """
428
- if not indices:
429
- return ""
430
-
431
- # Sort the list of indices and initialize the first range
432
- indices.sort()
433
- range_start = indices[0]
434
- current = range_start
435
- ranges = []
436
-
437
- for i in indices[1:]:
438
- if i == current + 1:
439
- current = i
440
- else:
441
- if range_start == current:
442
- ranges.append(f"{range_start}")
443
- else:
444
- ranges.append(f"{range_start}-{current}")
445
- range_start = current = i
446
-
447
- # Add the last range
448
- if range_start == current:
449
- ranges.append(f"{range_start}")
450
- else:
451
- ranges.append(f"{range_start}-{current}")
452
-
453
- return ",".join(ranges)
454
-
455
- # Function to upload the output file to Hugging Face
456
- def upload_output_to_huggingface(output_file_path, repo_name, token):
457
- api = HfApi()
458
-
459
- # Check if the repository exists
460
- try:
461
- print ("checking repo:", repo_name)
462
- api.repo_info(repo_id=repo_name, repo_type="dataset", token=token)
463
- except Exception as e:
464
- if "404" in str(e):
465
- # Create the repository if it doesn't exist
466
- print ("creating it...")
467
- create_repo(repo_id=repo_name, repo_type="dataset", token=token)
468
- print(f"Created repository: {repo_name}")
469
- else:
470
- print(f"Failed to check repository existence: {e}")
471
- return
472
-
473
- # Upload the file to the repository
474
- try:
475
- print ("starting dataset upload from:", output_file_path)
476
- upload_file(
477
- path_or_fileobj=output_file_path,
478
- path_in_repo=output_file_path,
479
- repo_id=repo_name,
480
- repo_type="dataset",
481
- token=token
482
- )
483
- print(f"Uploaded {output_file_path} to Hugging Face repository: {repo_name}")
484
- except Exception as e:
485
- print(f"Failed to upload {output_file_path} to Hugging Face: {e}")
486
- raise
487
-
488
- def translate_dataset(train_url, local_parquet_path, input_file_path, output_file_path, raw_file_path, range_specification, model_type, output_dir, output_repo_name, token, translator, tokenizer, target_language):
489
- try:
490
- # Download the Parquet file
491
- download_parquet(train_url, local_parquet_path)
492
- except Exception as e:
493
- logging.error(f"Failed to download the Parquet file from {train_url}: {e}")
494
- return
495
-
496
- try:
497
- # Convert the downloaded Parquet file to JSONL
498
- convert_parquet_to_jsonl(local_parquet_path, output_dir)
499
- except Exception as e:
500
- logging.error(f"Failed to convert Parquet to JSONL: {e}")
501
- return
502
-
503
- try:
504
- # Rename the JSONL file using subprocess to ensure correct handling
505
- subprocess.run(["mv", f"{output_dir}/train.jsonl", input_file_path], check=True)
506
- except subprocess.CalledProcessError as e:
507
- logging.error(f"Failed to rename the file from 'train.jsonl' to {input_file_path}: {e}")
508
- return
509
-
510
- try:
511
- # Count lines in the JSONL file to validate contents
512
- line_count = count_lines_in_jsonl(input_file_path)
513
- logging.info(f"Number of lines in the file: {line_count}")
514
- except Exception as e:
515
- logging.error(f"Failed to count lines in {input_file_path}: {e}")
516
- return
517
-
518
- try:
519
- # Parse the range specification for processing specific lines
520
- line_indices = parse_range_specification(range_specification, file_length=line_count)
521
- if not line_indices:
522
- logging.error("No valid line indices to process. Please check the range specifications.")
523
- return
524
- except Exception as e:
525
- logging.error(f"Error parsing range specification '{range_specification}': {e}")
526
- return
527
-
528
- try:
529
- # Process the file with specified model type and line indices
530
- process_file(input_file_path, output_file_path, raw_file_path, line_indices, translator, tokenizer, model_type, target_language)
531
- except Exception as e:
532
- logging.error(f"Failed to process the file {input_file_path}: {e}")
533
- return
534
-
535
- try:
536
- # Upload the output file to Hugging Face repository
537
- upload_output_to_huggingface(output_file_path, output_repo_name, token)
538
- except Exception as e:
539
- logging.error(f"Failed to upload {output_file_path} to Hugging Face: {e}")
540
-
541
- # Setup logging configuration
542
- log_stream = io.StringIO()
543
- logging.basicConfig(level=logging.INFO,
544
- format='%(asctime)s - %(levelname)s - %(message)s',
545
- handlers=[
546
- logging.FileHandler("translation.log", mode='a'),
547
- logging.StreamHandler(log_stream)
548
- ])
549
- logger = logging.getLogger(__name__)
550
-
551
- # Main function to handle the translation workflow
552
- # Main function to handle the translation workflow
553
- def main(dataset_url, model_type, output_dataset_name, range_specification, target_language, token: gr.OAuthToken | None, profile: gr.OAuthProfile | None):
554
- try:
555
- # Login to Hugging Face
556
- if token is None or profile is None or token.token is None or profile.username is None:
557
- return "### You must be logged in to use this service."
558
-
559
- if token:
560
- logger.info("Logged in to Hugging Face")
561
-
562
- # Configuration and paths
563
- tokenizer_name = "facebook/wmt21-dense-24-wide-en-x"
564
- model_repo_name = "cstr/wmt21ct2_int8" # Repository to download the model from
565
-
566
- # Download the model snapshot from Hugging Face
567
- model_path = snapshot_download(repo_id=model_repo_name, token=token.token)
568
- logger.info(f"Model downloaded to: {model_path}")
569
-
570
- # Load the CTranslate2 model
571
- translator = ctranslate2.Translator(model_path, device="auto")
572
- logger.info("CTranslate2 model loaded successfully.")
573
-
574
- # Load the tokenizer
575
- tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name)
576
- tokenizer.src_lang = "en"
577
- tokenizer.tgt_lang = target_language # Set target language
578
- logger.info("Tokenizer loaded successfully.")
579
-
580
- # Define the task based on user input
581
- task = {
582
- "url": dataset_url,
583
- "local_path": "train.parquet",
584
- "input_file": f"{model_type}_en.jsonl",
585
- "output_file": f"{model_type}_{target_language}.jsonl", # Include target language in the filename
586
- "raw_file": f"{model_type}_{target_language}_raw.jsonl",
587
- "range_spec": range_specification,
588
- "model_type": model_type,
589
- "target_language": target_language # Include target language in the task
590
- }
591
-
592
- # Call the translate_dataset function with the provided parameters
593
- translate_dataset(
594
- train_url=task["url"],
595
- local_parquet_path=task["local_path"],
596
- input_file_path=task["input_file"],
597
- output_file_path=task["output_file"],
598
- output_dir=".",
599
- output_repo_name=output_dataset_name,
600
- raw_file_path=task["raw_file"],
601
- token=token.token,
602
- range_specification=task["range_spec"],
603
- model_type=task["model_type"],
604
- translator=translator,
605
- tokenizer=tokenizer,
606
- target_language=task["target_language"] # Pass the target language
607
- )
608
- logger.info("Dataset translation completed!")
609
- return "Dataset translation completed!\n\n### Logs:\n" + log_stream.getvalue()
610
- else:
611
- return "Login failed. Please try again."
612
- except Exception as e:
613
- logger.error(f"An error occurred in the main function: {e}")
614
- return f"An error occurred: {e}\n\n### Logs:\n{log_stream.getvalue()}"
615
-
616
-
617
- # Gradio interface setup
618
- gradio_title = "🧐 WMT21 Dataset Translation"
619
- gradio_desc = """This tool translates english datasets using the WMT21 translation model.
620
- ## 💭 What Does This Tool Do:
621
- - Translates datasets (as parquet files) with structures based on the selected model type (see below).
622
- - The translation model (facebook/wmt21-dense-24-wide-en-x) supports as target languages: Hausa (ha), Icelandic (is), Japanese (ja), Czech (cs), Russian (ru), Chinese (zh), German (de)
623
- - Uploads the translated dataset as jsonl to Hugging Face.
624
- - At the moment, this works only on CPU, and therefore is very very slow."""
625
- datasets_desc = """## 📊 Dataset Types:
626
- Note: additional fields will be kept (untranslated), an additional index field is added, which makes it easier to verify results, i.a.
627
- - **mix**:
628
- - `prompt`: List of dictionaries with 'content' and 'role' fields (multi-turn conversation).
629
- - `chosen`: Single dictionary with 'content' and 'role' fields.
630
- - `rejected`: Single dictionary with 'content' and 'role' fields.
631
- - **ufb_cached**:
632
- - `prompt`: String (user input).
633
- - `chosen`: List of dictionaries with 'content' and 'role' fields.
634
- - `rejected`: List of dictionaries with 'content' and 'role' fields.
635
- - **ufb**:
636
- - like ufb_cached, but we do not check for already translated strings
637
- ## 🛠️ Backend:
638
- The translation model is int8 quantized from facebook/wmt21-dense-24-wide-en-x and runs via ctranslate2 on the Hugging Face Hub."""
639
-
640
- # Define the theme
641
- theme = gr.themes.Soft(text_size="lg", spacing_size="lg")
642
-
643
- with gr.Blocks(theme=theme) as demo:
644
- gr.HTML(f"""<h1 align="center" id="space-title">{gradio_title}</h1>""")
645
- gr.Markdown(gradio_desc)
646
-
647
- with gr.Row(variant="panel"):
648
- gr.Markdown(value="## 🚀 Login to Hugging Face"),
649
- gr.LoginButton(min_width=380)
650
-
651
- gr.Markdown(value="🚨 **This is needed to upload the resulting dataset.**")
652
-
653
- with gr.Row(equal_height=False):
654
- with gr.Column():
655
- dataset_url = gr.Textbox(label="Input Dataset URL", lines=2, placeholder = "https://huggingface.co/datasets/alvarobartt/dpo-mix-7k-simplified/resolve/main/data/train-00000-of-00001.parquet?download=true")
656
- model_type = gr.Dropdown(choices=["mix", "ufb_cached", "ufb"], label="Dataset Type")
657
- output_dataset_name = gr.Textbox(label="Output Dataset Name", lines=1, placeholder = "cstr/translated_datasets")
658
- range_specification = gr.Textbox(label="Range Specification", lines=1, placeholder="e.g., 1-100")
659
- target_language = gr.Dropdown(choices=["ha", "is", "ja", "cs", "ru", "zh", "de"], label="Target Language") # New dropdown for target language
660
-
661
- with gr.Column():
662
- output = gr.Markdown(label="Output")
663
-
664
- submit_btn = gr.Button("Translate Dataset", variant="primary")
665
- submit_btn.click(main, inputs=[dataset_url, model_type, output_dataset_name, range_specification, target_language], outputs=output)
666
-
667
-
668
- gr.Markdown(datasets_desc)
669
-
670
- demo.queue(max_size=10).launch(share=True, show_api=True)
 
 
1
  import gradio as gr
2
  import os
 
 
 
 
3
  import time
4
+ import sys
 
 
 
 
5
  import subprocess
 
6
 
7
+ # Clone and install faster-whisper from GitHub
8
+ subprocess.run(["git", "clone", "https://github.com/SYSTRAN/faster-whisper.git"], check=True)
9
+ subprocess.run(["pip", "install", "-e", "./faster-whisper"], check=True)
10
+
11
+ # Add the faster-whisper directory to the Python path
12
+ sys.path.append("./faster-whisper")
13
+
14
+ from faster_whisper import WhisperModel
15
+ from faster_whisper.transcribe import BatchedInferencePipeline
16
+
17
+ def transcribe_audio(audio_path, batch_size):
18
+ # Initialize the model
19
+ model = WhisperModel("cstr/whisper-large-v3-turbo-int8_float32", device="auto", compute_type="int8")
20
+ batched_model = BatchedInferencePipeline(model=model)
21
+
22
+ # Benchmark transcription time
23
+ start_time = time.time()
24
+ segments, info = batched_model.transcribe(audio_path, batch_size=batch_size)
25
+ end_time = time.time()
26
+
27
+ # Generate transcription
28
+ transcription = ""
29
+ for segment in segments:
30
+ transcription += f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment.text}\n"
31
+
32
+ # Calculate metrics
33
+ transcription_time = end_time - start_time
34
+ real_time_factor = info.duration / transcription_time
35
+ audio_file_size = os.path.getsize(audio_path) / (1024 * 1024) # Size in MB
36
+
37
+ # Prepare output
38
+ output = f"Transcription:\n\n{transcription}\n"
39
+ output += f"\nLanguage: {info.language}, Probability: {info.language_probability:.2f}\n"
40
+ output += f"Duration: {info.duration:.2f}s, Duration after VAD: {info.duration_after_vad:.2f}s\n"
41
+ output += f"Transcription time: {transcription_time:.2f} seconds\n"
42
+ output += f"Real-time factor: {real_time_factor:.2f}x\n"
43
+ output += f"Audio file size: {audio_file_size:.2f} MB"
44
+
45
+ return output
46
+
47
+ # Gradio interface
48
+ iface = gr.Interface(
49
+ fn=transcribe_audio,
50
+ inputs=[
51
+ gr.Audio(type="filepath", label="Upload Audio File"),
52
+ gr.Slider(minimum=1, maximum=32, step=1, value=16, label="Batch Size")
53
+ ],
54
+ outputs=gr.Textbox(label="Transcription and Metrics"),
55
+ title="Faster Whisper Transcription (GitHub Version)",
56
+ description="Upload an audio file to transcribe using Faster Whisper (GitHub version). Adjust the batch size for performance tuning.",
57
+ examples=[["path/to/example/audio.mp3", 16]],
58
+ )
59
+
60
+ iface.launch()