@@ -857,21 +857,23 @@ struct common_init_result common_init_from_params(common_params & params) {
857857 return iparams;
858858 }
859859
860+ const llama_vocab * vocab = llama_get_vocab (model);
861+
860862 if (params.reranking ) {
861863 bool ok = true ;
862864
863- if (llama_token_bos (model ) == LLAMA_TOKEN_NULL) {
864- LOG_WRN (" %s: warning: model does not have a BOS token, reranking will not work\n " , __func__);
865+ if (llama_token_bos (vocab ) == LLAMA_TOKEN_NULL) {
866+ LOG_WRN (" %s: warning: vocab does not have a BOS token, reranking will not work\n " , __func__);
865867 ok = false ;
866868 }
867869
868- if (llama_token_eos (model ) == LLAMA_TOKEN_NULL) {
869- LOG_WRN (" %s: warning: model does not have an EOS token, reranking will not work\n " , __func__);
870+ if (llama_token_eos (vocab ) == LLAMA_TOKEN_NULL) {
871+ LOG_WRN (" %s: warning: vocab does not have an EOS token, reranking will not work\n " , __func__);
870872 ok = false ;
871873 }
872874
873- if (llama_token_sep (model ) == LLAMA_TOKEN_NULL) {
874- LOG_WRN (" %s: warning: model does not have a SEP token, reranking will not work\n " , __func__);
875+ if (llama_token_sep (vocab ) == LLAMA_TOKEN_NULL) {
876+ LOG_WRN (" %s: warning: vocab does not have a SEP token, reranking will not work\n " , __func__);
875877 ok = false ;
876878 }
877879
@@ -941,14 +943,14 @@ struct common_init_result common_init_from_params(common_params & params) {
941943 common_lora_adapters_apply (lctx, params.lora_adapters );
942944 }
943945
944- if (params.sampling .ignore_eos && llama_token_eos (model ) == LLAMA_TOKEN_NULL) {
945- LOG_WRN (" %s: warning: model does not have an EOS token, ignoring --ignore-eos\n " , __func__);
946+ if (params.sampling .ignore_eos && llama_token_eos (vocab ) == LLAMA_TOKEN_NULL) {
947+ LOG_WRN (" %s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n " , __func__);
946948 params.sampling .ignore_eos = false ;
947949 }
948950
949951 if (params.sampling .ignore_eos ) {
950- for (llama_token i = 0 ; i < llama_n_vocab (model ); i++) {
951- if (llama_token_is_eog (model , i)) {
952+ for (llama_token i = 0 ; i < llama_n_vocab (vocab ); i++) {
953+ if (llama_token_is_eog (vocab , i)) {
952954 LOG_INF (" %s: added %s logit bias = %f\n " , __func__, common_token_to_piece (lctx, i).c_str (), -INFINITY);
953955 params.sampling .logit_bias .push_back ({i, -INFINITY});
954956 }
@@ -969,8 +971,9 @@ struct common_init_result common_init_from_params(common_params & params) {
969971 LOG_WRN (" %s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n " , __func__);
970972
971973 std::vector<llama_token> tmp;
972- llama_token bos = llama_token_bos (model);
973- llama_token eos = llama_token_eos (model);
974+ llama_token bos = llama_token_bos (vocab);
975+ llama_token eos = llama_token_eos (vocab);
976+
974977 // some models (e.g. T5) don't have a BOS token
975978 if (bos != LLAMA_TOKEN_NULL) {
976979 tmp.push_back (bos);
@@ -1559,21 +1562,23 @@ std::vector<llama_token> common_tokenize(
15591562 const std::string & text,
15601563 bool add_special,
15611564 bool parse_special) {
1562- return common_tokenize (llama_get_model (ctx), text, add_special, parse_special);
1565+ const llama_model * model = llama_get_model (ctx);
1566+ const llama_vocab * vocab = llama_get_vocab (model);
1567+ return common_tokenize (vocab, text, add_special, parse_special);
15631568}
15641569
15651570std::vector<llama_token> common_tokenize (
1566- const struct llama_model * model ,
1571+ const struct llama_vocab * vocab ,
15671572 const std::string & text,
15681573 bool add_special,
15691574 bool parse_special) {
15701575 // upper limit for the number of tokens
15711576 int n_tokens = text.length () + 2 * add_special;
15721577 std::vector<llama_token> result (n_tokens);
1573- n_tokens = llama_tokenize (model , text.data (), text.length (), result.data (), result.size (), add_special, parse_special);
1578+ n_tokens = llama_tokenize (vocab , text.data (), text.length (), result.data (), result.size (), add_special, parse_special);
15741579 if (n_tokens < 0 ) {
15751580 result.resize (-n_tokens);
1576- int check = llama_tokenize (model , text.data (), text.length (), result.data (), result.size (), add_special, parse_special);
1581+ int check = llama_tokenize (vocab , text.data (), text.length (), result.data (), result.size (), add_special, parse_special);
15771582 GGML_ASSERT (check == -n_tokens);
15781583 } else {
15791584 result.resize (n_tokens);
@@ -1582,12 +1587,18 @@ std::vector<llama_token> common_tokenize(
15821587}
15831588
15841589std::string common_token_to_piece (const struct llama_context * ctx, llama_token token, bool special) {
1590+ const llama_model * model = llama_get_model (ctx);
1591+ const llama_vocab * vocab = llama_get_vocab (model);
1592+ return common_token_to_piece (vocab, token, special);
1593+ }
1594+
1595+ std::string common_token_to_piece (const struct llama_vocab * vocab, llama_token token, bool special) {
15851596 std::string piece;
15861597 piece.resize (piece.capacity ()); // using string internal cache, 15 bytes + '\n'
1587- const int n_chars = llama_token_to_piece (llama_get_model (ctx) , token, &piece[0 ], piece.size (), 0 , special);
1598+ const int n_chars = llama_token_to_piece (vocab , token, &piece[0 ], piece.size (), 0 , special);
15881599 if (n_chars < 0 ) {
15891600 piece.resize (-n_chars);
1590- int check = llama_token_to_piece (llama_get_model (ctx) , token, &piece[0 ], piece.size (), 0 , special);
1601+ int check = llama_token_to_piece (vocab , token, &piece[0 ], piece.size (), 0 , special);
15911602 GGML_ASSERT (check == -n_chars);
15921603 }
15931604 else {
@@ -1597,13 +1608,19 @@ std::string common_token_to_piece(const struct llama_context * ctx, llama_token
15971608 return piece;
15981609}
15991610
1600- std::string common_detokenize (llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
1611+ std::string common_detokenize (const struct llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
1612+ const llama_model * model = llama_get_model (ctx);
1613+ const llama_vocab * vocab = llama_get_vocab (model);
1614+ return common_detokenize (vocab, tokens, special);
1615+ }
1616+
1617+ std::string common_detokenize (const struct llama_vocab * vocab, const std::vector<llama_token> & tokens, bool special) {
16011618 std::string text;
16021619 text.resize (std::max (text.capacity (), tokens.size ()));
1603- int32_t n_chars = llama_detokenize (llama_get_model (ctx) , tokens.data (), (int32_t )tokens.size (), &text[0 ], (int32_t )text.size (), false , special);
1620+ int32_t n_chars = llama_detokenize (vocab , tokens.data (), (int32_t )tokens.size (), &text[0 ], (int32_t )text.size (), false , special);
16041621 if (n_chars < 0 ) {
16051622 text.resize (-n_chars);
1606- n_chars = llama_detokenize (llama_get_model (ctx) , tokens.data (), (int32_t )tokens.size (), &text[0 ], (int32_t )text.size (), false , special);
1623+ n_chars = llama_detokenize (vocab , tokens.data (), (int32_t )tokens.size (), &text[0 ], (int32_t )text.size (), false , special);
16071624 GGML_ASSERT (n_chars <= (int32_t )text.size ()); // whitespace trimming is performed after per-token detokenization
16081625 }
16091626
@@ -1631,7 +1648,7 @@ std::string common_get_builtin_chat_template(const struct llama_model * model) {
16311648
16321649bool common_chat_verify_template (const std::string & tmpl) {
16331650 llama_chat_message chat[] = {{" user" , " test" }};
1634- int res = llama_chat_apply_template (nullptr , tmpl.c_str (), chat, 1 , true , nullptr , 0 );
1651+ const int res = llama_chat_apply_template (tmpl.c_str (), chat, 1 , true , nullptr , 0 );
16351652 return res >= 0 ;
16361653}
16371654
@@ -1642,35 +1659,34 @@ std::string common_chat_apply_template(const struct llama_model * model,
16421659 int alloc_size = 0 ;
16431660 bool fallback = false ; // indicate if we must fallback to default chatml
16441661 std::vector<llama_chat_message> chat;
1645- for (auto & msg : msgs) {
1662+ for (const auto & msg : msgs) {
16461663 chat.push_back ({msg.role .c_str (), msg.content .c_str ()});
16471664 alloc_size += (msg.role .size () + msg.content .size ()) * 1.25 ;
16481665 }
16491666
1650- const char * ptr_tmpl = tmpl.empty () ? nullptr : tmpl.c_str ();
1667+ const char * ptr_tmpl = tmpl.empty () ? llama_model_chat_template (model) : tmpl.c_str ();
16511668 std::vector<char > buf (alloc_size);
16521669
16531670 // run the first time to get the total output length
1654- int32_t res = llama_chat_apply_template (model, ptr_tmpl, chat.data (), chat.size (), add_ass, buf.data (), buf.size ());
1671+ int32_t res = llama_chat_apply_template (ptr_tmpl, chat.data (), chat.size (), add_ass, buf.data (), buf.size ());
16551672
16561673 // error: chat template is not supported
16571674 if (res < 0 ) {
16581675 if (ptr_tmpl != nullptr ) {
16591676 // if the custom "tmpl" is not supported, we throw an error
16601677 // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
16611678 throw std::runtime_error (" this custom template is not supported" );
1662- } else {
1663- // If the built-in template is not supported, we default to chatml
1664- res = llama_chat_apply_template (nullptr , " chatml" , chat.data (), chat.size (), add_ass, buf.data (), buf.size ());
1665- fallback = true ;
16661679 }
1680+
1681+ // If the built-in template is not supported, we default to chatml
1682+ res = llama_chat_apply_template (" chatml" , chat.data (), chat.size (), add_ass, buf.data (), buf.size ());
1683+ fallback = true ;
16671684 }
16681685
16691686 // if it turns out that our buffer is too small, we resize it
16701687 if ((size_t ) res > buf.size ()) {
16711688 buf.resize (res);
16721689 res = llama_chat_apply_template (
1673- fallback ? nullptr : model,
16741690 fallback ? " chatml" : ptr_tmpl,
16751691 chat.data (), chat.size (), add_ass, buf.data (), buf.size ());
16761692 }
0 commit comments