Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.

Commit 46f78c1

Browse files
KSGulinbfineran
andauthored
Upgrade to transformers release V4.23.1 (#62)
* Update trainer and model flows to accommodate sparseml Disable FP16 on QAT start (#12) * Override LRScheduler when using LRModifiers * Disable FP16 on QAT start * keep wrapped scaler object for training after disabling Using QATMatMul in DistilBERT model class (#41) Removed double quantization of output of context layer. (#45) Fix DataParallel validation forward signatures (#47) * Fix: DataParallel validation forward signatures * Update: generalize forward_fn selection Best model after epoch (#46) fix sclaer check for non fp16 mode in trainer (#38) Mobilebert QAT (#55) * Remove duplicate quantization of vocabulary. enable a QATWrapper for non-parameterized matmuls in BERT self attention (#9) * Utils and auxillary changes update Zoo stub loading for SparseZoo 1.1 refactor (#54) add flag to signal NM integration is active (#32) Add recipe_name to file names * Fix errors introduced in manual cherry-pick upgrade Co-authored-by: Benjamin Fineran <[email protected]>
1 parent 53c407d commit 46f78c1

File tree

6 files changed

+133
-18
lines changed

6 files changed

+133
-18
lines changed

src/transformers/hf_argparser.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,16 @@
2121
from inspect import isclass
2222
from pathlib import Path
2323
from typing import Any, Dict, Iterable, NewType, Optional, Tuple, Union, get_type_hints
24-
24+
import os
2525
import yaml
2626

27+
from sparsezoo import Model
28+
29+
from .utils.logging import get_logger
30+
31+
32+
logger = get_logger(__name__)
33+
2734

2835
DataClass = NewType("DataClass", Any)
2936
DataClassType = NewType("DataClassType", Any)
@@ -229,12 +236,17 @@ def parse_args_into_dataclasses(
229236
# additional namespace.
230237
outputs.append(namespace)
231238
if return_remaining_strings:
232-
return (*outputs, remaining_args)
239+
return tuple(
240+
*[_download_dataclass_zoo_stub_files(output) for output in outputs],
241+
remaining_args,
242+
)
233243
else:
234244
if remaining_args:
235245
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}")
236246

237-
return (*outputs,)
247+
return tuple(
248+
[_download_dataclass_zoo_stub_files(output) for output in outputs]
249+
)
238250

239251
def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
240252
"""
@@ -262,7 +274,9 @@ def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tu
262274
outputs.append(obj)
263275
if not allow_extra_keys and unused_keys:
264276
raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}")
265-
return tuple(outputs)
277+
return tuple(
278+
[_download_dataclass_zoo_stub_files(output) for output in outputs]
279+
)
266280

267281
def parse_json_file(self, json_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
268282
"""
@@ -305,3 +319,28 @@ def parse_yaml_file(self, yaml_file: str, allow_extra_keys: bool = False) -> Tup
305319
"""
306320
outputs = self.parse_dict(yaml.safe_load(Path(yaml_file).read_text()), allow_extra_keys=allow_extra_keys)
307321
return tuple(outputs)
322+
323+
def _download_dataclass_zoo_stub_files(data_class: DataClass):
324+
for name, val in data_class.__dict__.items():
325+
if not isinstance(val, str) or "recipe" in name or not val.startswith("zoo:"):
326+
continue
327+
328+
logger.info(f"Downloading framework files for SparseZoo stub: {val}")
329+
330+
zoo_model = Model(val)
331+
framework_file_paths = [file.path for file in zoo_model.training.default.files]
332+
assert framework_file_paths, "Unable to download any framework files for SparseZoo stub {val}"
333+
framework_file_names = [os.path.basename(path) for path in framework_file_paths]
334+
if "pytorch_model.bin" not in framework_file_names or ("config.json" not in framework_file_names):
335+
raise RuntimeError(
336+
"Unable to find 'pytorch_model.bin' and 'config.json' in framework "
337+
f"files downloaded from {val}. Found {framework_file_names}. Check "
338+
"if the given stub is for a transformers repo model"
339+
)
340+
framework_dir_path = Path(framework_file_paths[0]).parent.absolute()
341+
342+
logger.info(f"Overwriting argument {name} to downloaded {framework_dir_path}")
343+
344+
data_class.__dict__[name] = str(framework_dir_path)
345+
346+
return data_class

src/transformers/models/distilbert/modeling_distilbert.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,38 @@ def _create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor):
8989
out.detach_()
9090

9191

92+
class QATAttentionScores(nn.Module):
93+
def __init__(self):
94+
super().__init__()
95+
96+
# behaves like normal torch.matmul unless a SparseML QuantizationModifier
97+
# is initialized
98+
self.wrap_qat = True
99+
self.qat_wrapper_kwargs = {
100+
"num_inputs": 2,
101+
"input_qconfigs": ["asymmetric", "symmetric"],
102+
}
103+
104+
def forward(self, a: torch.Tensor, b: torch.Tensor):
105+
return torch.matmul(a, b)
106+
107+
class QATContextLayer(nn.Module):
108+
def __init__(self):
109+
super().__init__()
110+
111+
# behaves like normal torch.matmul unless a SparseML QuantizationModifier
112+
# is initialized
113+
self.wrap_qat = True
114+
self.qat_wrapper_kwargs = {
115+
"num_inputs": 2,
116+
"num_outputs": 0,
117+
"input_qconfigs": ["asymmetric", "symmetric"],
118+
}
119+
120+
def forward(self, a: torch.Tensor, b: torch.Tensor):
121+
return torch.matmul(a, b)
122+
123+
92124
class Embeddings(nn.Module):
93125
def __init__(self, config: PretrainedConfig):
94126
super().__init__()
@@ -150,6 +182,11 @@ def __init__(self, config: PretrainedConfig):
150182

151183
self.pruned_heads: Set[int] = set()
152184

185+
# non-parameterized matmuls will behave as normal torch.matmul ops unless
186+
# Quantization-Aware-Training is invoked
187+
self.attention_scores_matmul = QATAttentionScores()
188+
self.context_layer_matmul = QATContextLayer()
189+
153190
def prune_heads(self, heads: List[int]):
154191
attention_head_size = self.dim // self.n_heads
155192
if len(heads) == 0:
@@ -207,7 +244,7 @@ def unshape(x: torch.Tensor) -> torch.Tensor:
207244
v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
208245

209246
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
210-
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
247+
scores = self.attention_scores_matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
211248
mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
212249
scores = scores.masked_fill(
213250
mask, torch.tensor(torch.finfo(scores.dtype).min)
@@ -220,7 +257,7 @@ def unshape(x: torch.Tensor) -> torch.Tensor:
220257
if head_mask is not None:
221258
weights = weights * head_mask
222259

223-
context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head)
260+
context = self.context_layer_matmul(weights, v) # (bs, n_heads, q_length, dim_per_head)
224261
context = unshape(context) # (bs, q_length, dim)
225262
context = self.out_lin(context) # (bs, q_length, dim)
226263

@@ -645,7 +682,6 @@ def forward(
645682
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
646683
"""
647684
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
648-
649685
dlbrt_output = self.distilbert(
650686
input_ids=input_ids,
651687
attention_mask=attention_mask,

src/transformers/models/mobilebert/modeling_mobilebert.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,23 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
170170

171171
NORM2FN = {"layer_norm": nn.LayerNorm, "no_norm": NoNorm}
172172

173+
class QATEmbeddingTransformation(nn.Module):
174+
def __init__(self, embedded_input_size, hidden_size):
175+
super().__init__()
176+
177+
# Behaves like normal Linear module unless a SparseML QuantizationModifier
178+
# is initialized.
179+
# When initialized, does not quantize inputs.
180+
# Only weights are quantized (inputs come quantized from embeddings)
181+
self.linear = nn.Linear(embedded_input_size, hidden_size)
182+
self.wrap_qat = True
183+
self.qat_wrapper_kwargs = {
184+
"num_inputs": 0,
185+
"num_outputs": 1,
186+
}
187+
188+
def forward(self, x: torch.Tensor):
189+
return self.linear(x)
173190

174191
class MobileBertEmbeddings(nn.Module):
175192
"""Construct the embeddings from word, position and token_type embeddings."""
@@ -186,7 +203,7 @@ def __init__(self, config):
186203

187204
embed_dim_multiplier = 3 if self.trigram_input else 1
188205
embedded_input_size = self.embedding_size * embed_dim_multiplier
189-
self.embedding_transformation = nn.Linear(embedded_input_size, config.hidden_size)
206+
self.embedding_transformation = QATEmbeddingTransformation(embedded_input_size, config.hidden_size)
190207

191208
self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size)
192209
self.dropout = nn.Dropout(config.hidden_dropout_prob)

src/transformers/trainer.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1687,6 +1687,10 @@ def _inner_training_loop(
16871687
_ = list(train_dataloader.sampler)
16881688

16891689
for epoch in range(epochs_trained, num_train_epochs):
1690+
if self.use_cuda_amp and hasattr(self, "qat_active") and callable(self.qat_active) and self.qat_active(epoch):
1691+
logger.info("entering QAT phase, disabling FP16 training")
1692+
self.scaler._enabled = False
1693+
16901694
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
16911695
train_dataloader.sampler.set_epoch(epoch)
16921696
elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard):
@@ -2167,7 +2171,12 @@ def _save_checkpoint(self, model, trial, metrics=None):
21672171
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
21682172

21692173
# Determine the new best metric / best model checkpoint
2170-
if metrics is not None and self.args.metric_for_best_model is not None:
2174+
if (
2175+
metrics is not None
2176+
and self.args.metric_for_best_model is not None
2177+
and self.args.best_model_after_epoch is not None
2178+
and self.state.epoch > self.args.best_model_after_epoch
2179+
):
21712180
metric_to_check = self.args.metric_for_best_model
21722181
if not metric_to_check.startswith("eval_"):
21732182
metric_to_check = f"eval_{metric_to_check}"
@@ -2421,14 +2430,14 @@ def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[s
24212430

24222431
return inputs
24232432

2424-
def compute_loss_context_manager(self):
2433+
def compute_loss_context_manager(self, enabled):
24252434
"""
24262435
A helper wrapper to group together context managers.
24272436
"""
24282437
return ContextManagers(
24292438
[
24302439
self.torchdynamo_smart_context_manager(),
2431-
self.autocast_smart_context_manager(),
2440+
self.autocast_smart_context_manager(enabled=enabled),
24322441
]
24332442
)
24342443

@@ -2438,7 +2447,7 @@ def torchdynamo_smart_context_manager(self):
24382447
"""
24392448
return self.ctx_manager_torchdynamo
24402449

2441-
def autocast_smart_context_manager(self):
2450+
def autocast_smart_context_manager(self, enabled):
24422451
"""
24432452
A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
24442453
arguments, depending on the situation.
@@ -2448,10 +2457,10 @@ def autocast_smart_context_manager(self):
24482457
ctx_manager = (
24492458
torch.cpu.amp.autocast(dtype=self.amp_dtype)
24502459
if self.use_cpu_amp
2451-
else torch.cuda.amp.autocast(dtype=self.amp_dtype)
2460+
else torch.cuda.amp.autocast(dtype=self.amp_dtype, enabled=enabled)
24522461
)
24532462
else:
2454-
ctx_manager = torch.cuda.amp.autocast()
2463+
ctx_manager = torch.cuda.amp.autocast(enabled=enabled)
24552464
else:
24562465
ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress()
24572466

@@ -2482,7 +2491,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
24822491
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
24832492
return loss_mb.reduce_mean().detach().to(self.args.device)
24842493

2485-
with self.compute_loss_context_manager():
2494+
with self.compute_loss_context_manager(enabled=hasattr(self, "scaler") and self.scaler.is_enabled()):
24862495
loss = self.compute_loss(model, inputs)
24872496

24882497
if self.args.n_gpu > 1:
@@ -2939,7 +2948,14 @@ def evaluation_loop(
29392948

29402949
observed_num_examples = 0
29412950
# Main evaluation loop
2951+
module_forward_fn = model.module.forward if isinstance(model, nn.DataParallel) else model.forward
29422952
for step, inputs in enumerate(dataloader):
2953+
inputs = {
2954+
k: inputs[k]
2955+
for k in inputs
2956+
if k in list(inspect.signature(module_forward_fn).parameters.keys())
2957+
}
2958+
29432959
# Update the observed num examples
29442960
observed_batch_size = find_batch_size(inputs)
29452961
if observed_batch_size is not None:
@@ -3191,7 +3207,9 @@ def prediction_step(
31913207
logits = smp_nested_concat(logits_mb)
31923208
else:
31933209
if has_labels:
3194-
with self.compute_loss_context_manager():
3210+
with self.compute_loss_context_manager(
3211+
enabled=hasattr(self, "scaler") and self.scaler.is_enabled()
3212+
):
31953213
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
31963214
loss = loss.mean().detach()
31973215

@@ -3201,7 +3219,9 @@ def prediction_step(
32013219
logits = outputs[1:]
32023220
else:
32033221
loss = None
3204-
with self.compute_loss_context_manager():
3222+
with self.compute_loss_context_manager(
3223+
enabled=hasattr(self, "scaler") and self.scaler.is_enabled()
3224+
):
32053225
outputs = model(**inputs)
32063226
if isinstance(outputs, dict):
32073227
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)

src/transformers/trainer_seq2seq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def prediction_step(
208208
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1)
209209

210210
with torch.no_grad():
211-
with self.compute_loss_context_manager():
211+
with self.compute_loss_context_manager(enabled=hasattr(self, "scaler") and self.scaler.is_enabled()):
212212
outputs = model(**inputs)
213213
if has_labels:
214214
if self.label_smoother is not None:

src/transformers/utils/import_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,9 @@ class _LazyModule(ModuleType):
10161016
Module class that surfaces all objects but only performs associated imports when the objects are requested.
10171017
"""
10181018

1019+
# flag to signal NM integration is active
1020+
NM_INTEGRATED = True
1021+
10191022
# Very heavily inspired by optuna.integration._IntegrationModule
10201023
# https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
10211024
def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None):

0 commit comments

Comments
 (0)