88
99int main (int argc, char ** argv) {
1010 gpt_params params;
11- llama_sampling_params & sparams = params.sampling_params ;
12- params.seed = 42 ;
13- params.n_threads = 4 ;
14- sparams.repeat_last_n = 64 ;
11+
1512 params.prompt = " The quick brown fox" ;
1613
1714 if (!gpt_params_parse (argc, argv, params)) {
@@ -25,56 +22,49 @@ int main(int argc, char ** argv) {
2522 }
2623
2724 auto n_past = 0 ;
28- auto last_n_tokens_data = std::vector<llama_token>(sparams.repeat_last_n , 0 );
25+
26+ std::string result0;
27+ std::string result1;
2928
3029 // init
3130 llama_model * model;
3231 llama_context * ctx;
3332
34- std::tie (model, ctx) = llama_init_from_gpt_params ( params );
35- if (model == nullptr ) {
36- return 1 ;
37- }
38- if (ctx == nullptr ) {
39- llama_free_model (model);
33+ std::tie (model, ctx) = llama_init_from_gpt_params (params);
34+ if (model == nullptr || ctx == nullptr ) {
35+ fprintf (stderr, " %s : failed to init\n " , __func__);
4036 return 1 ;
4137 }
38+
39+ // tokenize prompt
4240 auto tokens = llama_tokenize (ctx, params.prompt , true );
43- auto n_prompt_tokens = tokens.size ();
44- if (n_prompt_tokens < 1 ) {
45- fprintf (stderr, " %s : failed to tokenize prompt\n " , __func__);
46- llama_free (ctx);
47- llama_free_model (model);
48- return 1 ;
49- }
5041
5142 // evaluate prompt
52- llama_decode (ctx, llama_batch_get_one (tokens.data (), n_prompt_tokens, n_past, 0 ));
43+ llama_decode (ctx, llama_batch_get_one (tokens.data (), tokens.size (), n_past, 0 ));
44+ n_past += tokens.size ();
5345
54- last_n_tokens_data.insert (last_n_tokens_data.end (), tokens.data (), tokens.data () + n_prompt_tokens);
55- n_past += n_prompt_tokens;
56-
57- const size_t state_size = llama_get_state_size (ctx);
58- uint8_t * state_mem = new uint8_t [state_size];
59-
60- // Save state (rng, logits, embedding and kv_cache) to file
46+ // save state (rng, logits, embedding and kv_cache) to file
6147 {
62- FILE *fp_write = fopen (" dump_state.bin" , " wb" );
63- llama_copy_state_data (ctx, state_mem); // could also copy directly to memory mapped file
64- fwrite (state_mem, 1 , state_size, fp_write);
65- fclose (fp_write);
48+ std::vector<uint8_t > state_mem (llama_get_state_size (ctx));
49+
50+ {
51+ FILE *fp_write = fopen (" dump_state.bin" , " wb" );
52+ llama_copy_state_data (ctx, state_mem.data ()); // could also copy directly to memory mapped file
53+ fwrite (state_mem.data (), 1 , state_mem.size (), fp_write);
54+ fclose (fp_write);
55+ }
6656 }
6757
6858 // save state (last tokens)
69- const auto last_n_tokens_data_saved = std::vector<llama_token>(last_n_tokens_data);
7059 const auto n_past_saved = n_past;
7160
7261 // first run
73- printf (" \n %s" , params.prompt .c_str ());
62+ printf (" \n first run: %s" , params.prompt .c_str ());
7463
7564 for (auto i = 0 ; i < params.n_predict ; i++) {
7665 auto * logits = llama_get_logits (ctx);
7766 auto n_vocab = llama_n_vocab (model);
67+
7868 std::vector<llama_token_data> candidates;
7969 candidates.reserve (n_vocab);
8070 for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
@@ -83,9 +73,10 @@ int main(int argc, char ** argv) {
8373 llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
8474 auto next_token = llama_sample_token (ctx, &candidates_p);
8575 auto next_token_str = llama_token_to_piece (ctx, next_token);
86- last_n_tokens_data.push_back (next_token);
8776
8877 printf (" %s" , next_token_str.c_str ());
78+ result0 += next_token_str;
79+
8980 if (llama_decode (ctx, llama_batch_get_one (&next_token, 1 , n_past, 0 ))) {
9081 fprintf (stderr, " \n %s : failed to evaluate\n " , __func__);
9182 llama_free (ctx);
@@ -103,32 +94,28 @@ int main(int argc, char ** argv) {
10394 // make new context
10495 auto * ctx2 = llama_new_context_with_model (model, llama_context_params_from_gpt_params (params));
10596
106- // Load state (rng, logits, embedding and kv_cache) from file
97+ printf (" \n second run: %s" , params.prompt .c_str ());
98+
99+ // load state (rng, logits, embedding and kv_cache) from file
107100 {
108- FILE *fp_read = fopen (" dump_state.bin" , " rb" );
109- if (state_size != llama_get_state_size (ctx2)) {
110- fprintf (stderr, " \n %s : failed to validate state size\n " , __func__);
111- llama_free (ctx2);
112- llama_free_model (model);
113- return 1 ;
114- }
101+ std::vector<uint8_t > state_mem (llama_get_state_size (ctx2));
115102
116- const size_t ret = fread (state_mem, 1 , state_size, fp_read);
117- if (ret != state_size) {
103+ FILE * fp_read = fopen (" dump_state.bin" , " rb" );
104+
105+ const size_t ret = fread (state_mem.data (), 1 , state_mem.size (), fp_read);
106+ if (ret != state_mem.size ()) {
118107 fprintf (stderr, " \n %s : failed to read state\n " , __func__);
119108 llama_free (ctx2);
120109 llama_free_model (model);
121110 return 1 ;
122111 }
123112
124- llama_set_state_data (ctx2, state_mem); // could also read directly from memory mapped file
113+ llama_set_state_data (ctx2, state_mem.data ());
114+
125115 fclose (fp_read);
126116 }
127117
128- delete[] state_mem;
129-
130118 // restore state (last tokens)
131- last_n_tokens_data = last_n_tokens_data_saved;
132119 n_past = n_past_saved;
133120
134121 // second run
@@ -143,10 +130,11 @@ int main(int argc, char ** argv) {
143130 llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
144131 auto next_token = llama_sample_token (ctx2, &candidates_p);
145132 auto next_token_str = llama_token_to_piece (ctx2, next_token);
146- last_n_tokens_data.push_back (next_token);
147133
148134 printf (" %s" , next_token_str.c_str ());
149- if (llama_decode (ctx, llama_batch_get_one (&next_token, 1 , n_past, 0 ))) {
135+ result1 += next_token_str;
136+
137+ if (llama_decode (ctx2, llama_batch_get_one (&next_token, 1 , n_past, 0 ))) {
150138 fprintf (stderr, " \n %s : failed to evaluate\n " , __func__);
151139 llama_free (ctx2);
152140 llama_free_model (model);
@@ -155,10 +143,17 @@ int main(int argc, char ** argv) {
155143 n_past += 1 ;
156144 }
157145
158- printf (" \n\n " );
146+ printf (" \n " );
159147
160148 llama_free (ctx2);
161149 llama_free_model (model);
162150
151+ if (result0 != result1) {
152+ fprintf (stderr, " \n %s : error : the 2 generations are different\n " , __func__);
153+ return 1 ;
154+ }
155+
156+ fprintf (stderr, " \n %s : success\n " , __func__);
157+
163158 return 0 ;
164159}
0 commit comments