File size: 1,335 Bytes
d5ee97c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
//
//  FastSpeech2.swift
//  HelloTensorFlowTTS
//
//  Created by 안창범 on 2021/03/09.
//

import Foundation
import TensorFlowLite

class FastSpeech2 {
    let interpreter: Interpreter
    
    var speakerId: Int32 = 0
    
    var f0Ratio: Float = 1
    
    var energyRatio: Float = 1
    
    init(url: URL) throws {
        var options = Interpreter.Options()
        options.threadCount = 5
        interpreter = try Interpreter(modelPath: url.path, options: options)
    }
    
    func getMelSpectrogram(inputIds: [Int32], speedRatio: Float) throws -> Tensor {
        try interpreter.resizeInput(at: 0, to: [1, inputIds.count])
        try interpreter.allocateTensors()
        
        let data = inputIds.withUnsafeBufferPointer(Data.init)
        try interpreter.copy(data, toInputAt: 0)
        try interpreter.copy(Data(bytes: &speakerId, count: 4), toInputAt: 1)
        var speedRatio = speedRatio
        try interpreter.copy(Data(bytes: &speedRatio, count: 4), toInputAt: 2)
        try interpreter.copy(Data(bytes: &f0Ratio, count: 4), toInputAt: 3)
        try interpreter.copy(Data(bytes: &energyRatio, count: 4), toInputAt: 4)

        let t0 = Date()
        try interpreter.invoke()
        print("fastspeech2: \(Date().timeIntervalSince(t0))s")
        
        return try interpreter.output(at: 1)
    }
}