Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions TTS/.models.json
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@
}
}
},
"zh":{
"zh-CN":{
"baker":{
"tacotron2-DDC-GST":{
"model_file": "1RR9rZdV_FMm8yvtCHALtUbJf1nxbUiAw",
"config_file": "1daY1JHGXEozJ-MGYLiWEUmzEwEvM5xpz",
"stats_file": "1vl9c-D3dW_E7pdhNpDFQLX-giJc0jOtV",
"model_file": "1SYpv7V__QYDjKXa_vJmNXo1CSkcoZovy",
"config_file": "14BIvfJXnFHi3jcxYNX40__TR6RwJOZqi",
"stats_file": "1ECRlXybT6rAWp269CkhjUPwcZ10CkcqD",
"commit": ""
}
}
Expand Down
18 changes: 16 additions & 2 deletions TTS/server/server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!flask/bin/python
import argparse
import json
import os
import sys
import io
Expand All @@ -9,6 +10,7 @@
from TTS.utils.synthesizer import Synthesizer
from TTS.utils.manage import ModelManager
from TTS.utils.io import load_config
from TTS.utils.generic_utils import style_wav_uri_to_dict


def create_argparser():
Expand Down Expand Up @@ -75,14 +77,22 @@ def convert_boolean(x):
if not args.vocoder_config and os.path.isfile(vocoder_config_file):
args.vocoder_config = vocoder_config_file


synthesizer = Synthesizer(args.tts_checkpoint, args.tts_config, args.vocoder_checkpoint, args.vocoder_config, args.use_cuda)

use_speaker_embedding = synthesizer.tts_config.get("use_external_speaker_embedding_file", False)
use_gst = synthesizer.tts_config.get("use_gst", False)
app = Flask(__name__)


@app.route('/')
def index():
return render_template('index.html', show_details=args.show_details)
return render_template(
'index.html',
show_details=args.show_details,
use_speaker_embedding=use_speaker_embedding,
use_gst = use_gst
)

@app.route('/details')
def details():
Expand All @@ -102,8 +112,12 @@ def details():
@app.route('/api/tts', methods=['GET'])
def tts():
text = request.args.get('text')
speaker_json_key = request.args.get('speaker', "")
style_wav = request.args.get('style-wav', "")

style_wav = style_wav_uri_to_dict(style_wav)
print(" > Model input: {}".format(text))
wavs = synthesizer.tts(text)
wavs = synthesizer.tts(text, speaker_json_key=speaker_json_key, style_wav=style_wav)
out = io.BytesIO()
synthesizer.save_wav(wavs, out)
return send_file(out, mimetype='audio/wav')
Expand Down
25 changes: 21 additions & 4 deletions TTS/server/templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@

<ul class="list-unstyled">
</ul>
{%if use_speaker_embedding%}
<input id="speaker-json-key" placeholder="speaker json key.." size=45 type="text" name="speaker-json-key">
{%endif%}

{%if use_gst%}
<input value="{'0': 0.1}" id="style-wav" placeholder="style wav (dict or path ot wav).." size=45 type="text" name="style-wav">
{%endif%}

<input id="text" placeholder="Type here..." size=45 type="text" name="text">
<button id="speak-button" name="speak">Speak</button><br/><br/>
{%if show_details%}
Expand All @@ -73,15 +81,24 @@

<!-- Bootstrap core JavaScript -->
<script>
function getTextValue(textId) {
const container = q(textId)
if (container) {
return container.value
}
return ""
}
function q(selector) {return document.querySelector(selector)}
q('#text').focus()
function do_tts(e) {
text = q('#text').value
const text = q('#text').value
const speakerJsonKey = getTextValue('#speaker-json-key')
const styleWav = getTextValue('#style-wav')
if (text) {
q('#message').textContent = 'Synthesizing...'
q('#speak-button').disabled = true
q('#audio').hidden = true
synthesize(text)
synthesize(text, speakerJsonKey, styleWav)
}
e.preventDefault()
return false
Expand All @@ -92,8 +109,8 @@
do_tts(e)
}
})
function synthesize(text) {
fetch('/api/tts?text=' + encodeURIComponent(text), {cache: 'no-cache'})
function synthesize(text, speakerJsonKey="", styleWav="") {
fetch(`/api/tts?text=${encodeURIComponent(text)}&speaker=${encodeURIComponent(speakerJsonKey)}&style-wav=${encodeURIComponent(styleWav)}` , {cache: 'no-cache'})
.then(function(res) {
if (!res.ok) throw Error(res.statusText)
return res.blob()
Expand Down
6 changes: 4 additions & 2 deletions TTS/tts/models/glow_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from TTS.tts.layers.glow_tts.decoder import Decoder
from TTS.tts.utils.generic_utils import sequence_mask
from TTS.tts.layers.glow_tts.monotonic_align import maximum_path, generate_path
from TTS.tts.models.tts_abstract import TTSAbstract


class GlowTts(nn.Module):
class GlowTts(TTSAbstract):
"""Glow TTS models from https://arxiv.org/abs/2005.11129

Args:
Expand Down Expand Up @@ -179,7 +180,8 @@ def forward(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None):
return z, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur

@torch.no_grad()
def inference(self, x, x_lengths, g=None):
def inference(self, x, x_lengths, g=None, *args, **kwargs): # pylint: disable=unused-argument,keyword-arg-before-vararg

if g is not None:
if self.external_speaker_embedding_dim:
g = F.normalize(g).unsqueeze(-1)
Expand Down
51 changes: 26 additions & 25 deletions TTS/tts/models/speedy_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from TTS.tts.layers.speedy_speech.encoder import Encoder, PositionalEncoding
from TTS.tts.utils.generic_utils import sequence_mask
from TTS.tts.layers.glow_tts.monotonic_align import generate_path
from TTS.tts.models.tts_abstract import TTSAbstract


class SpeedySpeech(nn.Module):
class SpeedySpeech(TTSAbstract):
"""Speedy Speech model
https://arxiv.org/abs/2008.03802

Expand Down Expand Up @@ -36,29 +37,29 @@ class SpeedySpeech(nn.Module):
# pylint: disable=dangerous-default-value

def __init__(
self,
num_chars,
out_channels,
hidden_channels,
positional_encoding=True,
length_scale=1,
encoder_type='residual_conv_bn',
encoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 13
},
decoder_type='residual_conv_bn',
decoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4, 8] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 17
},
num_speakers=0,
external_c=False,
c_in_channels=0):
self,
num_chars,
out_channels,
hidden_channels,
positional_encoding=True,
length_scale=1,
encoder_type='residual_conv_bn',
encoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 13
},
decoder_type='residual_conv_bn',
decoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4, 8] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 17
},
num_speakers=0,
external_c=False,
c_in_channels=0):

super().__init__()
self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale
Expand Down Expand Up @@ -174,7 +175,7 @@ def forward(self, x, x_lengths, y_lengths, dr, g=None): # pylint: disable=unuse
o_de, attn= self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g)
return o_de, o_dr_log.squeeze(1), attn

def inference(self, x, x_lengths, g=None): # pylint: disable=unused-argument
def inference(self, x, x_lengths, g=None, *args, **kwargs): # pylint: disable=unused-argument,keyword-arg-before-vararg
"""
Shapes:
x: [B, T_max]
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/models/tacotron2.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_
return decoder_outputs, postnet_outputs, alignments, stop_tokens

@torch.no_grad()
def inference(self, text, speaker_ids=None, style_mel=None, speaker_embeddings=None):
def inference(self, text, *args, speaker_ids=None, style_mel=None, speaker_embeddings=None, **kwargs):
embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder.inference(embedded_inputs)

Expand Down
17 changes: 5 additions & 12 deletions TTS/tts/models/tacotron_abstract.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import copy
from abc import ABC, abstractmethod

import torch
from torch import nn

from TTS.tts.utils.generic_utils import sequence_mask
from TTS.utils.io import AttrDict
from TTS.tts.models.tts_abstract import TTSAbstract


class TacotronAbstract(ABC, nn.Module):
class TacotronAbstract(TTSAbstract):
def __init__(self,
num_chars,
num_speakers,
Expand Down Expand Up @@ -71,6 +70,7 @@ def __init__(self,
self.encoder = None
self.decoder = None
self.postnet = None


# multispeaker
if self.speaker_embedding_dim is None:
Expand Down Expand Up @@ -113,15 +113,8 @@ def _init_coarse_decoder(self):
# CORE FUNCTIONS
#############################

@abstractmethod
def forward(self):
pass

@abstractmethod
def inference(self):
pass

def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
def load_checkpoint(self, config: AttrDict, checkpoint_path: str, eval: bool = False): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model'])
self.decoder.set_r(state['r'])
Expand Down
26 changes: 26 additions & 0 deletions TTS/tts/models/tts_abstract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from TTS.utils.io import AttrDict
from torch import nn
from abc import ABC, abstractmethod


class TTSAbstract(ABC, nn.Module):
"""Abstract for tts model (tacotron, speedy_speech, glow_tts ...)

Heritance:
ABC: Abstract Base Class
nn.Module: pytorch nn.Module
"""

@abstractmethod
def forward(self):
pass

@abstractmethod
def inference(self, text, speaker_ids=None, style_mel=None, speaker_embeddings=None):
pass

@abstractmethod
def load_checkpoint(self, config: AttrDict, checkpoint_path: str, eval: bool = False):
pass


13 changes: 8 additions & 5 deletions TTS/tts/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections import Counter

from TTS.utils.generic_utils import check_argument
from TTS.tts.models.tts_abstract import TTSAbstract


def split_dataset(items):
Expand Down Expand Up @@ -44,12 +45,12 @@ def to_camel(text):
return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)


def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None) -> TTSAbstract:
print(" > Using model: {}".format(c.model))
MyModel = importlib.import_module('TTS.tts.models.' + c.model.lower())
MyModel = getattr(MyModel, to_camel(c.model))
if c.model.lower() in "tacotron":
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
model: TTSAbstract = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
num_speakers=num_speakers,
r=c.r,
postnet_output_dim=int(c.audio['fft_size'] / 2 + 1),
Expand All @@ -76,7 +77,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
ddc_r=c.ddc_r,
speaker_embedding_dim=speaker_embedding_dim)
elif c.model.lower() == "tacotron2":
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
model: TTSAbstract = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
num_speakers=num_speakers,
r=c.r,
postnet_output_dim=c.audio['num_mels'],
Expand All @@ -102,7 +103,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
ddc_r=c.ddc_r,
speaker_embedding_dim=speaker_embedding_dim)
elif c.model.lower() == "glow_tts":
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
model: TTSAbstract = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
hidden_channels_enc=c['hidden_channels_encoder'],
hidden_channels_dec=c['hidden_channels_decoder'],
hidden_channels_dp=c['hidden_channels_duration_predictor'],
Expand All @@ -123,7 +124,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
mean_only=True,
external_speaker_embedding_dim=speaker_embedding_dim)
elif c.model.lower() == "speedy_speech":
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
model: TTSAbstract = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
out_channels=c.audio['num_mels'],
hidden_channels=c['hidden_channels'],
positional_encoding=c['positional_encoding'],
Expand All @@ -132,6 +133,8 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
decoder_type=c['decoder_type'],
decoder_params=c['decoder_params'],
c_in_channels=0)
else:
return BaseException("Model type is not allowed : ", c.model.lower())
return model

def is_tacotron(c):
Expand Down
23 changes: 13 additions & 10 deletions TTS/tts/utils/speakers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,25 @@
import json


def make_speakers_json_path(out_path):
def make_speakers_json_path(out_path: str) -> str:
"""Returns conventional speakers.json location."""
return os.path.join(out_path, "speakers.json")


def load_speaker_mapping(out_path):
def load_speaker_mapping(out_path: str) -> dict:
"""Loads speaker mapping if already present."""
try:
if os.path.splitext(out_path)[1] == '.json':
json_file = out_path
else:
json_file = make_speakers_json_path(out_path)
speakers_json = {}
if os.path.splitext(out_path)[1] == '.json':
json_file = out_path
else:
json_file = make_speakers_json_path(out_path)

if os.path.isfile(json_file):
with open(json_file) as f:
return json.load(f)
except FileNotFoundError:
return {}
speakers_json = json.load(f)
else:
print(f"speaker json file was not found in path '{out_path}'")
return speakers_json

def save_speaker_mapping(out_path, speaker_mapping):
"""Saves speaker mapping if not yet present."""
Expand Down
Loading