Skip to content

Commit 6c353dc

Browse files
committed
cleanup useless code
1 parent a1cf66e commit 6c353dc

File tree

1 file changed

+10
-104
lines changed

1 file changed

+10
-104
lines changed

llama.cpp

Lines changed: 10 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,7 +1221,6 @@ static bool llama_kv_cache_init(
12211221
return false;
12221222
}
12231223

1224-
fprintf(stderr, "n_embed: %d n_layer: %d n_ctx: %d n_elements: %d\n", n_embd, n_layer, n_ctx, n_elements);
12251224
cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
12261225
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
12271226
ggml_set_name(cache.k, "cache_k");
@@ -3447,18 +3446,12 @@ static struct ggml_cgraph * llm_build_starcoder(
34473446
const int64_t n_layer = hparams.n_layer;
34483447
const int64_t n_ctx = hparams.n_ctx;
34493448
const int64_t n_head = hparams.n_head;
3450-
const int64_t n_head_kv = hparams.n_head_kv;
34513449
const int64_t n_embd_head = hparams.n_embd_head();
3452-
const int64_t n_embd_gqa = hparams.n_embd_gqa();
34533450

34543451
GGML_ASSERT(n_embd_head == hparams.n_rot);
34553452

3456-
const float freq_base = hparams.rope_freq_base;
3457-
const float freq_scale = hparams.rope_freq_scale;
34583453
const float norm_eps = hparams.f_norm_eps;
34593454

3460-
const int n_gpu_layers = model.n_gpu_layers;
3461-
34623455
auto & buf_compute = lctx.buf_compute;
34633456

34643457
struct ggml_init_params params = {
@@ -3517,56 +3510,18 @@ static struct ggml_cgraph * llm_build_starcoder(
35173510

35183511
inpL = ggml_add(ctx0, token, position);
35193512

3520-
const int i_gpu_start = n_layer - n_gpu_layers;
3521-
(void) i_gpu_start;
3522-
3523-
// offload functions set the tensor output backend to GPU
3524-
// tensors are GPU-accelerated if any input or the output has been offloaded
3525-
//
3526-
// with the low VRAM option VRAM scratch is disabled in llama_load_model_internal
3527-
// in that case ggml_cuda_assign_buffers has no effect
3528-
offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
3529-
offload_func_t offload_func_kq = llama_nop;
3530-
offload_func_t offload_func_v = llama_nop;
3531-
3532-
#ifdef GGML_USE_CUBLAS
3533-
if (n_gpu_layers > n_layer) {
3534-
offload_func_nr = ggml_cuda_assign_buffers_no_alloc;
3535-
}
3536-
if (n_gpu_layers > n_layer + 1) {
3537-
offload_func_v = ggml_cuda_assign_buffers_no_alloc;
3538-
}
3539-
if (n_gpu_layers > n_layer + 2) {
3540-
offload_func_kq = ggml_cuda_assign_buffers_no_alloc;
3541-
}
3542-
#endif // GGML_USE_CUBLAS
3543-
3544-
#define PRINT_SHAPE(x) fprintf(stderr, "%d %s: (%s)\n", __LINE__, #x, llama_format_tensor_shape(x).c_str())
35453513
for (int il = 0; il < n_layer; ++il) {
3546-
offload_func_t offload_func = llama_nop;
3547-
3548-
#ifdef GGML_USE_CUBLAS
3549-
if (il >= i_gpu_start) {
3550-
offload_func = ggml_cuda_assign_buffers_no_alloc;
3551-
}
3552-
#endif // GGML_USE_CUBLAS
3553-
35543514
{
35553515
// Norm
35563516
cur = ggml_norm(ctx0, inpL, norm_eps);
3557-
35583517
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].attn_norm), model.layers[il].attn_norm_b);
35593518

35603519
}
35613520

3562-
{
3563-
// Compute QKV
3564-
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
3565-
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
3566-
}
3567-
35683521
{
35693522
// Self Attention
3523+
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wqkv, cur), model.layers[il].bqkv);
3524+
35703525
struct ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd);
35713526
struct ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1*sizeof(float)*n_embd);
35723527
struct ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2*sizeof(float)*n_embd);
@@ -3580,39 +3535,23 @@ static struct ggml_cgraph * llm_build_starcoder(
35803535
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
35813536
}
35823537

