Skip to content

First Conv2D layer in model cannot be pruned with structured sparsity (sparsity_m_by_n) #1182

@anasvaf

Description

@anasvaf

Bug Description
When applying structured pruning with sparsity_m_by_n parameter to a model containing multiple Conv2D layers, the first Conv2D layer consistently fails to be pruned while identical subsequent Conv2D layers are pruned correctly.

Environment
TensorFlow version: model trained with 2.19/keras v3 but recreating the model with tf_keras and v2 compatibility and loading just the pre-trained weights
TensorFlow Model Optimization version: 0.8.0
Python version: 3.10
Operating System: Linux

Expected Behavior
All Conv2D layers with identical pruning configuration should be pruned equally, achieving the target sparsity with the specified structured pattern (e.g., 2:4 sparsity).

Actual Behavior
First Conv2D layer: 0% sparsity (no pruning applied)
Subsequent Conv2D layers: 50% sparsity with correct 2:4 structured pattern

Example Code

import tensorflow as tf
import tf_keras as keras
import tensorflow_model_optimization as tfmot
import numpy as np

print(f"TensorFlow version: {tf.__version__}")
print(f"TFMOT version: {tfmot.__version__}")

def reproduce_first_layer_pruning_bug():
    """Demonstrates that the first Conv2D layer in a model cannot be pruned."""
    
    # Create a simple model with two identical Conv2D layers
    inputs = keras.layers.Input(shape=(None, 128, 2))
    x = keras.layers.Conv2D(16, (1, 8), strides=(1, 2), padding='same', name='conv2d_first')(inputs)
    x = keras.layers.Conv2D(16, (1, 8), strides=(1, 2), padding='same', name='conv2d_second')(x)
    model = keras.Model(inputs=inputs, outputs=x)
    
    # Apply identical pruning to both layers
    def apply_pruning(layer):
        if isinstance(layer, keras.layers.Conv2D):
            return tfmot.sparsity.keras.prune_low_magnitude(
                layer,
                pruning_schedule=tfmot.sparsity.keras.ConstantSparsity(0.5, 0),
                sparsity_m_by_n=(2, 4)
            )
        return layer
    
    pruned_model = keras.models.clone_model(model, clone_function=apply_pruning)
    pruned_model.compile(optimizer='adam', loss='mse')
    
    # Train with dummy data
    dummy_x = np.random.random((32, 100, 128, 2))
    dummy_y = np.random.random((32, 100, 25, 16))
    
    pruned_model.fit(dummy_x, dummy_y, epochs=3, 
                     callbacks=[tfmot.sparsity.keras.UpdatePruningStep()], 
                     verbose=0)
    
    # Check sparsity of both layers
    stripped_model = tfmot.sparsity.keras.strip_pruning(pruned_model)
    
    first_layer = stripped_model.get_layer('conv2d_first')
    second_layer = stripped_model.get_layer('conv2d_second')
    
    first_sparsity = (first_layer.get_weights()[0] == 0).sum() / first_layer.get_weights()[0].size
    second_sparsity = (second_layer.get_weights()[0] == 0).sum() / second_layer.get_weights()[0].size
    
    print(f"First Conv2D layer sparsity: {first_sparsity:.2%}")
    print(f"Second Conv2D layer sparsity: {second_sparsity:.2%}")
    print(f"Bug reproduced: {first_sparsity == 0 and second_sparsity > 0}")
    
    return first_sparsity == 0 and second_sparsity > 0

if __name__ == "__main__":
    bug_reproduced = reproduce_first_layer_pruning_bug()
    if bug_reproduced:
        print("Bug is confirmed: First Conv2D layer cannot be pruned")
    else:
        print("No bug detected")

Additional Investigation

  1. The issue persists even when the first layer is tested in complete isolation (single-layer model)
  2. Different pruning schedules don't resolve the issue (ConstantSparsity, PolynomialDecay)
  3. The layer weights are normal (no zeros, good distribution, unique values)
  4. The pruning wrapper is correctly applied (visible in model summary and layer config)
  5. Manual weight copying doesn't help
  6. Recreating the layer with identical configuration doesn't help

Debugging Output Example

First layer config shows correct pruning wrapper:
'pruning_schedule': {'class_name': 'ConstantSparsity', 'config': {'target_sparsity': 0.5, 'begin_step': 0, 'end_step': -1, 'frequency': 100}}

Training monitoring shows:
Epoch 0 END: First Conv2D sparsity = 0.00%, Second Conv2D sparsity = 50.00%
Epoch 1 END: First Conv2D sparsity = 0.00%, Second Conv2D sparsity = 50.00%

Potential Root Cause
Maybe related to how the pruning mechanism handles the first layer in the model graph, possibly in the pruning callback or wrapper update logic.

Dummy Workaround

def apply_manual_2_4_pruning_to_first_layer(model, layer_name='conv2d'):
    layer = model.get_layer(layer_name)
    weights = layer.get_weights()
    kernel = weights[0]
    
    flat_weights = kernel.flatten()
    # Apply 2:4 structured pruning: keep 2 largest in each group of 4
    for i in range(0, len(flat_weights) - 3, 4):
        group = flat_weights[i:i+4]
        indices = np.argsort(np.abs(group))
        flat_weights[indices[0]] = 0  # Zero smallest
        flat_weights[indices[1]] = 0  # Zero second smallest
    
    weights[0] = flat_weights.reshape(kernel.shape)
    layer.set_weights(weights)
    
    sparsity = (weights[0] == 0).sum() / weights[0].size
    print(f"Applied manual 2:4 pruning to {layer_name}. Sparsity: {sparsity:.2%}")

The latter solution achieves a sparsity of ~50% (as intended).

Any help here is greatly appreciated.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions