File size: 3,922 Bytes
10e72d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import requests
import tarfile
import gdown
import shutil
import subprocess
from pathlib import Path
from config import BASE_DIR

def run_command(command):
    print(f"Running: {command}")
    try:
        result = subprocess.run(command, shell=True, check=True, text=True,
                                stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        print(result.stdout)
        if result.stderr:
            print(f"STDERR: {result.stderr}")
        return True
    except subprocess.CalledProcessError as e:
        print(f"Error executing command: {e}")
        print(f"STDERR: {e.stderr}")
        return False


def download_file(url, destination):
    print(f"Downloading {url} to {destination}")
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()

        with open(destination, 'wb') as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)

        return True
    except Exception as e:
        print(f"Error downloading file: {e}")
        return False


def setup_environment():
    print("Setting up the environment for Persian TTS...")
    
    BASE_DIR_PATH = Path(BASE_DIR)
    MODEL_DIR = BASE_DIR_PATH / "saved_models" / "final_models"
    RESULTS_DIR = BASE_DIR_PATH / "results"

    os.makedirs(MODEL_DIR, exist_ok=True)
    os.makedirs(RESULTS_DIR, exist_ok=True)

    tacotron_repo = BASE_DIR_PATH / "pmt2"
    
    encoder_file = tacotron_repo / "saved_models" / "default" / "encoder.pt"
    if not os.path.exists(encoder_file):
        default_model_dir = tacotron_repo / "saved_models" / "default"
        os.makedirs(default_model_dir, exist_ok=True)

        encoder_url = "https://github.com/MahtaFetrat/Persian-MultiSpeaker-Tacotron2/raw/refs/heads/master/saved_models/default/encoder.pt"
        success = download_file(encoder_url, encoder_file)
        if not success:
            print(f"Warning: Could not download encoder model from {encoder_url}")
            return False

    shutil.copy(encoder_file, MODEL_DIR / "encoder.pt")

    if not os.path.exists(MODEL_DIR / "vocoder_HiFiGAN.pkl"):
        hifigan_tar_gz = BASE_DIR_PATH / "train_nodev_all_vctk_hifigan.v1.tar.gz"
        url = "https://drive.google.com/uc?id=1oVOC4Vf0DYLdDp4r7GChfgj7Xh5xd0ex"

        try:
            print(f"Downloading HiFiGAN model from Google Drive...")
            gdown.download(url, str(hifigan_tar_gz), quiet=False)

            print(f"Extracting {hifigan_tar_gz}...")
            with tarfile.open(hifigan_tar_gz, 'r:gz') as tar:
                tar.extractall(path=BASE_DIR_PATH)

            extracted_dir = BASE_DIR_PATH / "train_nodev_all_vctk_hifigan.v1"

            shutil.move(extracted_dir / "checkpoint-2500000steps.pkl",
                        MODEL_DIR / "vocoder_HiFiGAN.pkl")
            shutil.move(extracted_dir / "config.yml", MODEL_DIR / "config.yml")

            shutil.rmtree(extracted_dir, ignore_errors=True)
            os.remove(hifigan_tar_gz)

        except Exception as e:
            print(f"Warning: Could not download vocoder")
            return False

    if not os.path.exists(MODEL_DIR / "synthesizer.pt"):
        synth_url = "https://huggingface.co/MahtaFetrat/Persian-Tacotron2-on-ManaTTS/resolve/main/synthesizer.pt"
        success = download_file(synth_url, MODEL_DIR / "synthesizer.pt")
        if not success:
            print(f"Warning: Could not download synthesizer model from {synth_url}")
            return False

    if not os.path.exists(BASE_DIR_PATH / "sample.wav"):
        sample_url = "https://huggingface.co/MahtaFetrat/Persian-Tacotron2-on-ManaTTS/resolve/main/sample.wav"
        success = download_file(sample_url, BASE_DIR_PATH / "sample.wav")
        if not success:
            print(f"Warning: Could not download sample audio from {sample_url}")
            return False

    print("Setup complete!")
    return True