3583-
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
3584-
// [64, N, 12]
35853538
struct ggml_tensor * Q =
35863539
ggml_permute(ctx0,
35873540
ggml_cpy(ctx0,
35883541
Qcur,
35893542
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
35903543
0, 2, 1, 3);
35913544

3592-
// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
3593-
// [64, n_past + N, 12]
35943545
struct ggml_tensor * K =
35953546
ggml_permute(ctx0,
35963547
ggml_reshape_3d(ctx0,
35973548
ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
35983549
n_embd/n_head, n_head, n_past + N),
35993550
0, 2, 1, 3); //TODO: need to be tiled
36003551

3601-
// GG: flash attention
3602-
//struct ggml_tensor * V =
3603-
// ggml_cpy(ctx0,
3604-
// ggml_permute(ctx0,
3605-
// ggml_reshape_3d(ctx0,
3606-
// ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.v)*n_embd),
3607-
// n_embd/n_head, n_head, n_past + N),
3608-
// 1, 2, 0, 3),
3609-
// ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head));
3610-
3611-
//struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true);
3612-
36133552
// K * Q
36143553
// [n_past + N, N, 12]
3615-
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); //TODO: check if it broadcasts
3554+
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
36163555

36173556
// KQ_scaled = KQ / sqrt(n_embd/n_head)
36183557
// [n_past + N, N, 12]
@@ -3649,18 +3588,13 @@ static struct ggml_cgraph * llm_build_starcoder(
36493588
// [64, 12, N]
36503589
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
36513590

3652-
// cur = KQV_merged.contiguous().view(n_embd, N)
3653-
// [768, N]
36543591
cur = ggml_cpy(ctx0,
36553592
KQV_merged,
36563593
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
36573594
}
36583595

36593596
// Projection
3660-
{
3661-
cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
3662-
cur = ggml_add(ctx0, cur, model.layers[il].bo);
3663-
}
3597+
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wo, cur), model.layers[il].bo);
36643598

36653599
// add the input
36663600
cur = ggml_add(ctx0, cur, inpL);
@@ -3678,54 +3612,26 @@ static struct ggml_cgraph * llm_build_starcoder(
36783612
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ffn_norm), model.layers[il].ffn_norm_b);
36793613
}
36803614

3681-
// fully connected
3682-
// [3072, 768] - model.layers[il].c_mlp_fc_w
3683-
// [3072, 1] - model.layers[il].c_mlp_fc_b
3684-
// [ 768, N] - cur (in)
3685-
// [3072, N] - cur (out)
3686-
//
3687-
// cur = fc_w*cur + fc_b
3688-
// [3072, N]
3689-
cur = ggml_mul_mat(ctx0,
3690-
model.layers[il].w3,
3691-
cur);
3692-
3693-
cur = ggml_add(ctx0, cur, model.layers[il].b3);
3615+
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w3, cur), model.layers[il].b3);
36943616

36953617
// GELU activation
3696-
// [3072, N]
36973618
cur = ggml_gelu(ctx0, cur);
36983619

36993620
// projection
3700-
// [ 768, 3072] - model.layers[il].c_mlp_proj_w
3701-
// [ 768, 1] - model.layers[il].c_mlp_proj_b
3702-
// [3072, N] - cur (in)
3703-
// [ 768, N] - cur (out)
3704-
//
3705-
// cur = proj_w*cur + proj_b
3706-
// [768, N]
3707-
cur = ggml_mul_mat(ctx0,
3708-
model.layers[il].w2,
3709-
cur);
3710-
3711-
cur = ggml_add(ctx0, cur, model.layers[il].b2);
3621+
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w2, cur), model.layers[il].b2);
37123622
}
37133623

37143624
inpL = ggml_add(ctx0, cur, inpFF);
37153625
}
37163626

37173627
// norm
37183628
{
3719-
// [ 768, N]
3720-
inpL = ggml_norm(ctx0, inpL, norm_eps);
3721-
3722-
// inpL = ln_f_g*inpL + ln_f_b
3723-
// [ 768, N]
3724-
inpL = ggml_add(ctx0, ggml_mul(ctx0, inpL, model.output_norm), model.output_norm_b);
3629+
cur = ggml_norm(ctx0, inpL, norm_eps);
3630+
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.output_norm), model.output_norm_b);
37253631
}
3726-
ggml_set_name(inpL, "result_norm");
3632+
ggml_set_name(cur, "result_norm");
37273633

3728-
cur = ggml_mul_mat(ctx0, model.output, inpL);
3634+
cur = ggml_mul_mat(ctx0, model.output, cur);
37293635
ggml_set_name(cur, "result_output");
37303636

37313637
ggml_build_forward_expand(gf, cur);

0 commit comments

Comments
 (0)