Skip to content

Commit

Permalink
Fix spm converted FastTokenizer issue on non-ascii char (#778)
Browse files Browse the repository at this point in the history
* Fix spm converted tokenizer issue on non-ascii char

* remove pkg_resource in python
  • Loading branch information
wenbingl authored Jul 31, 2024
1 parent e113ed3 commit b4ebfc9
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 52 deletions.
87 changes: 60 additions & 27 deletions operators/tokenizer/bpe_decoder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ struct KernelBpeDecoder {
return status;
} else {
auto um = ParseId2String(byte_decoder);
std::transform(um.begin(), um.end(),
std::inserter(byte_decoder_, byte_decoder_.end()),
[](const auto& p) { return std::make_pair(static_cast<char32_t>(p.first),
ort_extensions::narrow<unsigned char>(std::stoul(p.second))); });
std::transform(um.begin(), um.end(), std::inserter(byte_decoder_, byte_decoder_.end()), [](const auto& p) {
return std::make_pair(static_cast<char32_t>(p.first),
ort_extensions::narrow<unsigned char>(std::stoul(p.second)));
});
}

std::string added_tokens;
Expand All @@ -59,8 +59,7 @@ struct KernelBpeDecoder {
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "all_special_ids", all_special_ids));
if (!all_special_ids.empty()) {
auto um = ParseId2String(all_special_ids);
std::transform(um.begin(), um.end(),
std::inserter(all_special_ids_, all_special_ids_.end()),
std::transform(um.begin(), um.end(), std::inserter(all_special_ids_, all_special_ids_.end()),
[](const auto& p) { return p.first; });
}

Expand Down Expand Up @@ -116,8 +115,29 @@ struct KernelBpeDecoder {
arr_vocab_.shrink_to_fit();
}

OrtxStatus Compute(const ortc::Tensor<int64_t>& ids,
ortc::Tensor<std::string>& output) const {
const std::string spm_underscore{"\xe2\x96\x81"};

static bool IsSpmByteWord(std::string_view word) {
return word.size() == 6 && word[0] == '<' && word[1] == '0' && word[2] == 'x' && word[5] == '>';
}

static std::string ReplaceAll(std::string_view s, const std::string& search, const std::string& replace) {
std::string result;
for (size_t pos = 0;; pos += search.length()) {
auto new_pos = s.find(search, pos);
if (new_pos == std::string::npos) {
result += s.substr(pos, s.size() - pos);
break;
}
result += s.substr(pos, new_pos - pos);
result += replace;
pos = new_pos;
}

return result;
}

OrtxStatus Compute(const ortc::Tensor<int64_t>& ids, ortc::Tensor<std::string>& output) const {
const int64_t* p_ids = ids.Data();
const auto& ids_dim = ids.Shape();
std::vector<int64_t> output_dim = {1};
Expand All @@ -126,6 +146,8 @@ struct KernelBpeDecoder {
std::copy(ids_dim.begin(), ids_dim.begin() + ids_dim.size() - 1, output_dim.begin());
}

bool spm_mode = byte_decoder_.count(ustring(spm_underscore)[0]) > 0;

size_t seq_len = ids_dim.back();
size_t string_batch = ids.NumberOfElement() / seq_len;
std::vector<std::string> decoded_strings;
Expand All @@ -148,24 +170,37 @@ struct KernelBpeDecoder {

if (added_tokens_.count(token)) {
const std::string& ws = added_tokens_.at(token);
decoded_token = (std::string)ws;
decoded_token.assign(ws);
} else if (static_cast<size_t>(token) < arr_vocab_.size()) {
const auto str = ustring(arr_vocab_[token]);
for (auto wchr : str) {
if (byte_decoder_.count(wchr) == 0) {
if (wchr <= char32_t(0xFF)) {
decoded_token.push_back(static_cast<char>(wchr));
continue;
}
if (skip_special_tokens_) {
continue;
} else {
decoded_token = unk_token_;
break;
const auto piece = arr_vocab_[token];
if (spm_mode) {
// sentencepiece case, which doesn't really have a byte decoder
if ((IsSpmByteWord(piece))) {
char buf[3] = {piece[3], piece[4], 0}; // something like <0x20>
char token = {static_cast<char>(strtol(buf, NULL, 16))};
decoded_token.push_back(token);
} else {
decoded_token.append(ReplaceAll(piece, spm_underscore, " "));
}
} else {
// the common bpe case
const auto str = ustring(piece);
for (auto wchr : str) {
if (byte_decoder_.count(wchr) == 0) {
if (wchr <= char32_t(0xFF)) {
decoded_token.push_back(static_cast<char>(wchr));
continue;
}
if (skip_special_tokens_) {
continue;
} else {
decoded_token = unk_token_;
break;
}
}
char uchr = byte_decoder_.at(wchr);
decoded_token.push_back(uchr);
}
char uchr = byte_decoder_.at(wchr);
decoded_token.push_back(uchr);
}
} else {
if (skip_special_tokens_) {
Expand All @@ -183,15 +218,13 @@ struct KernelBpeDecoder {
}
}

if (whitespace_token_ &&
f_special && (tok_idx > 0 && !f_special_last)) {
if (whitespace_token_ && f_special && (tok_idx > 0 && !f_special_last)) {
text.push_back(' ');
}

text.append(decoded_token);

if (whitespace_token_ &&
f_special && tok_idx != count - 1) {
if (whitespace_token_ && f_special && tok_idx != count - 1) {
text.push_back(' ');
}

Expand Down
19 changes: 0 additions & 19 deletions operators/tokenizer/bpe_streaming.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,25 +53,7 @@ class BpeStreamingDecoder : public KernelBpeDecoder {
return {};
}

static std::string ReplaceAll(std::string_view s, const std::string& search, const std::string& replace) {
std::string result;
for (size_t pos = 0;; pos += search.length()) {
auto new_pos = s.find(search, pos);
if (new_pos == std::string::npos) {
result += s.substr(pos, s.size() - pos);
break;
}
result += s.substr(pos, new_pos - pos);
result += replace;
pos = new_pos;
}

return result;
}

static bool IsSpmByteWord(std::string_view word) {
return word.size() == 6 && word[0] == '<' && word[1] == '0' && word[2] == 'x' && word[5] == '>';
}

OrtxStatus Id2Token(extTokenId_t id,
std::string& token,
Expand Down Expand Up @@ -119,7 +101,6 @@ class BpeStreamingDecoder : public KernelBpeDecoder {
}

OrtxStatus SpmId2Token(extTokenId_t id, std::string& token, bool& f_special_last) const {
const char spm_underscore[] = "\xe2\x96\x81";

std::string piece = id < arr_vocab_.size() ? arr_vocab_[id] : "";
bool f_special = false;
Expand Down
5 changes: 0 additions & 5 deletions test/test_cliptok.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import unittest
import numpy as np
import onnxruntime as _ort
import pkg_resources

from pathlib import Path
from onnx import helper, onnx_pb as onnx_proto
Expand Down Expand Up @@ -150,8 +149,4 @@ def test_optional_outputs(self):


if __name__ == "__main__":
try:
dist = pkg_resources.get_distribution('ftfy')
except pkg_resources.DistributionNotFound:
raise Exception("WARNING: ftfy is not installed - it is required for parity between CLIPTokenizer and CLIPTokenizerFast.")
unittest.main()
13 changes: 12 additions & 1 deletion test/test_fast_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,25 @@ def test_llama_tokenizer(self):
np.testing.assert_array_equal(ids[0], actual_ids[0])

def test_mistral(self):
tokenizer = AutoTokenizer.from_pretrained("mistral-community/Mistral-7B-v0.2", use_fast=True)
tokenizer = AutoTokenizer.from_pretrained(
"mistral-community/Mistral-7B-v0.2", use_fast=True)
text = "\nOnce upon a time, I was really into monochromatic makeup looks. I have a lot of coppery and bronze "
ids = tokenizer.encode(text, return_tensors="np")

ort_tok, _ = gen_processing_models(tokenizer, pre_kwargs={})
actual_ids, *_ = ort_inference(ort_tok, [text])
np.testing.assert_array_equal(ids[0], actual_ids[0])

def test_phi_3_mini(self):
tokenizer = AutoTokenizer.from_pretrained(
"microsoft/Phi-3-mini-128k-instruct", use_fast=True)
text = "what are you? \n 给 weiss ich, über was los ist \n"
ids = tokenizer.encode(text, return_tensors="np")

ort_tok, _ = gen_processing_models(tokenizer, pre_kwargs={})
actual_ids, *_ = ort_inference(ort_tok, [text])
np.testing.assert_array_equal(ids[0], actual_ids[0][1:])


if __name__ == '__main__':
unittest.main()

0 comments on commit b4ebfc9

Please sign in to comment.