Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/use huggingface compatible pretokenizer #38

Merged
merged 11 commits into from
Jan 27, 2022
26 changes: 19 additions & 7 deletions pretraining/bert/train_wordpiece_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
import os
from glob import glob

from sudachitra.pretokenizer import JapaneseBertWordPieceTokenizer, SudachipyPreTokenizer
from sudachitra import get_split_mode
from sudachitra.pretokenizer import JapaneseBertWordPieceTokenizer
from sudachitra.word_formatter import WordFormTypes
from sudachitra.pretokenizer import pretokenizer_handler
from sudachipy import Dictionary


def main():
Expand All @@ -36,14 +39,17 @@ def main():
limit_alphabet=args.limit_alphabet
)

wp_tokenizer = JapaneseBertWordPieceTokenizer(do_lower_case=args.do_lower_case, do_nfkc=args.do_nfkc)
wp_tokenizer = JapaneseBertWordPieceTokenizer(do_strip=args.do_strip,
do_lower_case=args.do_lower_case,
do_nfkc=args.do_nfkc,
disable_parallel=args.disable_parallel)

sudachi_pre_tokenizer = SudachipyPreTokenizer(
split_mode=args.split_mode,
dict_type=args.dict_type,
word_form_type=args.word_form_type
sudachi_dict = Dictionary(dict=args.dict_type)
sudachi_pre_tokenizer = sudachi_dict.pre_tokenizer(
mode=get_split_mode(args.split_mode),
handler=pretokenizer_handler(sudachi_dict, word_form_type=args.word_form_type)
)
wp_tokenizer.set_pre_tokenizer(sudachi_pre_tokenizer)
wp_tokenizer.pre_tokenizer = sudachi_pre_tokenizer

wp_tokenizer.train(files, **settings)

Expand All @@ -62,6 +68,8 @@ def get_args():
help='Input directory containing files to train tokenizer.')

# Normalizers
parser.add_argument('--do_strip', action='store_true', default=False,
help='Removes all whitespace characters on both sides of the input.')
parser.add_argument('--do_lower_case', action='store_true', default=False,
help='Replaces all uppercase to lowercase.')
parser.add_argument('--do_nfkc', action='store_true', default=False,
Expand All @@ -84,6 +92,10 @@ def get_args():
choices=WordFormTypes, type=WordFormTypes,
help='Word form type for each morpheme.')

# Wordpiece
parser.add_argument('--disable_parallel', action='store_true', default=False,
help='Disable parallel tokenization.')

