diff --git a/youtokentome/cpp/bpe.cpp b/youtokentome/cpp/bpe.cpp index 2a0fe73..fa748f5 100644 --- a/youtokentome/cpp/bpe.cpp +++ b/youtokentome/cpp/bpe.cpp @@ -865,7 +865,7 @@ void rename_tokens(ska::flat_hash_map &char2id, } BPEState learn_bpe_from_string(string &text_utf8, int n_tokens, - const string &output_file, + StreamWriter &output, BpeConfig bpe_config) { vector threads; assert(bpe_config.n_threads >= 1 || bpe_config.n_threads == -1); @@ -1294,8 +1294,8 @@ BPEState learn_bpe_from_string(string &text_utf8, int n_tokens, rename_tokens(char2id, rules, bpe_config.special_tokens, n_tokens); BPEState bpe_state = {char2id, rules, bpe_config.special_tokens}; - bpe_state.dump(output_file); - std::cerr << "model saved to: " << output_file << std::endl; + bpe_state.dump(output); + std::cerr << "model saved to: " << output.name() << std::endl; return bpe_state; } @@ -1450,7 +1450,8 @@ void train_bpe(const string &input_path, const string &model_path, std::cerr << "reading file..." << std::endl; auto data = fast_read_file_utf8(input_path); std::cerr << "learning bpe..." << std::endl; - learn_bpe_from_string(data, vocab_size, model_path, bpe_config); + auto fout = StreamWriter.open(model_path); + learn_bpe_from_string(data, vocab_size, fout, bpe_config); } DecodeResult BaseEncoder::encode_sentence(const std::string &sentence_utf8, diff --git a/youtokentome/cpp/bpe.h b/youtokentome/cpp/bpe.h index dff9326..bd5ebc6 100644 --- a/youtokentome/cpp/bpe.h +++ b/youtokentome/cpp/bpe.h @@ -32,7 +32,7 @@ class BaseEncoder { explicit BaseEncoder(BPEState bpe_state, int _n_threads); - explicit BaseEncoder(const std::string& model_path, int n_threads); + explicit BaseEncoder(const StreamReader& model_path, int n_threads); void fill_from_state(); diff --git a/youtokentome/cpp/utils.cpp b/youtokentome/cpp/utils.cpp index 901e6ef..a662c6c 100644 --- a/youtokentome/cpp/utils.cpp +++ b/youtokentome/cpp/utils.cpp @@ -2,20 +2,112 @@ #include #include #include +#include #include #include + namespace vkcom { using std::string; using std::vector; -void SpecialTokens::dump(std::ofstream &fout) { - fout << unk_id << " " << pad_id << " " << bos_id << " " << eos_id - << std::endl; +class FileWriter : public StreamWriter { + public: + FileWriter(const std::string &file_name) { + this->file_name = file_name; + this->fout = std::ofstream(file_name, std::ios::out | std::ios::binary); + if (fout.fail()) { + std::cerr << "Can't open file: " << file_name << std::endl; + assert(false); + } + } + + virtual int write(const char *buffer, int size) override { + return fout.write(buffer, size); + } + + virtual std::string name() const noexcept override { + return file_name; + } + + private: + std::string file_name; + std::ofstream fout; +}; + +class FileReader : public StreamReader { + public: + FileReader(const std::string &file_name) { + this->file_name = file_name; + this->fin = std::ifstream(file_name, std::ios::in | std::ios::binary); + if (fin.fail()) { + std::cerr << "Can't open file: " << file_name << std::endl; + assert(false); + } + } + + virtual int read(const char *buffer, int size) override { + return fin.read(buffer, size); + } + + virtual std::string name() const noexcept override { + return file_name; + } + + private: + std::string file_name; + std::ifstream fin; +}; + +StreamWriter StreamWriter::open(const std::string &file_name) { + return FileWriter(file_name); +} + +StreamReader StreamReader::open(const std::string &file_name) { + return FileReader(file_name); +} + +template::value, int>::type = 0> +T bin_to_int(const char *val) { + uint32_t ret = static_cast(val[0]); + ret |= static_cast(static_cast(val[1])) << 8; + ret |= static_cast(static_cast(val[2])) << 16; + ret |= static_cast(static_cast(val[3])) << 24; + return ret; +} + +template::value, int>::type = 0> +std::unique_ptr int_to_bin(T val) { + auto u32 = static_cast(val); + std::unique_ptr ret(new char[4]); + ret[0] = u32 & 0xFF; + ret[1] = (u32 >> 8) & 0xFF; + ret[2] = (u32 >> 16) & 0xFF; + ret[3] = (u32 >> 24); // no need for & 0xFF + return std::move(ret); +} + +void SpecialTokens::dump(StreamWriter &fout) { + std::unique_ptr unk_id_ptr(int_to_bin(unk_id)), + pad_id_ptr(int_to_bin(pad_id)), + bos_id_ptr(int_to_bin(bos_id)), + eos_id_ptr(int_to_bin(eos_id)); + fout.write(unk_id_ptr.get(), 4); + fout.write(pad_id_ptr.get(), 4); + fout.write(bos_id_ptr.get(), 4); + fout.write(eos_id_ptr.get(), 4); } -void SpecialTokens::load(std::ifstream &fin) { - fin >> unk_id >> pad_id >> bos_id >> eos_id; +void SpecialTokens::load(StreamReader &fin) { + char unk_id_bs[4], pad_id_bs[4], bos_id_bs[4], eos_id_bs[4]; + fin.read(unk_id_bs, 4); + fin.read(pad_id_bs, 4); + fin.read(bos_id_bs, 4); + fin.read(eos_id_bs, 4); + this->unk_id = bin_to_int(unk_id_bs); + this->pad_id = bin_to_int(pad_id_bs); + this->bos_id = bin_to_int(bos_id_bs); + this->eos_id = bin_to_int(eos_id_bs); } uint32_t SpecialTokens::max_id() const { @@ -49,48 +141,69 @@ bool BPE_Rule::operator==(const BPE_Rule &other) const { BPE_Rule::BPE_Rule(uint32_t x, uint32_t y, uint32_t z) : x(x), y(y), z(z) {} -void BPEState::dump(const string &file_name) { - std::ofstream fout(file_name, std::ios::out); - if (fout.fail()) { - std::cerr << "Can't open file: " << file_name << std::endl; - assert(false); +void BPEState::dump(StreamWriter &fout) { + std::unique_ptr char2id_ptr(int_to_bin(char2id.size())), + rules_ptr(int_to_bin(rules.size())); + fout.write(char2id_ptr.get(), 4); + fout.write(rules_ptr.get(), 4); + for (auto &s : char2id) { + std::unique_ptr first_ptr(int_to_bin(s.first)), + second_ptr(int_to_bin(s.second)); + fout.write(first_ptr.get(), 4); + fout.write(second_ptr.get(), 4); } - fout << char2id.size() << " " << rules.size() << std::endl; - for (auto s : char2id) { - fout << s.first << " " << s.second << std::endl; + for (auto &rule : rules) { + std::unique_ptr rule_ptr(int_to_bin(rule.x)); + fout.write(rule_ptr.get(), 4); } - - for (auto rule : rules) { - fout << rule.x << " " << rule.y << " " << rule.z << std::endl; + for (auto &rule : rules) { + std::unique_ptr rule_ptr(int_to_bin(rule.y)); + fout.write(rule_ptr.get(), 4); + } + for (auto &rule : rules) { + std::unique_ptr rule_ptr(int_to_bin(rule.z)); + fout.write(rule_ptr.get(), 4); } special_tokens.dump(fout); - fout.close(); } -void BPEState::load(const string &file_name) { +void BPEState::load(StreamReader &fin) { char2id.clear(); rules.clear(); - std::ifstream fin(file_name, std::ios::in); - if (fin.fail()) { - std::cerr << "Error. Can not open file with model: " << file_name - << std::endl; - exit(EXIT_FAILURE); - } - int n, m; - fin >> n >> m; + char n_bs[4], m_bs[4]; + fin.read(n_bs, 4); + fin.read(m_bs, 4); + auto n = bin_to_int(n_bs); + auto m = bin_to_int(m_bs); for (int i = 0; i < n; i++) { - uint32_t inner_id; - uint32_t utf32_id; - fin >> inner_id >> utf32_id; + char inner_id_bs[4], utf32_id_bs[4]; + fin.read(inner_id_bs, 4); + fin.read(utf32_id_bs, 4); + auto inner_id = bin_to_int(inner_id_bs); + auto utf32_id = bin_to_int(utf32_id_bs); char2id[inner_id] = utf32_id; } + std::vector> rules_xyz(m); + for (int j = 0; j < 3; j++) { + for (int i = 0; i < m; i++) { + char val[4]; + fin.read(val, 4); + uint32_t *element; + switch (j) { + case 0: + element = &std::get<0>(rules_xyz[i]); + case 1: + element = &std::get<1>(rules_xyz[i]); + case 2: + element = &std::get<2>(rules_xyz[i]); + } + *element = bin_to_int(val); + } + } for (int i = 0; i < m; i++) { - uint32_t x, y, z; - fin >> x >> y >> z; - rules.emplace_back(x, y, z); + rules.emplace_back(std::get<0>(rules_xyz[i]), std::get<1>(rules_xyz[i]), std::get<2>(rules_xyz[i])); } special_tokens.load(fin); - fin.close(); } BpeConfig::BpeConfig(double _character_coverage, int _n_threads, diff --git a/youtokentome/cpp/utils.h b/youtokentome/cpp/utils.h index d45346c..745fb7f 100644 --- a/youtokentome/cpp/utils.h +++ b/youtokentome/cpp/utils.h @@ -8,6 +8,22 @@ namespace vkcom { const uint32_t SPACE_TOKEN = 9601; +struct StreamWriter { + virtual int write(const char *buffer, int size) = 0; + virtual std::string name() const noexcept = 0; + virtual ~StreamWriter() = default; + + static StreamWriter open(const std::string &file_name); +}; + +struct StreamReader { + virtual int read(const char *buffer, int size) = 0; + virtual std::string name() const noexcept = 0; + virtual ~StreamReader() = default; + + static StreamReader open(const std::string &file_name); +}; + struct BPE_Rule { // x + y -> z uint32_t x{0}; @@ -31,9 +47,9 @@ struct SpecialTokens { SpecialTokens(int pad_id, int unk_id, int bos_id, int eos_id); - void dump(std::ofstream &fout); + void dump(StreamWriter &fout); - void load(std::ifstream &fin); + void load(StreamReader &fin); uint32_t max_id() const; @@ -58,9 +74,9 @@ struct BPEState { std::vector rules; SpecialTokens special_tokens; - void dump(const std::string &file_name); + void dump(StreamWriter &fout); - void load(const std::string &file_name); + void load(StreamReader &fin); }; struct DecodeResult { diff --git a/youtokentome/youtokentome.py b/youtokentome/youtokentome.py index e02c84a..a9df041 100644 --- a/youtokentome/youtokentome.py +++ b/youtokentome/youtokentome.py @@ -1,22 +1,32 @@ from enum import Enum -from typing import List, Union +from functools import wraps +from typing import BinaryIO, List, Optional, Union import _youtokentome_cython + class OutputType(Enum): ID = 1 SUBWORD = 2 + class BPE: - def __init__(self, model: str, n_threads: int = -1): - self.bpe_cython = _youtokentome_cython.BPE( - model_path=model, n_threads=n_threads - ) + def __init__(self, model: Union[str, BinaryIO], n_threads: int = -1): + own_obj = isinstance(model, str) + if own_obj: + model = open(model, "rb") + try: + self.bpe_cython = _youtokentome_cython.BPE( + model_fobj=model, n_threads=n_threads + ) + finally: + if own_obj: + model.close() @staticmethod def train( data: str, - model: str, + model: Optional[Union[str, BinaryIO]], vocab_size: int, coverage: float = 1.0, n_threads: int = -1, @@ -25,17 +35,24 @@ def train( bos_id: int = 2, eos_id: int = 3, ) -> "BPE": - _youtokentome_cython.BPE.train( - data=data, - model=model, - vocab_size=vocab_size, - n_threads=n_threads, - coverage=coverage, - pad_id=pad_id, - unk_id=unk_id, - bos_id=bos_id, - eos_id=eos_id, - ) + own_obj = isinstance(model, str) + if own_obj: + model = open(model, "wb") + try: + _youtokentome_cython.BPE.train( + data=data, + model=model, + vocab_size=vocab_size, + n_threads=n_threads, + coverage=coverage, + pad_id=pad_id, + unk_id=unk_id, + bos_id=bos_id, + eos_id=eos_id, + ) + finally: + if own_obj: + model.close() return BPE(model=model, n_threads=n_threads) @@ -61,6 +78,22 @@ def encode( reverse=reverse, ) + def save(self, where: Union[str, BinaryIO]): + """ + Write the model to FS or any writeable file object. + + :param where: FS path or writeable file object. + :return: None + """ + own_obj = isinstance(where, str) + if own_obj: + where = open(where, "wb") + try: + self.bpe_cython.save(where=where) + finally: + if own_obj: + where.close() + def vocab_size(self) -> int: return self.bpe_cython.vocab_size()