Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix spm converted FastTokenizer issue on non-ascii char #778

Merged
merged 3 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 60 additions & 27 deletions operators/tokenizer/bpe_decoder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ struct KernelBpeDecoder {
return status;
} else {
auto um = ParseId2String(byte_decoder);
std::transform(um.begin(), um.end(),
std::inserter(byte_decoder_, byte_decoder_.end()),
[](const auto& p) { return std::make_pair(static_cast<char32_t>(p.first),
ort_extensions::narrow<unsigned char>(std::stoul(p.second))); });
std::transform(um.begin(), um.end(), std::inserter(byte_decoder_, byte_decoder_.end()), [](const auto& p) {
return std::make_pair(static_cast<char32_t>(p.first),
ort_extensions::narrow<unsigned char>(std::stoul(p.second)));
});
}

std::string added_tokens;
Expand All @@ -59,8 +59,7 @@ struct KernelBpeDecoder {
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "all_special_ids", all_special_ids));
if (!all_special_ids.empty()) {
auto um = ParseId2String(all_special_ids);
std::transform(um.begin(), um.end(),
std::inserter(all_special_ids_, all_special_ids_.end()),
std::transform(um.begin(), um.end(), std::inserter(all_special_ids_, all_special_ids_.end()),
[](const auto& p) { return p.first; });
}

Expand Down Expand Up @@ -116,8 +115,29 @@ struct KernelBpeDecoder {
arr_vocab_.shrink_to_fit();
}

OrtxStatus Compute(const ortc::Tensor<int64_t>& ids,
ortc::Tensor<std::string>& output) const {
const std::string spm_underscore{"\xe2\x96\x81"};

static bool IsSpmByteWord(std::string_view word) {
return word.size() == 6 && word[0] == '<' && word[1] == '0' && word[2] == 'x' && word[5] == '>';
}

static std::string ReplaceAll(std::string_view s, const std::string& search, const std::string& replace) {
std::string result;
for (size_t pos = 0;; pos += search.length()) {
auto new_pos = s.find(search, pos);
if (new_pos == std::string::npos) {
result += s.substr(pos, s.size() - pos);
break;
}
result += s.substr(pos, new_pos - pos);
result += replace;
pos = new_pos;
}

return result;
}

OrtxStatus Compute(const ortc::Tensor<int64_t>& ids, ortc::Tensor<std::string>& output) const {
const int64_t* p_ids = ids.Data();
const auto& ids_dim = ids.Shape();
std::vector<int64_t> output_dim = {1};
Expand All @@ -126,6 +146,8 @@ struct KernelBpeDecoder {
std::copy(ids_dim.begin(), ids_dim.begin() + ids_dim.size() - 1, output_dim.begin());
}

bool spm_mode = byte_decoder_.count(ustring(spm_underscore)[0]) > 0;

size_t seq_len = ids_dim.back();
size_t string_batch = ids.NumberOfElement() / seq_len;
std::vector<std::string> decoded_strings;
Expand All @@ -148,24 +170,37 @@ struct KernelBpeDecoder {

if (added_tokens_.count(token)) {
const std::string& ws = added_tokens_.at(token);
decoded_token = (std::string)ws;
decoded_token.assign(ws);
} else if (static_cast<size_t>(token) < arr_vocab_.size()) {
const auto str = ustring(arr_vocab_[token]);
for (auto wchr : str) {
if (byte_decoder_.count(wchr) == 0) {
if (wchr <= char32_t(0xFF)) {
decoded_token.push_back(static_cast<char>(wchr));
continue;
}
if (skip_special_tokens_) {
continue;
} else {
decoded_token = unk_token_;
break;
const auto piece = arr_vocab_[token];
if (spm_mode) {
// sentencepiece case, which doesn't really have a byte decoder
if ((IsSpmByteWord(piece))) {
char buf[3] = {piece[3], piece[4], 0}; // something like <0x20>
char token = {static_cast<char>(strtol(buf, NULL, 16))};
decoded_token.push_back(token);
} else {
decoded_token.append(ReplaceAll(piece, spm_underscore, " "));
}
} else {
// the common bpe case
const auto str = ustring(piece);
for (auto wchr : str) {
if (byte_decoder_.count(wchr) == 0) {
if (wchr <= char32_t(0xFF)) {
decoded_token.push_back(static_cast<char>(wchr));
continue;
}
if (skip_special_tokens_) {
continue;
} else {
decoded_token = unk_token_;
break;
}
}
char uchr = byte_decoder_.at(wchr);
decoded_token.push_back(uchr);
}
char uchr = byte_decoder_.at(wchr);
decoded_token.push_back(uchr);
}
} else {
if (skip_special_tokens_) {
Expand All @@ -183,15 +218,13 @@ struct KernelBpeDecoder {
}
}

if (whitespace_token_ &&
f_special && (tok_idx > 0 && !f_special_last)) {
if (whitespace_token_ && f_special && (tok_idx > 0 && !f_special_last)) {
text.push_back(' ');
}

text.append(decoded_token);

if (whitespace_token_ &&
f_special && tok_idx != count - 1) {
if (whitespace_token_ && f_special && tok_idx != count - 1) {
text.push_back(' ');
}

Expand Down
19 changes: 0 additions & 19 deletions operators/tokenizer/bpe_streaming.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,25 +53,7 @@ class BpeStreamingDecoder : public KernelBpeDecoder {
return {};
}

static std::string ReplaceAll(std::string_view s, const std::string& search, const std::string& replace) {
std::string result;
for (size_t pos = 0;; pos += search.length()) {
auto new_pos = s.find(search, pos);
if (new_pos == std::string::npos) {
result += s.substr(pos, s.size() - pos);
break;
}
result += s.substr(pos, new_pos - pos);
result += replace;
pos = new_pos;
}

return result;
}

static bool IsSpmByteWord(std::string_view word) {
return word.size() == 6 && word[0] == '<' && word[1] == '0' && word[2] == 'x' && word[5] == '>';
}

OrtxStatus Id2Token(extTokenId_t id,
std::string& token,
Expand Down Expand Up @@ -119,7 +101,6 @@ class BpeStreamingDecoder : public KernelBpeDecoder {
}

OrtxStatus SpmId2Token(extTokenId_t id, std::string& token, bool& f_special_last) const {
const char spm_underscore[] = "\xe2\x96\x81";

std::string piece = id < arr_vocab_.size() ? arr_vocab_[id] : "";
bool f_special = false;
Expand Down
5 changes: 0 additions & 5 deletions test/test_cliptok.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import unittest
import numpy as np
import onnxruntime as _ort
import pkg_resources

from pathlib import Path
from onnx import helper, onnx_pb as onnx_proto
Expand Down Expand Up @@ -150,8 +149,4 @@ def test_optional_outputs(self):


if __name__ == "__main__":
try:
dist = pkg_resources.get_distribution('ftfy')
except pkg_resources.DistributionNotFound:
raise Exception("WARNING: ftfy is not installed - it is required for parity between CLIPTokenizer and CLIPTokenizerFast.")
unittest.main()
13 changes: 12 additions & 1 deletion test/test_fast_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,25 @@ def test_llama_tokenizer(self):
np.testing.assert_array_equal(ids[0], actual_ids[0])

def test_mistral(self):
tokenizer = AutoTokenizer.from_pretrained("mistral-community/Mistral-7B-v0.2", use_fast=True)
tokenizer = AutoTokenizer.from_pretrained(
"mistral-community/Mistral-7B-v0.2", use_fast=True)
text = "\nOnce upon a time, I was really into monochromatic makeup looks. I have a lot of coppery and bronze "
ids = tokenizer.encode(text, return_tensors="np")

ort_tok, _ = gen_processing_models(tokenizer, pre_kwargs={})
actual_ids, *_ = ort_inference(ort_tok, [text])
np.testing.assert_array_equal(ids[0], actual_ids[0])

def test_phi_3_mini(self):
tokenizer = AutoTokenizer.from_pretrained(
"microsoft/Phi-3-mini-128k-instruct", use_fast=True)
text = "what are you? \n 给 weiss ich, über was los ist \n"
ids = tokenizer.encode(text, return_tensors="np")

ort_tok, _ = gen_processing_models(tokenizer, pre_kwargs={})
actual_ids, *_ = ort_inference(ort_tok, [text])
np.testing.assert_array_equal(ids[0], actual_ids[0][1:])


if __name__ == '__main__':
unittest.main()
Loading