33#include " common.h"
44#include " log.h"
55
6+ #include < algorithm>
67#include < cmath>
8+ #include < cstring>
79#include < unordered_map>
8- #include < algorithm>
910
1011// the ring buffer works similarly to std::deque, but with a fixed capacity
1112// TODO: deduplicate with llama-impl.h
@@ -112,6 +113,13 @@ struct common_sampler {
112113
113114 llama_token_data_array cur_p;
114115
116+ void reset () {
117+ prev.clear ();
118+
119+ llama_sampler_reset (grmr);
120+ llama_sampler_reset (chain);
121+ }
122+
115123 void set_logits (struct llama_context * ctx, int idx) {
116124 const auto * logits = llama_get_logits_ith (ctx, idx);
117125
@@ -128,6 +136,12 @@ struct common_sampler {
128136
129137 cur_p = { cur.data (), cur.size (), -1 , false };
130138 }
139+
140+ common_time_meas tm () {
141+ return common_time_meas (t_total_us, params.no_perf );
142+ }
143+
144+ mutable int64_t t_total_us = 0 ;
131145};
132146
133147std::string common_params_sampling::print () const {
@@ -298,6 +312,8 @@ void common_sampler_free(struct common_sampler * gsmpl) {
298312}
299313
300314void common_sampler_accept (struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
315+ const auto tm = gsmpl->tm ();
316+
301317 if (accept_grammar) {
302318 llama_sampler_accept (gsmpl->grmr , token);
303319 }
@@ -308,9 +324,7 @@ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, boo
308324}
309325
310326void common_sampler_reset (struct common_sampler * gsmpl) {
311- llama_sampler_reset (gsmpl->grmr );
312-
313- llama_sampler_reset (gsmpl->chain );
327+ gsmpl->reset ();
314328}
315329
316330struct common_sampler * common_sampler_clone (common_sampler * gsmpl) {
@@ -327,16 +341,54 @@ struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
327341void common_perf_print (const struct llama_context * ctx, const struct common_sampler * gsmpl) {
328342 // TODO: measure grammar performance
329343
344+ const double t_sampling_ms = gsmpl ? 1e-3 *gsmpl->t_total_us : 0 ;
345+
346+ llama_perf_sampler_data data_smpl;
347+ llama_perf_context_data data_ctx;
348+
349+ memset (&data_smpl, 0 , sizeof (data_smpl));
350+ memset (&data_ctx, 0 , sizeof (data_ctx));
351+
330352 if (gsmpl) {
331- llama_perf_sampler_print (gsmpl->chain );
353+ auto & data = data_smpl;
354+
355+ data = llama_perf_sampler (gsmpl->chain );
356+
357+ // note: the sampling time includes the samplers time + extra time spent in common/sampling
358+ LOG_INF (" %s: sampling time = %10.2f ms\n " , __func__, t_sampling_ms);
359+ LOG_INF (" %s: samplers time = %10.2f ms / %5d tokens\n " , __func__, data.t_sample_ms , data.n_sample );
332360 }
361+
333362 if (ctx) {
334- llama_perf_context_print (ctx);
363+ auto & data = data_ctx;
364+
365+ data = llama_perf_context (ctx);
366+
367+ const double t_end_ms = 1e-3 * ggml_time_us ();
368+
369+ const double t_total_ms = t_end_ms - data.t_start_ms ;
370+ const double t_unacc_ms = t_total_ms - (t_sampling_ms + data.t_p_eval_ms + data.t_eval_ms );
371+ const double t_unacc_pc = 100.0 * t_unacc_ms / t_total_ms;
372+
373+ LOG_INF (" %s: load time = %10.2f ms\n " , __func__, data.t_load_ms );
374+ LOG_INF (" %s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n " ,
375+ __func__, data.t_p_eval_ms , data.n_p_eval , data.t_p_eval_ms / data.n_p_eval , 1e3 / data.t_p_eval_ms * data.n_p_eval );
376+ LOG_INF (" %s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n " ,
377+ __func__, data.t_eval_ms , data.n_eval , data.t_eval_ms / data.n_eval , 1e3 / data.t_eval_ms * data.n_eval );
378+ LOG_INF (" %s: total time = %10.2f ms / %5d tokens\n " , __func__, (t_end_ms - data.t_start_ms ), (data.n_p_eval + data.n_eval ));
379+ LOG_INF (" %s: unaccounted time = %10.2f ms / %5.1f %% (total - sampling - prompt eval - eval) / (total)\n " , __func__, t_unacc_ms, t_unacc_pc);
380+ LOG_INF (" %s: graphs reused = %10d\n " , __func__, data.n_reused );
381+
335382 llama_memory_breakdown_print (ctx);
336383 }
337384}
338385
339386llama_token common_sampler_sample (struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
387+ llama_synchronize (ctx);
388+
389+ // start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
390+ const auto tm = gsmpl->tm ();
391+
340392 gsmpl->set_logits (ctx, idx);
341393
342394 auto & grmr = gsmpl->grmr ;
@@ -428,6 +480,8 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
428480// helpers
429481
430482llama_token_data_array * common_sampler_get_candidates (struct common_sampler * gsmpl, bool do_sort) {
483+ const auto tm = gsmpl->tm ();
484+
431485 auto * res = &gsmpl->cur_p ;
432486
433487 if (do_sort && !res->sorted ) {
0 commit comments