# Output
parser.add_argument('-o', '--output_dir',
help='The output dir to be saved vocabulary and config file.')
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
logzero~=1.7.0
progressbar2~=3.53.1
pytextspan>=0.5.4
tokenizers>=0.10.3
transformers>=4.6.1
sudachipy>=0.6.2
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
install_requires=[
"logzero~=1.7.0",
"progressbar2~=3.53.1",
"pytextspan>=0.5.4",
eiennohito marked this conversation as resolved.
Show resolved Hide resolved
"tokenizers>=0.10.3",
"transformers>=4.6.1",
"sudachipy>=0.6.2",
Expand Down
2 changes: 1 addition & 1 deletion sudachitra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .sudachipy_word_tokenizer import SudachipyWordTokenizer
from .sudachipy_word_tokenizer import SudachipyWordTokenizer, get_split_mode
from .tokenization_bert_sudachipy import BertSudachipyTokenizer
from .tokenization_electra_sudachipy import ElectraSudachipyTokenizer
7 changes: 5 additions & 2 deletions sudachitra/input_string_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from tokenizers.normalizers import Lowercase, NFKC, Sequence
from tokenizers.normalizers import Lowercase, NFKC, Sequence, Strip


class InputStringNormalizer(object):
def __init__(self, do_lower_case=False, do_nfkc=False):
def __init__(self, do_strip=False, do_lower_case=False, do_nfkc=False):
self.do_strip: bool = do_strip
self.do_lower_case: bool = do_lower_case
self.do_nfkc: bool = do_nfkc
self._normalizer: Sequence = self._init_normalizer()

def _init_normalizer(self) -> Sequence:
normalizers = []
if self.do_strip:
normalizers.append(Strip())
if self.do_lower_case:
normalizers.append(Lowercase())
if self.do_nfkc:
Expand Down
2 changes: 1 addition & 1 deletion sudachitra/pretokenizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@

from .japanese_bert_wordpiece_tokenizer import JapaneseBertWordPieceTokenizer
from .pos_substitution_tokenizer import PartOfSpeechSubstitutionTokenizer
from .sudachipy_pretokenizer import CustomPreTokenizer, SudachipyPreTokenizer
from .sudachipy_pretokenizer import pretokenizer_handler
31 changes: 14 additions & 17 deletions sudachitra/pretokenizer/japanese_bert_wordpiece_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,9 @@
from tokenizers.implementations import BertWordPieceTokenizer
from tokenizers.implementations.base_tokenizer import BaseTokenizer

from .sudachipy_pretokenizer import CustomPreTokenizer
from ..input_string_normalizer import InputStringNormalizer


CPT = TypeVar('CPT', bound=CustomPreTokenizer)


class JapaneseBertWordPieceTokenizer(BaseTokenizer):
def __init__(
self,
Expand All @@ -41,6 +37,8 @@ def __init__(
mask_token: Union[str, AddedToken] = "[MASK]",
do_lower_case: bool = False,
do_nfkc: bool = False,
do_strip: bool = False,
disable_parallel: bool = False,
wordpieces_prefix: str = "##",
):
if vocab is not None:
Expand All @@ -60,7 +58,7 @@ def __init__(
if tokenizer.token_to_id(str(mask_token)) is not None:
tokenizer.add_special_tokens([str(mask_token)])

_normalizer = InputStringNormalizer(do_lower_case=do_lower_case, do_nfkc=do_nfkc)
_normalizer = InputStringNormalizer(do_strip=do_strip, do_lower_case=do_lower_case, do_nfkc=do_nfkc)
tokenizer.normalizer = _normalizer.normalizer
tokenizer.pre_tokenizer = BertPreTokenizer()

Expand All @@ -84,9 +82,11 @@ def __init__(
"cls_token": cls_token,
"pad_token": pad_token,
"mask_token": mask_token,
"do_strip": do_strip,
"do_lower_case": do_lower_case,
"do_nfkc": do_nfkc,
"wordpieces_prefix": wordpieces_prefix,
"disable_parallel": disable_parallel,
}

super().__init__(tokenizer, parameters)
Expand Down Expand Up @@ -114,17 +114,20 @@ def train(
wordpieces_prefix: str = "##",
):
""" Train the model using the given files """
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if self._parameters["disable_parallel"]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
eiennohito marked this conversation as resolved.
Show resolved Hide resolved

logger.info("Parameters for training")
logger.info("\tvocab_size: {}".format(vocab_size))
logger.info("\tdo_strip: {}".format(self._parameters["do_strip"]))
logger.info("\tdo_lower_case: {}".format(self._parameters["do_lower_case"]))
logger.info("\tdo_nfkc: {}".format(self._parameters["do_nfkc"]))
logger.info("\tvocab_size: {}".format(vocab_size))
logger.info("\tmin_frequency: {}".format(min_frequency))
logger.info("\tlimit_alphabet: {}".format(limit_alphabet))
logger.info("\tinitial_alphabet: {}".format(",".join(initial_alphabet)))
logger.info("\tspecial_tokens: {}".format(",".join(special_tokens)))
logger.info("\twordpieces_prefix: {}".format(wordpieces_prefix))
logger.info("\tdisable_parallel: {}".format(self._parameters["disable_parallel"]))

trainer = trainers.WordPieceTrainer(
vocab_size=vocab_size,
Expand Down Expand Up @@ -162,17 +165,20 @@ def train_from_iterator(
wordpieces_prefix: str = "##",
):
""" Train the model using the given iterator """
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if self._parameters["disable_parallel"]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

logger.info("Parameters for training")
logger.info("\tvocab_size: {}".format(vocab_size))
logger.info("\tdo_strip: {}".format(self._parameters["do_strip"]))
logger.info("\tdo_lower_case: {}".format(self._parameters["do_lower_case"]))
logger.info("\tdo_nfkc: {}".format(self._parameters["do_nfkc"]))
logger.info("\tmin_frequency: {}".format(min_frequency))
logger.info("\tlimit_alphabet: {}".format(limit_alphabet))
logger.info("\tinitial_alphabet: {}".format(",".join(initial_alphabet)))
logger.info("\tspecial_tokens: {}".format(",".join(special_tokens)))
logger.info("\twordpieces_prefix: {}".format(wordpieces_prefix))
logger.info("\tdisable_parallel: {}".format(self._parameters["disable_parallel"]))

trainer = trainers.WordPieceTrainer(
vocab_size=vocab_size,
Expand All @@ -187,15 +193,6 @@ def train_from_iterator(

logger.info("#Vocab: {}".format(self.get_vocab_size()))

def set_pre_tokenizer(self, custom_pre_tokenizer: CPT):
"""
Sets the custom tokenizer as pre-tokenizer.

Args:
custom_pre_tokenizer (CPT): Custom tokenizer that implements `custom_split`.
"""
self.pre_tokenizer = PreTokenizer.custom(custom_pre_tokenizer)

def save(self, output_tokenizer_path: str, pretty: bool = False):
"""
Saves a config file of the tokenizer.
Expand Down
117 changes: 39 additions & 78 deletions sudachitra/pretokenizer/sudachipy_pretokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,96 +13,57 @@
# limitations under the License.

import textspan
from tokenizers import NormalizedString, PreTokenizedString
from typing import List, Optional
from sudachipy import Dictionary, MorphemeList
from tokenizers import NormalizedString
from typing import Callable, List, Optional

from .. import SudachipyWordTokenizer
from ..word_formatter import word_formatter, WordFormTypes


class CustomPreTokenizer:
def custom_split(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]:
"""
Tokenizes the input string and returns list of tokens.
def split_normalized_string(normalized_string: NormalizedString, tokens: List[str]) -> List[NormalizedString]:
"""
Splits normalized_string by tokens.

Please override this function with your custom tokenizer.
Example. https://github.com/huggingface/tokenizers/blob/b24a2fc/bindings/python/examples/custom_components.py
Args:
normalized_string (NormalizedString): Input string.
tokens (List[str]): List of surface words in a sentence.

Args:
i (int): Index.
normalized_string (NormalizedString): Input string.
Returns:
List[NormalizedString]: List of normalized_strings.
"""
token_spans = textspan.get_original_spans(tokens, str(normalized_string))
return [normalized_string[start:end] for token_span in token_spans for start, end in token_span]
eiennohito marked this conversation as resolved.
Show resolved Hide resolved

Returns:
List[NormalizedString]: List of normalized_strings.
"""
raise NotImplementedError()

def pre_tokenize(self, pretok: PreTokenizedString):
pretok.split(self.custom_split)
def pretokenizer_handler(sudachi_dict: Dictionary, word_form_type: Optional[str] = 'surface')\
-> Callable[[int, NormalizedString, MorphemeList], List[NormalizedString]]:
"""
A handler for Dictionary.pre_tokenizer that transform MorphemeList into list to tokens.

@staticmethod
def split_normalized_string(normalized_string: NormalizedString, tokens: List[str]) -> List[NormalizedString]:
"""
Splits normalized_string by tokens.
Returns a handler to convert a morpheme to the specified word form.

Args:
normalized_string (NormalizedString): Input string.
tokens (List[str]): List of surface words in a sentence.
Args:
sudachi_dict (Dictionary):
Sudachi dictionary.
word_form_type (:obj:`str`, `optional`, defaults to :obj:`"surface"`):
Word form type for each morpheme.
The values defined in WordFormTypes can be specified.

Returns:
List[NormalizedString]: List of normalized_strings.
"""
token_spans = textspan.get_original_spans(tokens, str(normalized_string).strip())
return [normalized_string[start:end] for token_span in token_spans for start, end in token_span]
Returns:
Callable[[int, NormalizedString, MorphemeList], List[NormalizedString]]:
A handler for Dictionary.pre_tokenizer that transform MorphemeList into list to tokens.
https://worksapplications.github.io/sudachi.rs/python/api/sudachipy.html#sudachipy.Dictionary.pre_tokenizer
"""
_word_formatter = word_formatter(word_form_type, sudachi_dict) if word_form_type != WordFormTypes.SURFACE else None


class SudachipyPreTokenizer(SudachipyWordTokenizer, CustomPreTokenizer):
def __init__(
self,
split_mode: Optional[str] = "C",
dict_type: Optional[str] = "core",
word_form_type: Optional[str] = "surface",
**kwargs
):
"""
Constructs a SudachipyPreTokenizer.

Args:
split_mode (:obj:`str`, `optional`, defaults to :obj:`"C"`):
The mode of splitting.
"A", "B", or "C" can be specified.
dict_type (:obj:`str`, `optional`, defaults to :obj:`"core"`):
Sudachi dictionary type to be used for tokenization.
"small", "core", or "full" can be specified.
word_form_type (:obj:`str`, `optional`, defaults to :obj:`"surface"`):
Word form type for each morpheme.
The values defined in WordFormTypes can be specified.
**kwargs:
Sudachi dictionary parameters.
"""
SudachipyWordTokenizer.__init__(self, split_mode=split_mode, dict_type=dict_type, **kwargs)
self.word_form_type = word_form_type
self.word_formatter = (word_formatter(self.word_form_type, self.sudachi_dict)
if self.word_form_type != WordFormTypes.SURFACE else None)

def custom_split(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]:
"""
Tokenizes with SudachiPy and returns its tokens.

Args:
i (int): Index.
normalized_string (NormalizedString): Input string.

Returns:
List[NormalizedString]: List of normalized_strings.
"""
morphs = super().tokenize(str(normalized_string).strip())
tokens = [m.surface() for m in morphs if m.surface() != ""]
normalized_strings = self.split_normalized_string(normalized_string, tokens)
def _handler(index: int, original: NormalizedString, morphemes: MorphemeList) -> List[NormalizedString]:
tokens = [m.surface() for m in morphemes if m.surface() != '']
normalized_strings = split_normalized_string(original, tokens)
if len(tokens) != len(normalized_strings):
raise ValueError(len(morphs), len(tokens), len(normalized_strings), tokens, normalized_strings)

if self.word_form_type != WordFormTypes.SURFACE:
_ = [ns.replace(ns.normalized, self.word_formatter(m)) for ns, m in zip(normalized_strings, morphs)]
raise ValueError(len(morphemes), len(tokens), len(normalized_strings), tokens, morphemes, normalized_strings)
if word_form_type != WordFormTypes.SURFACE:
_ = [ns.replace(ns.normalized, _word_formatter(m)) for ns, m in zip(normalized_strings, morphemes)]
eiennohito marked this conversation as resolved.
Show resolved Hide resolved

return normalized_strings

return _handler
eiennohito marked this conversation as resolved.
Show resolved Hide resolved
35 changes: 26 additions & 9 deletions sudachitra/sudachipy_word_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,31 @@
from sudachipy import SplitMode


def get_split_mode(split_mode: str) -> SplitMode:
"""
Returns the specified SplitMode.
"A", "B", or "C" can be specified.

Args:
split_mode (str): The mode of splitting.

Returns:
SplitMode: Unit to split text.

Raises:
ValueError: If `split_mode` is not defined in SudachiPy.
"""
split_mode = split_mode.upper()
if split_mode == "C":
return SplitMode.C
elif split_mode == "B":
return SplitMode.B
elif split_mode == "A":
return SplitMode.A
else:
raise ValueError("Invalid `split_mode`: " + split_mode)


class SudachipyWordTokenizer:
"""Runs tokenization with SudachiPy."""

Expand All @@ -44,15 +69,7 @@ def __init__(
Sudachi dictionary type to be used for tokenization.
"small", "core", or "full" can be specified.
"""
split_mode = split_mode.upper()
if split_mode == "C":
self.split_mode = SplitMode.C
elif split_mode == "B":
self.split_mode = SplitMode.B
elif split_mode == "A":
self.split_mode = SplitMode.A
else:
raise ValueError("Invalid `split_mode`: " + split_mode)
self.split_mode = get_split_mode(split_mode)

self.sudachi_dict = Dictionary(config_path=config_path, resource_dir=resource_dir, dict=dict_type)
self.sudachi = self.sudachi_dict.create(self.split_mode)
Expand Down
Loading