@@ -28,7 +28,7 @@ static std::string trim(const std::string & str) {
2828}
2929
3030static std::string k_system = R"(
31- Transcript of a dialog, where the User interacts with an Assistant.
31+ Transcript of a never ending dialog, where the User interacts with an Assistant.
3232The Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
3333
3434User: Hello, what is the temperature outside?
@@ -59,6 +59,9 @@ struct client {
5959
6060 llama_token sampled;
6161
62+ int64_t t_start_prompt;
63+ int64_t t_start_gen;
64+
6265 int32_t n_prompt = 0 ;
6366 int32_t n_decoded = 0 ;
6467 int32_t i_batch = -1 ;
@@ -133,33 +136,47 @@ int main(int argc, char ** argv) {
133136
134137 for (auto & client : clients) {
135138 if (client.seq_id == -1 ) {
136- client.seq_id = g_seq_id;
137- client.input = k_prompts[rand () % k_prompts.size ()];
138- client.prompt = k_system + client.input + " \n Assistant:" ;
139- client.response = " " ;
140- std::fill (client.last_tokens .begin (), client.last_tokens .end (), 0 );
141-
142- std::vector<llama_token> prompt_tokens;
143- prompt_tokens = ::llama_tokenize (ctx, client.prompt , true );
144-
145- for (size_t i = 0 ; i < prompt_tokens.size (); ++i) {
146- batch_token.push_back (prompt_tokens[i]);
147- batch_pos.push_back (i);
148- batch_seq_id.push_back (client.seq_id );
149- batch_clients.push_back (&client);
139+ continue ;
140+ }
141+
142+ batch_token.push_back (client.sampled );
143+ batch_pos.push_back (client.n_decoded );
144+ batch_seq_id.push_back (client.seq_id );
145+ batch_clients.push_back (&client);
146+ client.n_decoded += 1 ;
147+ client.i_batch = batch_token.size () - 1 ;
148+ }
149+
150+ if (batch_token.empty ()) {
151+ // all sequences have ended - clear the entire KV cache
152+ llama_kv_cache_rm_tokens (ctx, -1 , -1 );
153+
154+ for (auto & client : clients) {
155+ if (client.seq_id == -1 ) {
156+ client.seq_id = g_seq_id;
157+ client.t_start_prompt = ggml_time_us ();
158+ client.t_start_gen = 0 ;
159+
160+ client.input = k_prompts[rand () % k_prompts.size ()];
161+ client.prompt = k_system + client.input + " \n Assistant:" ;
162+ client.response = " " ;
163+ std::fill (client.last_tokens .begin (), client.last_tokens .end (), 0 );
164+
165+ std::vector<llama_token> prompt_tokens;
166+ prompt_tokens = ::llama_tokenize (ctx, client.prompt , true );
167+
168+ for (size_t i = 0 ; i < prompt_tokens.size (); ++i) {
169+ batch_token.push_back (prompt_tokens[i]);
170+ batch_pos.push_back (i);
171+ batch_seq_id.push_back (client.seq_id );
172+ batch_clients.push_back (&client);
173+ }
174+ client.n_prompt = prompt_tokens.size ();
175+ client.n_decoded = prompt_tokens.size ();
176+ client.i_batch = batch_token.size () - 1 ;
177+
178+ g_seq_id += 1 ;
150179 }
151- client.n_prompt = prompt_tokens.size ();
152- client.n_decoded = prompt_tokens.size ();
153- client.i_batch = batch_token.size () - 1 ;
154-
155- g_seq_id += 1 ;
156- } else {
157- batch_token.push_back (client.sampled );
158- batch_pos.push_back (client.n_decoded );
159- batch_seq_id.push_back (client.seq_id );
160- batch_clients.push_back (&client);
161- client.n_decoded += 1 ;
162- client.i_batch = batch_token.size () - 1 ;
163180 }
164181 }
165182
@@ -188,6 +205,10 @@ int main(int argc, char ** argv) {
188205
189206 const llama_token id = llama_sample_token (ctx, NULL , NULL , params, client.last_tokens , candidates, client.i_batch - i);
190207
208+ if (client.t_start_gen == 0 ) {
209+ client.t_start_gen = ggml_time_us ();
210+ }
211+
191212 // remember which tokens were sampled - used for repetition penalties during sampling
192213 client.last_tokens .erase (client.last_tokens .begin ());
193214 client.last_tokens .push_back (id);
@@ -199,7 +220,10 @@ int main(int argc, char ** argv) {
199220 // printf("client %d, seq %d, token %d, pos %d, batch %d: %s\n",
200221 // client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());
201222
202- if (id == llama_token_eos (ctx) || client.n_decoded > params.n_predict || client.response .find (" User:" ) != std::string::npos) {
223+ if (id == llama_token_eos (ctx) || client.n_decoded > params.n_predict ||
224+ client.response .find (" User:" ) != std::string::npos ||
225+ client.response .find (' \n ' ) != std::string::npos) {
226+ // basic reverse prompt
203227 const size_t pos = client.response .find (" User:" );
204228 if (pos != std::string::npos) {
205229 client.response = client.response .substr (0 , pos);
@@ -211,13 +235,18 @@ int main(int argc, char ** argv) {
211235
212236 n_tokens_total += client.n_decoded - client.n_prompt ;
213237
214- printf (" \033 [1mClient %d , seq %d , prompt %d t, response %d t, speed: % .2f t/s\033 [0m: \n\n Input: %s\n Response: %s\n\n " ,
238+ printf (" \033 [1mClient %2d , seq %4d , prompt %4d t, response %4d t, speed: PP %5 .2f t/s, TG %5.2f, AVG %5.2f \033 [0m: \n\n Input: %s\n Response: %s\n\n " ,
215239 client.id , client.seq_id , client.n_prompt , client.n_decoded - client.n_prompt ,
216- (double ) n_tokens_total / (t_main_end - t_main_start) * 1e6 ,
217- client.input .c_str (), ::trim (client.response ).c_str ());
240+ (double ) (client.n_prompt ) / (client.t_start_gen - client.t_start_prompt ) * 1e6 ,
241+ (double ) (client.n_decoded - client.n_prompt ) / (t_main_end - client.t_start_gen ) * 1e6 ,
242+ (double ) (client.n_decoded ) / (t_main_end - client.t_start_prompt ) * 1e6 ,
243+ ::trim (client.input).c_str(),
244+ ::trim(client.response).c_str());
218245
219246 client.seq_id = -1 ;
220247 }
248+
249+ client.i_batch = -1 ;
221250 }
222251 }
223252
0 commit comments