diff --git a/llms/export/mlxlm.cpp b/llms/export/mlxlm.cpp index 6b837e3ec..b5aa9a246 100644 --- a/llms/export/mlxlm.cpp +++ b/llms/export/mlxlm.cpp @@ -13,20 +13,18 @@ namespace mx = mlx::core; #define time_now() std::chrono::high_resolution_clock::now() // Maybe compile -std::function load_model(const std::string& path) { +std::function load_model(const std::string &path) { return mx::compile(mx::import_function(path), /* shapeless = */ true); } // Maybe make tokenizer virtual -BPETokenizer load_tokenizer(const std::string& path) { +BPETokenizer load_tokenizer(const std::string &path) { return BPETokenizer(path); } -void generate( - const std::function& model, - const BPETokenizer& tokenizer, - const std::string& prompt, - int max_tokens /* = 256 */) { +void generate(const std::function &model, + const BPETokenizer &tokenizer, const std::string &prompt, + int max_tokens /* = 256 */) { auto prompt_tokens = tokenizer.encode(prompt); int prompt_size = prompt_tokens.size(); @@ -38,19 +36,22 @@ void generate( }; // Helper to expand the cache and mask - auto expand = [](auto& args, auto& mask) { + auto expand = [](auto &args, auto &mask) { constexpr int cache_step_size = 256; int cache_size = args[1].shape(-2); - int new_size = cache_step_size * ((cache_size + cache_step_size) / cache_step_size); + int new_size = + cache_step_size * ((cache_size + cache_step_size) / cache_step_size); for (auto it = args.begin() + 1; it != args.end(); ++it) { - auto& x = *it; + auto &x = *it; auto shape = x.shape(); shape[2] = new_size; auto new_x = mx::zeros(shape, x.dtype()); shape[2] = cache_size; - *it = mx::slice_update(new_x, x, mx::Shape(x.ndim(), 0), std::move(shape)); + *it = + mx::slice_update(new_x, x, mx::Shape(x.ndim(), 0), std::move(shape)); } - mask = mx::slice_update(mx::full({new_size}, false), mask, {0}, {cache_size}); + mask = + mx::slice_update(mx::full({new_size}, false), mask, {0}, {cache_size}); }; auto tic = time_now(); diff --git a/llms/export/mlxlm.h b/llms/export/mlxlm.h index cf8f72217..0b1da6ec4 100644 --- a/llms/export/mlxlm.h +++ b/llms/export/mlxlm.h @@ -6,15 +6,12 @@ namespace mx = mlx::core; -std::function load_model(const std::string& path); +std::function load_model(const std::string &path); -BPETokenizer load_tokenizer(const std::string& path); +BPETokenizer load_tokenizer(const std::string &path); -struct GenerationResponse { -}; +struct GenerationResponse {}; -void generate( - const std::function& model, - const BPETokenizer& tokenizer, - const std::string& prompt, - int max_tokens = 256); +void generate(const std::function &model, + const BPETokenizer &tokenizer, const std::string &prompt, + int max_tokens = 256); diff --git a/llms/export/tokenizer.cpp b/llms/export/tokenizer.cpp index 3dac0008a..6ecd8b1ec 100644 --- a/llms/export/tokenizer.cpp +++ b/llms/export/tokenizer.cpp @@ -1,21 +1,22 @@ -#include -#include -#include #include +#include +#include #include +#include -#include "tokenizer.h" #include "third_party/unicode.h" +#include "tokenizer.h" using json = nlohmann::json; #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wdeprecated-declarations" -std::pair utf8_to_utf16(const std::string& s) { +std::pair utf8_to_utf16(const std::string &s) { static std::string replace_str = std::string(1, 0xFF); static std::wstring replace_wstr = std::wstring(1, 0xFFFD); - std::wstring_convert> cvt(replace_str, replace_wstr); + std::wstring_convert> cvt(replace_str, + replace_wstr); auto out = cvt.from_bytes(s); return {out, cvt.converted()}; } @@ -23,7 +24,8 @@ std::pair utf8_to_utf16(const std::string& s) { auto make_byte_decoder() { std::unordered_map byte_decoder; - std::vector limits = {0, '!', '~' + 1, L'¡', L'¬' + 1, L'®', L'ÿ' + 1}; + std::vector limits = {0, '!', '~' + 1, L'¡', + L'¬' + 1, L'®', L'ÿ' + 1}; char n = 0; for (int i = 0; i < limits.size() - 1; ++i) { auto start = limits[i]; @@ -43,14 +45,14 @@ auto make_byte_decoder() { auto BPETokenizer::byte_decoder_ = make_byte_decoder(); -BPETokenizer::BPETokenizer(const std::string& path_) { +BPETokenizer::BPETokenizer(const std::string &path_) { auto path = std::filesystem::path(path_); std::ifstream ifs(path / "tokenizer.json"); auto tokenizer = json::parse(ifs); auto model = tokenizer["model"]; token_to_id_ = model["vocab"]; id_to_token_.resize(token_to_id_.size()); - for (auto& [s, id] : token_to_id_) { + for (auto &[s, id] : token_to_id_) { if (id >= id_to_token_.size()) { id_to_token_.resize(id + 1); } @@ -58,7 +60,7 @@ BPETokenizer::BPETokenizer(const std::string& path_) { } std::string type = model["type"]; auto merges = model["merges"]; - for (auto& s : merges) { + for (auto &s : merges) { if (s.is_string()) { merges_.emplace(s, merges_.size()); } else { @@ -69,7 +71,7 @@ BPETokenizer::BPETokenizer(const std::string& path_) { } auto added_tokens = tokenizer["added_tokens"]; - for (auto& added_token : added_tokens) { + for (auto &added_token : added_tokens) { int id = added_token["id"]; if (id >= id_to_token_.size()) { id_to_token_.resize(id + 1); @@ -83,14 +85,17 @@ BPETokenizer::BPETokenizer(const std::string& path_) { } // Currently hardcoded to Llama3 BPE regex - pre_tokenizer_regex_ = {"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"}; + pre_tokenizer_regex_ = { + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}" + "\\p{N}]?\\p{L}+|\\p{N}{1,3}| " + "?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"}; } std::vector BPETokenizer::encode(std::string text) const { auto segments = unicode_regex_split(text, pre_tokenizer_regex_); - auto one_step_merge = [this](std::string segment, std::vector& splits) { + auto one_step_merge = [this](std::string segment, std::vector &splits) { int merge_idx; int rank = INT32_MAX; for (int i = 0; i < splits.size() - 2; ++i) { @@ -119,7 +124,8 @@ std::vector BPETokenizer::encode(std::string text) const { auto start = splits[i]; auto mid = splits[i + 1]; auto end = splits[i + 2]; - if (segment.substr(start, mid - start) == merge_l && segment.substr(mid, end - mid) == merge_r) { + if (segment.substr(start, mid - start) == merge_l && + segment.substr(mid, end - mid) == merge_r) { splits.erase(splits.begin() + i + 1); i -= 1; } @@ -131,18 +137,19 @@ std::vector BPETokenizer::encode(std::string text) const { ids.push_back(bos_id_); // Initialize merges to integer list - auto merge_segment = [&ids, &one_step_merge, this](const std::string& segment) { - + auto merge_segment = [&ids, &one_step_merge, + this](const std::string &segment) { std::vector splits; for (int i = 0; i < segment.size(); ++i) { splits.push_back(i); - if (static_cast(segment[i]) > 128) { + if (static_cast(segment[i]) >= 128) { i++; } } splits.push_back(segment.size()); - while (one_step_merge(segment, splits)) { }; + while (one_step_merge(segment, splits)) { + }; for (int i = 0; i < splits.size() - 1; ++i) { auto start = splits[i]; auto end = splits[i + 1]; @@ -155,7 +162,7 @@ std::vector BPETokenizer::encode(std::string text) const { } }; - for (auto& segment : segments) { + for (auto &segment : segments) { merge_segment(segment); } return ids; @@ -171,7 +178,8 @@ std::string BPETokenizer::id_to_bytes(int id) const { return token; } -std::pair BPETokenizer::try_decode(const std::vector& ids) const { +std::pair +BPETokenizer::try_decode(const std::vector &ids) const { std::string text; for (auto id : ids) { text += id_to_bytes(id); @@ -182,7 +190,7 @@ std::pair BPETokenizer::try_decode(const std::vector& id return {text, complete}; } -std::string BPETokenizer::decode(const std::vector& ids) const { +std::string BPETokenizer::decode(const std::vector &ids) const { return try_decode(ids).first; } diff --git a/llms/export/tokenizer.h b/llms/export/tokenizer.h index 5e1fcebf8..64dd2fe42 100644 --- a/llms/export/tokenizer.h +++ b/llms/export/tokenizer.h @@ -8,24 +8,24 @@ /** BPE Tokenizer API */ class BPETokenizer { - public: - BPETokenizer(const std::string& path); +public: + BPETokenizer(const std::string &path); /** Encode a string of text to token integer ids. */ std::vector encode(std::string text) const; /** Try to decode the vector of ids to text. The text is truncated to * include only the fully decodable tokens. */ - std::string decode(const std::vector& ids) const; + std::string decode(const std::vector &ids) const; /** Try to decode the vector of ids to text. The second return value * indicates if the decoding completed. The text is truncated to include * only the fully decodable tokens. */ - std::pair try_decode(const std::vector& ids) const; + std::pair try_decode(const std::vector &ids) const; int eos_token_id() const; - private: +private: std::unordered_map token_to_id_; std::vector id_to_token_; std::unordered_map merges_;