Vaibhav Srivastav commited on
Commit
255495b
·
0 Parent(s):

Squash for release.

Browse files
.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.whl filter=lfs diff=lfs merge=lfs -text
37
+ *.mp3 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio_cached_examples/
2
+
3
+ # Byte-compiled / optimized / DLL files
4
+ __pycache__/
5
+ *.py[cod]
6
+ *$py.class
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # Distribution / packaging
12
+ .Python
13
+ build/
14
+ develop-eggs/
15
+ dist/
16
+ downloads/
17
+ eggs/
18
+ .eggs/
19
+ lib/
20
+ lib64/
21
+ parts/
22
+ sdist/
23
+ var/
24
+ wheels/
25
+ share/python-wheels/
26
+ *.egg-info/
27
+ .installed.cfg
28
+ *.egg
29
+ MANIFEST
30
+
31
+ # PyInstaller
32
+ # Usually these files are written by a python script from a template
33
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
34
+ *.manifest
35
+ *.spec
36
+
37
+ # Installer logs
38
+ pip-log.txt
39
+ pip-delete-this-directory.txt
40
+
41
+ # Unit test / coverage reports
42
+ htmlcov/
43
+ .tox/
44
+ .nox/
45
+ .coverage
46
+ .coverage.*
47
+ .cache
48
+ nosetests.xml
49
+ coverage.xml
50
+ *.cover
51
+ *.py,cover
52
+ .hypothesis/
53
+ .pytest_cache/
54
+ cover/
55
+
56
+ # Translations
57
+ *.mo
58
+ *.pot
59
+
60
+ # Django stuff:
61
+ *.log
62
+ local_settings.py
63
+ db.sqlite3
64
+ db.sqlite3-journal
65
+
66
+ # Flask stuff:
67
+ instance/
68
+ .webassets-cache
69
+
70
+ # Scrapy stuff:
71
+ .scrapy
72
+
73
+ # Sphinx documentation
74
+ docs/_build/
75
+
76
+ # PyBuilder
77
+ .pybuilder/
78
+ target/
79
+
80
+ # Jupyter Notebook
81
+ .ipynb_checkpoints
82
+
83
+ # IPython
84
+ profile_default/
85
+ ipython_config.py
86
+
87
+ # pyenv
88
+ # For a library or package, you might want to ignore these files since the code is
89
+ # intended to run in multiple environments; otherwise, check them in:
90
+ # .python-version
91
+
92
+ # pipenv
93
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
95
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
96
+ # install all needed dependencies.
97
+ #Pipfile.lock
98
+
99
+ # poetry
100
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
101
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
102
+ # commonly ignored for libraries.
103
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
104
+ #poetry.lock
105
+
106
+ # pdm
107
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
108
+ #pdm.lock
109
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
110
+ # in version control.
111
+ # https://pdm.fming.dev/#use-with-ide
112
+ .pdm.toml
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.5.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: end-of-file-fixer
12
+ - id: mixed-line-ending
13
+ args: ["--fix=lf"]
14
+ - id: requirements-txt-fixer
15
+ - id: trailing-whitespace
16
+ - repo: https://github.com/myint/docformatter
17
+ rev: v1.7.5
18
+ hooks:
19
+ - id: docformatter
20
+ args: ["--in-place"]
21
+ - repo: https://github.com/pycqa/isort
22
+ rev: 5.12.0
23
+ hooks:
24
+ - id: isort
25
+ args: ["--profile", "black"]
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v1.7.0
28
+ hooks:
29
+ - id: mypy
30
+ args: ["--ignore-missing-imports"]
31
+ additional_dependencies:
32
+ ["types-python-slugify", "types-requests", "types-PyYAML"]
33
+ - repo: https://github.com/psf/black
34
+ rev: 23.11.0
35
+ hooks:
36
+ - id: black
37
+ language_version: python3.10
38
+ args: ["--line-length", "119"]
39
+ - repo: https://github.com/kynan/nbstripout
40
+ rev: 0.6.1
41
+ hooks:
42
+ - id: nbstripout
43
+ args:
44
+ [
45
+ "--extra-keys",
46
+ "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
47
+ ]
48
+ - repo: https://github.com/nbQA-dev/nbQA
49
+ rev: 1.7.0
50
+ hooks:
51
+ - id: nbqa-black
52
+ - id: nbqa-pyupgrade
53
+ args: ["--py37-plus"]
54
+ - id: nbqa-isort
55
+ args: ["--float-to-top"]
.vscode/settings.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "editor.formatOnSave": true,
3
+ "files.insertFinalNewline": false,
4
+ "[python]": {
5
+ "editor.defaultFormatter": "ms-python.black-formatter",
6
+ "editor.formatOnType": true,
7
+ "editor.codeActionsOnSave": {
8
+ "source.organizeImports": true
9
+ }
10
+ },
11
+ "[jupyter]": {
12
+ "files.insertFinalNewline": false
13
+ },
14
+ "black-formatter.args": [
15
+ "--line-length=119"
16
+ ],
17
+ "isort.args": ["--profile", "black"],
18
+ "flake8.args": [
19
+ "--max-line-length=119"
20
+ ],
21
+ "ruff.lint.args": [
22
+ "--line-length=119"
23
+ ],
24
+ "notebook.output.scrolling": true,
25
+ "notebook.formatOnCellExecution": true
26
+ }
Dockerfile ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04
2
+ ENV DEBIAN_FRONTEND=noninteractive
3
+ RUN apt-get update && \
4
+ apt-get upgrade -y && \
5
+ apt-get install -y --no-install-recommends \
6
+ git \
7
+ git-lfs \
8
+ wget \
9
+ curl \
10
+ # python build dependencies \
11
+ build-essential \
12
+ libssl-dev \
13
+ zlib1g-dev \
14
+ libbz2-dev \
15
+ libreadline-dev \
16
+ libsqlite3-dev \
17
+ libncursesw5-dev \
18
+ xz-utils \
19
+ tk-dev \
20
+ libxml2-dev \
21
+ libxmlsec1-dev \
22
+ libffi-dev \
23
+ liblzma-dev \
24
+ # gradio dependencies \
25
+ ffmpeg \
26
+ # fairseq2 dependencies \
27
+ libsndfile-dev && \
28
+ apt-get clean && \
29
+ rm -rf /var/lib/apt/lists/*
30
+
31
+ RUN useradd -m -u 1000 user
32
+ USER user
33
+ ENV HOME=/home/user \
34
+ PATH=/home/user/.local/bin:${PATH}
35
+ WORKDIR ${HOME}/app
36
+
37
+ RUN curl https://pyenv.run | bash
38
+ ENV PATH=${HOME}/.pyenv/shims:${HOME}/.pyenv/bin:${PATH}
39
+ ARG PYTHON_VERSION=3.10.13
40
+ RUN pyenv install ${PYTHON_VERSION} && \
41
+ pyenv global ${PYTHON_VERSION} && \
42
+ pyenv rehash && \
43
+ pip install --no-cache-dir -U pip setuptools wheel && \
44
+ pip install "huggingface-hub==0.19.3" "hf-transfer==0.1.4"
45
+
46
+ COPY --chown=1000 . ${HOME}/app
47
+ RUN pip install -r ${HOME}/app/requirements.txt && \
48
+ pip install fairseq2 --pre --extra-index-url https://fair.pkg.atmeta.com/fairseq2/pt2.1.0/cu121 && \
49
+ pip install ${HOME}/app/whl/seamless_communication-1.0.0-py3-none-any.whl
50
+
51
+ ENV PYTHONPATH=${HOME}/app \
52
+ PYTHONUNBUFFERED=1 \
53
+ HF_HUB_ENABLE_HF_TRANSFER=1 \
54
+ GRADIO_ALLOW_FLAGGING=never \
55
+ GRADIO_NUM_PORTS=1 \
56
+ GRADIO_SERVER_NAME=0.0.0.0 \
57
+ GRADIO_THEME=huggingface \
58
+ TQDM_POSITION=-1 \
59
+ TQDM_MININTERVAL=1 \
60
+ SYSTEM=spaces
61
+ CMD ["python", "app.py"]
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Seamless Expressive
3
+ emoji: 🏃
4
+ colorFrom: red
5
+ colorTo: blue
6
+ sdk: docker
7
+ pinned: false
8
+ ---
9
+
10
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import pathlib
5
+ import tempfile
6
+
7
+ import gradio as gr
8
+ import torch
9
+ import torchaudio
10
+ from fairseq2.assets import InProcAssetMetadataProvider, asset_store
11
+ from fairseq2.data import Collater, SequenceData, VocabularyInfo
12
+ from fairseq2.data.audio import (
13
+ AudioDecoder,
14
+ WaveformToFbankConverter,
15
+ WaveformToFbankOutput,
16
+ )
17
+
18
+ from seamless_communication.inference import SequenceGeneratorOptions
19
+ from fairseq2.generation import NGramRepeatBlockProcessor
20
+ from fairseq2.memory import MemoryBlock
21
+ from fairseq2.typing import DataType, Device
22
+ from huggingface_hub import snapshot_download
23
+ from seamless_communication.inference import BatchedSpeechOutput, Translator, SequenceGeneratorOptions
24
+ from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
25
+ from seamless_communication.models.unity import (
26
+ UnitTokenizer,
27
+ load_gcmvn_stats,
28
+ load_unity_text_tokenizer,
29
+ load_unity_unit_tokenizer,
30
+ )
31
+ from torch.nn import Module
32
+ from seamless_communication.cli.expressivity.evaluate.pretssel_inference_helper import PretsselGenerator
33
+
34
+ from utils import LANGUAGE_CODE_TO_NAME
35
+
36
+ DESCRIPTION = """\
37
+ # Seamless Expressive
38
+
39
+
40
+ [SeamlessExpressive](https://github.com/facebookresearch/seamless_communication) is a speech-to-speech translation model that captures certain underexplored aspects of prosody such as speech rate and pauses, while preserving the style of one's voice and high content translation quality.
41
+ """
42
+
43
+ CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES") == "1" and torch.cuda.is_available()
44
+
45
+ CHECKPOINTS_PATH = pathlib.Path(os.getenv("CHECKPOINTS_PATH", "/home/user/app/models"))
46
+ if not CHECKPOINTS_PATH.exists():
47
+ snapshot_download(repo_id="facebook/seamless-expressive", repo_type="model", local_dir=CHECKPOINTS_PATH)
48
+ snapshot_download(repo_id="facebook/seamless-m4t-v2-large", repo_type="model", local_dir=CHECKPOINTS_PATH)
49
+
50
+ # Ensure that we do not have any other environment resolvers and always return
51
+ # "demo" for demo purposes.
52
+ asset_store.env_resolvers.clear()
53
+ asset_store.env_resolvers.append(lambda: "demo")
54
+
55
+ # Construct an `InProcAssetMetadataProvider` with environment-specific metadata
56
+ # that just overrides the regular metadata for "demo" environment. Note the "@demo" suffix.
57
+ demo_metadata = [
58
+ {
59
+ "name": "seamless_expressivity@demo",
60
+ "checkpoint": f"file://{CHECKPOINTS_PATH}/m2m_expressive_unity.pt",
61
+ "char_tokenizer": f"file://{CHECKPOINTS_PATH}/spm_char_lang38_tc.model",
62
+ },
63
+ {
64
+ "name": "vocoder_pretssel@demo",
65
+ "checkpoint": f"file://{CHECKPOINTS_PATH}/pretssel_melhifigan_wm-final.pt",
66
+ },
67
+ {
68
+ "name": "seamlessM4T_v2_large@demo",
69
+ "checkpoint": f"file://{CHECKPOINTS_PATH}/seamlessM4T_v2_large.pt",
70
+ "char_tokenizer": f"file://{CHECKPOINTS_PATH}/spm_char_lang38_tc.model",
71
+ },
72
+ ]
73
+
74
+ asset_store.metadata_providers.append(InProcAssetMetadataProvider(demo_metadata))
75
+
76
+ LANGUAGE_NAME_TO_CODE = {v: k for k, v in LANGUAGE_CODE_TO_NAME.items()}
77
+
78
+
79
+ if torch.cuda.is_available():
80
+ device = torch.device("cuda:0")
81
+ dtype = torch.float16
82
+ else:
83
+ device = torch.device("cpu")
84
+ dtype = torch.float32
85
+
86
+
87
+ MODEL_NAME = "seamless_expressivity"
88
+ VOCODER_NAME = "vocoder_pretssel"
89
+
90
+ # used for ASR for toxicity
91
+ m4t_translator = Translator(
92
+ model_name_or_card="seamlessM4T_v2_large",
93
+ vocoder_name_or_card=None,
94
+ device=device,
95
+ dtype=dtype,
96
+ )
97
+ unit_tokenizer = load_unity_unit_tokenizer(MODEL_NAME)
98
+
99
+ _gcmvn_mean, _gcmvn_std = load_gcmvn_stats(VOCODER_NAME)
100
+ gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype)
101
+ gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype)
102
+
103
+ translator = Translator(
104
+ MODEL_NAME,
105
+ vocoder_name_or_card=None,
106
+ device=device,
107
+ dtype=dtype,
108
+ apply_mintox=False,
109
+ )
110
+
111
+ text_generation_opts = SequenceGeneratorOptions(
112
+ beam_size=5,
113
+ unk_penalty=torch.inf,
114
+ soft_max_seq_len=(0, 200),
115
+ step_processor=NGramRepeatBlockProcessor(
116
+ ngram_size=10,
117
+ ),
118
+ )
119
+ m4t_text_generation_opts = SequenceGeneratorOptions(
120
+ beam_size=5,
121
+ unk_penalty=torch.inf,
122
+ soft_max_seq_len=(1, 200),
123
+ step_processor=NGramRepeatBlockProcessor(
124
+ ngram_size=10,
125
+ ),
126
+ )
127
+
128
+ pretssel_generator = PretsselGenerator(
129
+ VOCODER_NAME,
130
+ vocab_info=unit_tokenizer.vocab_info,
131
+ device=device,
132
+ dtype=dtype,
133
+ )
134
+
135
+ decode_audio = AudioDecoder(dtype=torch.float32, device=device)
136
+
137
+ convert_to_fbank = WaveformToFbankConverter(
138
+ num_mel_bins=80,
139
+ waveform_scale=2**15,
140
+ channel_last=True,
141
+ standardize=False,
142
+ device=device,
143
+ dtype=dtype,
144
+ )
145
+
146
+
147
+ def normalize_fbank(data: WaveformToFbankOutput) -> WaveformToFbankOutput:
148
+ fbank = data["fbank"]
149
+ std, mean = torch.std_mean(fbank, dim=0)
150
+ data["fbank"] = fbank.subtract(mean).divide(std)
151
+ data["gcmvn_fbank"] = fbank.subtract(gcmvn_mean).divide(gcmvn_std)
152
+ return data
153
+
154
+
155
+ collate = Collater(pad_value=0, pad_to_multiple=1)
156
+
157
+
158
+ AUDIO_SAMPLE_RATE = 16000
159
+ MAX_INPUT_AUDIO_LENGTH = 10 # in seconds
160
+
161
+
162
+ def remove_prosody_tokens_from_text(text):
163
+ # filter out prosody tokens, there is only emphasis '*', and pause '='
164
+ text = text.replace("*", "").replace("=", "")
165
+ text = " ".join(text.split())
166
+ return text
167
+
168
+
169
+ def preprocess_audio(input_audio_path: str) -> None:
170
+ arr, org_sr = torchaudio.load(input_audio_path)
171
+ new_arr = torchaudio.functional.resample(arr, orig_freq=org_sr, new_freq=AUDIO_SAMPLE_RATE)
172
+ max_length = int(MAX_INPUT_AUDIO_LENGTH * AUDIO_SAMPLE_RATE)
173
+ if new_arr.shape[1] > max_length:
174
+ new_arr = new_arr[:, :max_length]
175
+ gr.Warning(f"Input audio is too long. Only the first {MAX_INPUT_AUDIO_LENGTH} seconds is used.")
176
+ torchaudio.save(input_audio_path, new_arr, sample_rate=AUDIO_SAMPLE_RATE)
177
+
178
+
179
+ def run(
180
+ input_audio_path: str,
181
+ source_language: str,
182
+ target_language: str,
183
+ ) -> tuple[str, str]:
184
+ target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
185
+ source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
186
+
187
+ preprocess_audio(input_audio_path)
188
+
189
+ with pathlib.Path(input_audio_path).open("rb") as fb:
190
+ block = MemoryBlock(fb.read())
191
+ example = decode_audio(block)
192
+
193
+ example = convert_to_fbank(example)
194
+ example = normalize_fbank(example)
195
+ example = collate(example)
196
+
197
+ # get transcription for mintox
198
+ source_sentences, _ = m4t_translator.predict(
199
+ input=example["fbank"],
200
+ task_str="S2TT", # get source text
201
+ tgt_lang=source_language_code,
202
+ text_generation_opts=m4t_text_generation_opts,
203
+ )
204
+ source_text = str(source_sentences[0])
205
+
206
+ prosody_encoder_input = example["gcmvn_fbank"]
207
+ text_output, unit_output = translator.predict(
208
+ example["fbank"],
209
+ "S2ST",
210
+ tgt_lang=target_language_code,
211
+ src_lang=source_language_code,
212
+ text_generation_opts=text_generation_opts,
213
+ unit_generation_ngram_filtering=False,
214
+ duration_factor=1.0,
215
+ prosody_encoder_input=prosody_encoder_input,
216
+ src_text=source_text, # for mintox check
217
+ )
218
+ speech_output = pretssel_generator.predict(
219
+ unit_output.units,
220
+ tgt_lang=target_language_code,
221
+ prosody_encoder_input=prosody_encoder_input,
222
+ )
223
+
224
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
225
+ torchaudio.save(
226
+ f.name,
227
+ speech_output.audio_wavs[0][0].to(torch.float32).cpu(),
228
+ sample_rate=speech_output.sample_rate,
229
+ )
230
+
231
+ text_out = remove_prosody_tokens_from_text(str(text_output[0]))
232
+
233
+ return f.name, text_out
234
+
235
+
236
+ TARGET_LANGUAGE_NAMES = [
237
+ "English",
238
+ "French",
239
+ "German",
240
+ "Spanish",
241
+ ]
242
+
243
+ with gr.Blocks(css="style.css") as demo:
244
+ gr.Markdown(DESCRIPTION)
245
+ gr.DuplicateButton(
246
+ value="Duplicate Space for private use",
247
+ elem_id="duplicate-button",
248
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
249
+ )
250
+ with gr.Row():
251
+ with gr.Column():
252
+ with gr.Group():
253
+ input_audio = gr.Audio(label="Input speech", type="filepath")
254
+ source_language = gr.Dropdown(
255
+ label="Source language",
256
+ choices=TARGET_LANGUAGE_NAMES,
257
+ value="English",
258
+ )
259
+ target_language = gr.Dropdown(
260
+ label="Target language",
261
+ choices=TARGET_LANGUAGE_NAMES,
262
+ value="French",
263
+ )
264
+ btn = gr.Button()
265
+ with gr.Column():
266
+ with gr.Group():
267
+ output_audio = gr.Audio(label="Translated speech")
268
+ output_text = gr.Textbox(label="Translated text")
269
+
270
+ gr.Examples(
271
+ examples=[
272
+ ["assets/Excited-Es.wav", "English", "Spanish"],
273
+ ["assets/FastTalking-En.wav", "French", "English"],
274
+ ["assets/Sad-Es.wav", "English", "Spanish"],
275
+ ],
276
+ inputs=[input_audio, source_language, target_language],
277
+ outputs=[output_audio, output_text],
278
+ fn=run,
279
+ cache_examples=CACHE_EXAMPLES,
280
+ api_name=False,
281
+ )
282
+
283
+ btn.click(
284
+ fn=run,
285
+ inputs=[input_audio, source_language, target_language],
286
+ outputs=[output_audio, output_text],
287
+ api_name="run",
288
+ )
289
+
290
+ if __name__ == "__main__":
291
+ demo.queue(max_size=50).launch()
assets/Excited-Es.wav ADDED
Binary file (788 kB). View file
 
assets/FastTalking-En.wav ADDED
Binary file (788 kB). View file
 
assets/Sad-Es.wav ADDED
Binary file (788 kB). View file
 
assets/Whisper-Fr.wav ADDED
Binary file (788 kB). View file
 
assets/sample_input.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:982369687f05bf8fcd6923c4ffcccda0fcce92f44eceae5a9d00a431f07ea87b
3
+ size 10272
assets/sample_input_2.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a505a4641e3f5f0ddec9508832793aa20e63d2545530b66bc04a9bd19a742e6
3
+ size 30624
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchaudio
style.css ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+
5
+ #duplicate-button {
6
+ margin: auto;
7
+ color: #fff;
8
+ background: #1565c0;
9
+ border-radius: 100vh;
10
+ }
utils.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import torch
2
+ # import torchaudio
3
+ # from fairseq2.assets import InProcAssetMetadataProvider, asset_store
4
+ # from fairseq2.data import Collater, SequenceData
5
+ # from fairseq2.data.audio import (
6
+ # AudioDecoder,
7
+ # WaveformToFbankConverter,
8
+ # WaveformToFbankOutput,
9
+ # )
10
+ # from fairseq2.generation import SequenceGeneratorOptions
11
+ # from fairseq2.memory import MemoryBlock
12
+ # from fairseq2.typing import DataType, Device
13
+ # from huggingface_hub import snapshot_download
14
+ # from seamless_communication.inference import BatchedSpeechOutput, Translator
15
+ # from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
16
+ # from seamless_communication.models.unity import (
17
+ # UnitTokenizer,
18
+ # load_gcmvn_stats,
19
+ # load_unity_text_tokenizer,
20
+ # load_unity_unit_tokenizer,
21
+ # )
22
+ # from torch.nn import Module
23
+
24
+ # class PretsselGenerator(Module):
25
+ # def __init__(
26
+ # self,
27
+ # pretssel_name_or_card: str,
28
+ # unit_tokenizer: UnitTokenizer,
29
+ # device: Device,
30
+ # dtype: DataType = torch.float16,
31
+ # ):
32
+ # super().__init__()
33
+ # # Load the model.
34
+ # if device == torch.device("cpu"):
35
+ # dtype = torch.float32
36
+
37
+
38
+ # self.device = device
39
+ # self.dtype = dtype
40
+
41
+ # self.pretssel_model = load_pretssel_vocoder_model(
42
+ # pretssel_name_or_card,
43
+ # device=device,
44
+ # dtype=dtype,
45
+ # )
46
+ # self.pretssel_model.eval()
47
+
48
+ # vocoder_model_card = asset_store.retrieve_card(pretssel_name_or_card)
49
+ # self.output_sample_rate = vocoder_model_card.field("sample_rate").as_(int)
50
+
51
+ # self.unit_tokenizer = unit_tokenizer
52
+ # self.unit_collate = Collater(pad_value=unit_tokenizer.vocab_info.pad_idx)
53
+ # self.duration_collate = Collater(pad_value=0)
54
+
55
+ # @torch.inference_mode()
56
+ # def predict(
57
+ # self,
58
+ # units: list[list[int]],
59
+ # tgt_lang: str,
60
+ # prosody_encoder_input: SequenceData,
61
+ # ) -> BatchedSpeechOutput:
62
+ # audio_wavs = []
63
+ # unit_eos_token = torch.tensor(
64
+ # [self.unit_tokenizer.vocab_info.eos_idx],
65
+ # device=self.device,
66
+ # )
67
+
68
+ # prosody_input_seqs = prosody_encoder_input["seqs"]
69
+ # prosody_input_lens = prosody_encoder_input["seq_lens"]
70
+
71
+ # for i, u in enumerate(units):
72
+ # unit = torch.tensor(u).to(unit_eos_token)
73
+
74
+ # # adjust the control symbols for the embedding
75
+ # unit += 4
76
+ # unit = torch.cat([unit, unit_eos_token], dim=0)
77
+
78
+ # unit, duration = torch.unique_consecutive(unit, return_counts=True)
79
+
80
+ # # adjust for the last eos token
81
+ # duration[-1] = 0
82
+
83
+ # duration *= 2
84
+
85
+ # prosody_input_seq = prosody_input_seqs[i][: prosody_input_lens[i]]
86
+
87
+ # audio_wav = self.pretssel_model(
88
+ # unit,
89
+ # tgt_lang,
90
+ # prosody_input_seq,
91
+ # durations=duration.unsqueeze(0),
92
+ # )
93
+
94
+ # audio_wavs.append(audio_wav)
95
+
96
+ # return BatchedSpeechOutput(
97
+ # units=units,
98
+ # audio_wavs=audio_wavs,
99
+ # sample_rate=self.output_sample_rate,
100
+ # )
101
+
102
+
103
+ LANGUAGE_CODE_TO_NAME = {
104
+ "afr": "Afrikaans",
105
+ "amh": "Amharic",
106
+ "arb": "Modern Standard Arabic",
107
+ "ary": "Moroccan Arabic",
108
+ "arz": "Egyptian Arabic",
109
+ "asm": "Assamese",
110
+ "ast": "Asturian",
111
+ "azj": "North Azerbaijani",
112
+ "bel": "Belarusian",
113
+ "ben": "Bengali",
114
+ "bos": "Bosnian",
115
+ "bul": "Bulgarian",
116
+ "cat": "Catalan",
117
+ "ceb": "Cebuano",
118
+ "ces": "Czech",
119
+ "ckb": "Central Kurdish",
120
+ "cmn": "Mandarin Chinese",
121
+ "cym": "Welsh",
122
+ "dan": "Danish",
123
+ "deu": "German",
124
+ "ell": "Greek",
125
+ "eng": "English",
126
+ "est": "Estonian",
127
+ "eus": "Basque",
128
+ "fin": "Finnish",
129
+ "fra": "French",
130
+ "gaz": "West Central Oromo",
131
+ "gle": "Irish",
132
+ "glg": "Galician",
133
+ "guj": "Gujarati",
134
+ "heb": "Hebrew",
135
+ "hin": "Hindi",
136
+ "hrv": "Croatian",
137
+ "hun": "Hungarian",
138
+ "hye": "Armenian",
139
+ "ibo": "Igbo",
140
+ "ind": "Indonesian",
141
+ "isl": "Icelandic",
142
+ "ita": "Italian",
143
+ "jav": "Javanese",
144
+ "jpn": "Japanese",
145
+ "kam": "Kamba",
146
+ "kan": "Kannada",
147
+ "kat": "Georgian",
148
+ "kaz": "Kazakh",
149
+ "kea": "Kabuverdianu",
150
+ "khk": "Halh Mongolian",
151
+ "khm": "Khmer",
152
+ "kir": "Kyrgyz",
153
+ "kor": "Korean",
154
+ "lao": "Lao",
155
+ "lit": "Lithuanian",
156
+ "ltz": "Luxembourgish",
157
+ "lug": "Ganda",
158
+ "luo": "Luo",
159
+ "lvs": "Standard Latvian",
160
+ "mai": "Maithili",
161
+ "mal": "Malayalam",
162
+ "mar": "Marathi",
163
+ "mkd": "Macedonian",
164
+ "mlt": "Maltese",
165
+ "mni": "Meitei",
166
+ "mya": "Burmese",
167
+ "nld": "Dutch",
168
+ "nno": "Norwegian Nynorsk",
169
+ "nob": "Norwegian Bokm\u00e5l",
170
+ "npi": "Nepali",
171
+ "nya": "Nyanja",
172
+ "oci": "Occitan",
173
+ "ory": "Odia",
174
+ "pan": "Punjabi",
175
+ "pbt": "Southern Pashto",
176
+ "pes": "Western Persian",
177
+ "pol": "Polish",
178
+ "por": "Portuguese",
179
+ "ron": "Romanian",
180
+ "rus": "Russian",
181
+ "slk": "Slovak",
182
+ "slv": "Slovenian",
183
+ "sna": "Shona",
184
+ "snd": "Sindhi",
185
+ "som": "Somali",
186
+ "spa": "Spanish",
187
+ "srp": "Serbian",
188
+ "swe": "Swedish",
189
+ "swh": "Swahili",
190
+ "tam": "Tamil",
191
+ "tel": "Telugu",
192
+ "tgk": "Tajik",
193
+ "tgl": "Tagalog",
194
+ "tha": "Thai",
195
+ "tur": "Turkish",
196
+ "ukr": "Ukrainian",
197
+ "urd": "Urdu",
198
+ "uzn": "Northern Uzbek",
199
+ "vie": "Vietnamese",
200
+ "xho": "Xhosa",
201
+ "yor": "Yoruba",
202
+ "yue": "Cantonese",
203
+ "zlm": "Colloquial Malay",
204
+ "zsm": "Standard Malay",
205
+ "zul": "Zulu",
206
+ }
whl/seamless_communication-1.0.0-py3-none-any.whl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1df10e0c85ee0ffbc9f2e1bf8896850a52c551383df0332a94d26d9d39770c85
3
+ size 201552