-
Notifications
You must be signed in to change notification settings - Fork 329
Description
Bug Description
When applying structured pruning with sparsity_m_by_n parameter to a model containing multiple Conv2D layers, the first Conv2D layer consistently fails to be pruned while identical subsequent Conv2D layers are pruned correctly.
Environment
TensorFlow version: model trained with 2.19/keras v3 but recreating the model with tf_keras and v2 compatibility and loading just the pre-trained weights
TensorFlow Model Optimization version: 0.8.0
Python version: 3.10
Operating System: Linux
Expected Behavior
All Conv2D layers with identical pruning configuration should be pruned equally, achieving the target sparsity with the specified structured pattern (e.g., 2:4 sparsity).
Actual Behavior
First Conv2D layer: 0% sparsity (no pruning applied)
Subsequent Conv2D layers: 50% sparsity with correct 2:4 structured pattern
Example Code
import tensorflow as tf
import tf_keras as keras
import tensorflow_model_optimization as tfmot
import numpy as np
print(f"TensorFlow version: {tf.__version__}")
print(f"TFMOT version: {tfmot.__version__}")
def reproduce_first_layer_pruning_bug():
"""Demonstrates that the first Conv2D layer in a model cannot be pruned."""
# Create a simple model with two identical Conv2D layers
inputs = keras.layers.Input(shape=(None, 128, 2))
x = keras.layers.Conv2D(16, (1, 8), strides=(1, 2), padding='same', name='conv2d_first')(inputs)
x = keras.layers.Conv2D(16, (1, 8), strides=(1, 2), padding='same', name='conv2d_second')(x)
model = keras.Model(inputs=inputs, outputs=x)
# Apply identical pruning to both layers
def apply_pruning(layer):
if isinstance(layer, keras.layers.Conv2D):
return tfmot.sparsity.keras.prune_low_magnitude(
layer,
pruning_schedule=tfmot.sparsity.keras.ConstantSparsity(0.5, 0),
sparsity_m_by_n=(2, 4)
)
return layer
pruned_model = keras.models.clone_model(model, clone_function=apply_pruning)
pruned_model.compile(optimizer='adam', loss='mse')
# Train with dummy data
dummy_x = np.random.random((32, 100, 128, 2))
dummy_y = np.random.random((32, 100, 25, 16))
pruned_model.fit(dummy_x, dummy_y, epochs=3,
callbacks=[tfmot.sparsity.keras.UpdatePruningStep()],
verbose=0)
# Check sparsity of both layers
stripped_model = tfmot.sparsity.keras.strip_pruning(pruned_model)
first_layer = stripped_model.get_layer('conv2d_first')
second_layer = stripped_model.get_layer('conv2d_second')
first_sparsity = (first_layer.get_weights()[0] == 0).sum() / first_layer.get_weights()[0].size
second_sparsity = (second_layer.get_weights()[0] == 0).sum() / second_layer.get_weights()[0].size
print(f"First Conv2D layer sparsity: {first_sparsity:.2%}")
print(f"Second Conv2D layer sparsity: {second_sparsity:.2%}")
print(f"Bug reproduced: {first_sparsity == 0 and second_sparsity > 0}")
return first_sparsity == 0 and second_sparsity > 0
if __name__ == "__main__":
bug_reproduced = reproduce_first_layer_pruning_bug()
if bug_reproduced:
print("Bug is confirmed: First Conv2D layer cannot be pruned")
else:
print("No bug detected")
Additional Investigation
- The issue persists even when the first layer is tested in complete isolation (single-layer model)
- Different pruning schedules don't resolve the issue (ConstantSparsity, PolynomialDecay)
- The layer weights are normal (no zeros, good distribution, unique values)
- The pruning wrapper is correctly applied (visible in model summary and layer config)
- Manual weight copying doesn't help
- Recreating the layer with identical configuration doesn't help
Debugging Output Example
First layer config shows correct pruning wrapper:
'pruning_schedule': {'class_name': 'ConstantSparsity', 'config': {'target_sparsity': 0.5, 'begin_step': 0, 'end_step': -1, 'frequency': 100}}
Training monitoring shows:
Epoch 0 END: First Conv2D sparsity = 0.00%, Second Conv2D sparsity = 50.00%
Epoch 1 END: First Conv2D sparsity = 0.00%, Second Conv2D sparsity = 50.00%
Potential Root Cause
Maybe related to how the pruning mechanism handles the first layer in the model graph, possibly in the pruning callback or wrapper update logic.
Dummy Workaround
def apply_manual_2_4_pruning_to_first_layer(model, layer_name='conv2d'):
layer = model.get_layer(layer_name)
weights = layer.get_weights()
kernel = weights[0]
flat_weights = kernel.flatten()
# Apply 2:4 structured pruning: keep 2 largest in each group of 4
for i in range(0, len(flat_weights) - 3, 4):
group = flat_weights[i:i+4]
indices = np.argsort(np.abs(group))
flat_weights[indices[0]] = 0 # Zero smallest
flat_weights[indices[1]] = 0 # Zero second smallest
weights[0] = flat_weights.reshape(kernel.shape)
layer.set_weights(weights)
sparsity = (weights[0] == 0).sum() / weights[0].size
print(f"Applied manual 2:4 pruning to {layer_name}. Sparsity: {sparsity:.2%}")
The latter solution achieves a sparsity of ~50% (as intended).
Any help here is greatly appreciated.