Skip to content

Does DeepSpeed's Pipeline-Parallelism optimizer supports skip connections? #932

@RoyMahlab

Description

@RoyMahlab

In your example you convert the AlexNet into a list of layers:

def join_layers(vision_model):

    layers = [
        *vision_model.features,
        vision_model.avgpool,
        lambda x: torch.flatten(x, 1),
        *vision_model.classifier,
    ]
    return layers

which is later inserted to PipelineModule

net = AlexNet(num_classes=10)
net = PipelineModule(layers=join_layers(net),
                     loss_fn=torch.nn.CrossEntropyLoss(),
                     num_stages=args.pipeline_parallel_size,
                     partition_method=part,
                     activation_checkpoint_interval=0)

This seems to run-over the forward module that you built in your AlexNet module, which makes me wonder about the possibility of having skip-connections in my module while using DeepSpeed's Pipeline-Parallelism optimizer.

Many thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions