Skip to content

Commit

Permalink
Merge pull request #17 from xingchensong/Mddct-pre-commit
Browse files Browse the repository at this point in the history
pre-commit for clean and tidy code
  • Loading branch information
xingchensong authored Dec 19, 2024
2 parents a15a6e5 + 21b863e commit c14c2aa
Show file tree
Hide file tree
Showing 8 changed files with 232 additions and 145 deletions.
14 changes: 14 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
exclude: 's3tokenizer/assets/.*'
- repo: https://github.com/pre-commit/mirrors-yapf
rev: 'v0.32.0'
hooks:
- id: yapf
- repo: https://github.com/pycqa/flake8
rev: '3.8.2'
hooks:
- id: flake8
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pre-commit
numpy
torch
onnx
Expand Down
77 changes: 41 additions & 36 deletions s3tokenizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,40 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Modified from https://github.com/openai/whisper/blob/main/whisper/__init__.py
"""Modified from
https://github.com/openai/whisper/blob/main/whisper/__init__.py
"""


import hashlib
import io
import os
import urllib
import warnings
from typing import List, Optional, Union
from typing import List, Union

import torch
from tqdm import tqdm

from .model import S3Tokenizer
from .utils import (
onnx2torch,
make_non_pad_mask,
mask_to_bias,
log_mel_spectrogram,
load_audio,
padding
)

from .utils import (load_audio, log_mel_spectrogram, make_non_pad_mask,
mask_to_bias, onnx2torch, padding)

__all__ = [
'load_audio', 'log_mel_spectrogram', 'make_non_pad_mask', 'mask_to_bias',
'onnx2torch', 'padding'
]
_MODELS = {
"speech_tokenizer_v1": "https://www.modelscope.cn/models/iic/cosyvoice-300m/resolve/master/speech_tokenizer_v1.onnx",
"speech_tokenizer_v1_25hz": "https://www.modelscope.cn/models/iic/CosyVoice-300M-25Hz/resolve/master/speech_tokenizer_v1.onnx",
"speech_tokenizer_v1":
"https://www.modelscope.cn/models/iic/cosyvoice-300m/"
"resolve/master/speech_tokenizer_v1.onnx",
"speech_tokenizer_v1_25hz":
"https://www.modelscope.cn/models/iic/CosyVoice-300M-25Hz/"
"resolve/master/speech_tokenizer_v1.onnx",
}

_SHA256S = {
"speech_tokenizer_v1": "23b5a723ed9143aebfd9ffda14ac4c21231f31c35ef837b6a13bb9e5488abb1e",
"speech_tokenizer_v1_25hz": "56285ddd4a83e883ee0cb9f8d69c1089b53a94b1f78ff7e4a0224a27eb4cb486",
"speech_tokenizer_v1":
"23b5a723ed9143aebfd9ffda14ac4c21231f31c35ef837b6a13bb9e5488abb1e",
"speech_tokenizer_v1_25hz":
"56285ddd4a83e883ee0cb9f8d69c1089b53a94b1f78ff7e4a0224a27eb4cb486",
}


Expand All @@ -56,7 +57,8 @@ def _download(name: str, root: str) -> Union[bytes, str]:
download_target = os.path.join(root, f"{name}.onnx")

if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
raise RuntimeError(
f"{download_target} exists and is not a regular file")

if os.path.isfile(download_target):
with open(download_target, "rb") as f:
Expand All @@ -65,17 +67,18 @@ def _download(name: str, root: str) -> Union[bytes, str]:
return download_target
else:
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)
f"{download_target} exists, but the SHA256 checksum does not"
" match; re-downloading the file")

with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with urllib.request.urlopen(url) as source, open(download_target,
"wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
desc="Downloading onnx checkpoint",
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
desc="Downloading onnx checkpoint",
) as loop:
while True:
buffer = source.read(8192)
Expand All @@ -88,8 +91,8 @@ def _download(name: str, root: str) -> Union[bytes, str]:
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
"Model has been downloaded but the SHA256 checksum does not not"
" match. Please retry loading the model.")

return download_target

Expand All @@ -109,10 +112,12 @@ def load_model(
Parameters
----------
name : str
one of the official model names listed by `s3tokenizer.available_models()`, or
path to a model checkpoint containing the model dimensions and the model state_dict.
one of the official model names listed by
`s3tokenizer.available_models()`, or path to a model checkpoint
containing the model dimensions and the model state_dict.
download_root: str
path to download the model files; by default, it uses "~/.cache/s3tokenizer"
path to download the model files; by default,
it uses "~/.cache/s3tokenizer"
Returns
-------
Expand All @@ -122,16 +127,16 @@ def load_model(

if download_root is None:
default = os.path.join(os.path.expanduser("~"), ".cache")
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "s3tokenizer")
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default),
"s3tokenizer")

if name in _MODELS:
checkpoint_file = _download(name, download_root)
elif os.path.isfile(name):
checkpoint_file = name
else:
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}"
)
f"Model {name} not found; available models = {available_models()}")

model = S3Tokenizer(name)
model.init_from_onnx(checkpoint_file)
Expand Down
73 changes: 55 additions & 18 deletions s3tokenizer/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,20 @@
"""


import os
import json
import argparse
import json
import os

import torch
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from tqdm import tqdm

import s3tokenizer


class AudioDataset(Dataset):

def __init__(self, wav_scp):
self.data = []
self.keys = []
Expand All @@ -61,7 +62,9 @@ def __getitem__(self, idx):
key = self.keys[idx]
audio = s3tokenizer.load_audio(file_path)
if audio.shape[0] / 16000 > 30:
print(f'do not support extract speech token for audio longer than 30s, file_path: {file_path}')
print(
f'do not support extract speech token for audio longer than 30s, file_path: {file_path}' # noqa
)
mel = torch.zeros(128, 0)
else:
mel = s3tokenizer.log_mel_spectrogram(audio)
Expand All @@ -88,13 +91,37 @@ def init_distributed():

def get_args():
parser = argparse.ArgumentParser(description='extract speech code')
parser.add_argument('--model', required=True, type=str, choices=["speech_tokenizer_v1", "speech_tokenizer_v1_25hz"], help='model version')
parser.add_argument('--wav_scp', required=True, type=str, help='each line contains `wav_name wav_path`')
parser.add_argument('--device', required=True, type=str, choices=["cuda", "cpu"], help='device for inference')
parser.add_argument('--output_dir', required=True, type=str, help='dir to save result')
parser.add_argument('--batch_size', required=True, type=int, help='batch size (per-device) for inference')
parser.add_argument('--num_workers', type=int, default=4, help='workers for dataloader')
parser.add_argument('--prefetch', type=int, default=5, help='prefetch for dataloader')
parser.add_argument(
'--model',
required=True,
type=str,
choices=["speech_tokenizer_v1", "speech_tokenizer_v1_25hz"],
help='model version')
parser.add_argument('--wav_scp',
required=True,
type=str,
help='each line contains `wav_name wav_path`')
parser.add_argument('--device',
required=True,
type=str,
choices=["cuda", "cpu"],
help='device for inference')
parser.add_argument('--output_dir',
required=True,
type=str,
help='dir to save result')
parser.add_argument('--batch_size',
required=True,
type=int,
help='batch size (per-device) for inference')
parser.add_argument('--num_workers',
type=int,
default=4,
help='workers for dataloader')
parser.add_argument('--prefetch',
type=int,
default=5,
help='prefetch for dataloader')
args = parser.parse_args()
return args

Expand All @@ -114,14 +141,21 @@ def main():
dataset = AudioDataset(args.wav_scp)

if args.device == "cuda":
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[local_rank])
sampler = DistributedSampler(dataset,
num_replicas=world_size,
rank=rank)
else:
sampler = None

dataloader = DataLoader(dataset, batch_size=args.batch_size, sampler=sampler,
shuffle=False, num_workers=args.num_workers,
prefetch_factor=args.prefetch, collate_fn=collate_fn)
dataloader = DataLoader(dataset,
batch_size=args.batch_size,
sampler=sampler,
shuffle=False,
num_workers=args.num_workers,
prefetch_factor=args.prefetch,
collate_fn=collate_fn)

total_steps = len(dataset)

Expand All @@ -134,7 +168,10 @@ def main():
for i, k in enumerate(keys):
code = codes[i, :codes_lens[i].item()].tolist()
writer.write(
json.dumps({"key": k, "code": code}, ensure_ascii=False) + "\n")
json.dumps({
"key": k,
"code": code
}, ensure_ascii=False) + "\n")
if rank == 0:
progress_bar.update(world_size * len(keys))

Expand Down
Loading

0 comments on commit c14c2aa

Please sign in to comment.