|
1 | | -#define LLAMA_API_INTERNAL |
2 | 1 | #include "sampling.h" |
| 2 | + |
3 | 3 | #include <random> |
4 | 4 |
|
5 | | -struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { |
| 5 | +struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_context * ctx, llama_seq_id seq_id) { |
6 | 6 | struct llama_sampling_context * result = new llama_sampling_context(); |
7 | 7 |
|
8 | 8 | result->params = params; |
| 9 | + result->seq_id = seq_id; |
| 10 | + result->ctx = ctx; |
9 | 11 | result->grammar = nullptr; |
10 | 12 |
|
11 | 13 | // if there is a grammar, parse it |
@@ -81,7 +83,7 @@ void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t s |
81 | 83 | if (seed == LLAMA_DEFAULT_SEED) { |
82 | 84 | seed = std::random_device{}(); |
83 | 85 | } |
84 | | - ctx->rng.seed(seed); |
| 86 | + llama_set_rng_seed_seq(ctx->ctx, seed, ctx->seq_id); |
85 | 87 | } |
86 | 88 |
|
87 | 89 | void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) { |
@@ -271,10 +273,10 @@ static llama_token llama_sampling_sample_impl( |
271 | 273 | bool is_resampling) { |
272 | 274 | const llama_sampling_params & params = ctx_sampling->params; |
273 | 275 |
|
274 | | - const float temp = params.temp; |
275 | | - const int mirostat = params.mirostat; |
276 | | - const float mirostat_tau = params.mirostat_tau; |
277 | | - const float mirostat_eta = params.mirostat_eta; |
| 276 | + const float temp = params.temp; |
| 277 | + const int mirostat = params.mirostat; |
| 278 | + const float mirostat_tau = params.mirostat_tau; |
| 279 | + const float mirostat_eta = params.mirostat_eta; |
278 | 280 |
|
279 | 281 | std::vector<float> original_logits; |
280 | 282 | auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits); |
@@ -304,7 +306,7 @@ static llama_token llama_sampling_sample_impl( |
304 | 306 |
|
305 | 307 | sampler_queue(ctx_main, params, cur_p, min_keep); |
306 | 308 |
|
307 | | - id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng); |
| 309 | + id = llama_sample_token_seq(ctx_main, &cur_p, ctx_sampling->seq_id); |
308 | 310 |
|
309 | 311 | //{ |
310 | 312 | // const int n_top = 10; |
|
0 commit comments