Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Conversation

Satrat
Copy link

@Satrat Satrat commented Dec 15, 2023

Initial implementation for alternating between oneshot and finetuning stages. This branch is based on two active PRs, they should be merged first:

Testing

test_multi_recipe.yaml

test_oneshot_stage:
  obcq_modifiers:
    SparseGPTModifier:
      sparsity: 0.5
      block_size: 128
      sequential_update: False
      quantize: False
      percdamp: 0.01
      prunen: 0
      prunem: 0
      targets: [
        "re:model.layers.\\d+$"
      ]
      target_ids: ["attention_mask", "position_ids"]  
test_finetune_stage:
  pruning_modifiers:
    ConstantPruningModifier:
      targets: [
        "re:.*self_attn.q_proj",
        "re:.*self_attn.k_proj",
        "re:.*self_attn.v_proj",
        "re:.*self_attn.o_proj",
        "re:.*mlp.gate_proj",
        "re:.*mlp.up_proj"
      ]
      start: 0
test_second_oneshot_stage:
  obcq_modifiers:
    SparseGPTModifier:
      sparsity: 0.7
      block_size: 128
      sequential_update: False
      quantize: False
      percdamp: 0.01
      prunen: 0
      prunem: 0
      targets: [
        "re:model.layers.\\d+$"
      ]
      target_ids: ["attention_mask", "position_ids"]  
test_second_finetune_stage:
  pruning_modifiers:
    ConstantPruningModifier:
      targets: [
        "re:.*self_attn.q_proj",
        "re:.*self_attn.k_proj",
        "re:.*self_attn.v_proj",
        "re:.*self_attn.o_proj",
        "re:.*mlp.gate_proj",
        "re:.*mlp.up_proj"
      ]
      start: 0
test_quantization_oneshot_stage:
  obcq_modifiers:
    QuantizationModifier:
      ignore:
        - LlamaRotaryEmbedding
        - LlamaRMSNorm
        - SiLUActivation
        - model.layers.0.mlp.down_proj
        - model.layers.1.mlp.down_proj
        - model.layers.2.mlp.down_proj
        - model.layers.3.mlp.down_proj
        - model.layers.4.mlp.down_proj
        - model.layers.5.mlp.down_proj
      post_oneshot_calibration: False
      scheme_overrides:
        Embedding:
          input_activations: null
          weights:
            num_bits: 8
            symmetric: False

Test script:

def run():
    from sparseml.transformers.finetune.text_generation import run_general
    
    model = "Xenova/llama2.c-stories15M"
    dataset_name = "open_platypus"
    concatenate_data = False
    run_stages = True
    output_dir = "./output_oneshot"
    overwrite_output_dir = True
    recipe = "test_multi_recipe.yaml"
    splits = {
        "calibration": "train[:50%]",
        "train": "train[50%:]"
    }

    run_general(
        model_name_or_path=model,
        dataset_name=dataset_name,
        run_stages=run_stages,
        output_dir=output_dir,
        overwrite_output_dir=overwrite_output_dir,
        recipe=recipe,
        concatenate_data = concatenate_data,
        splits = splits
    )

if __name__ == "__main__":
    run()

Known Issues/ Shortcomings

  • FSDP hasn't been tested yet
  • Training checkpoints getting overwritten during subsequent finetuning runs
  • No way to specify different numbers of epochs for each finetune stage
  • No way to specify different datset splits for different finetuning stages
  • Checkpoint loading between stages not implemented
  • Output recipe doesn't indicate what stages have been run and what hasn't
  • No unit or integration tests!

bfineran
bfineran previously approved these changes Dec 28, 2023
@Satrat Satrat marked this pull request as ready for review January 2, 2024 21:31
@Satrat Satrat requested a review from bfineran January 4, 2024 23:31
@Satrat Satrat mentioned this pull request Jan 4, 2024
Copy link
Member

@rahul-tuli rahul-tuli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Good tests

Copy link
Member

@rahul-tuli rahul-tuli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Good tests

@Satrat Satrat merged commit f592037 into main Jan 9, 2024
@Satrat Satrat deleted the alternating_flow_pt2 branch January 9, 2024 14:27
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants