@@ -20,6 +20,7 @@ import (
2020 "context"
2121 "fmt"
2222 "net/url"
23+ "os/exec"
2324 "strconv"
2425 "strings"
2526 "time"
@@ -718,7 +719,7 @@ func (d *Driver) copyVolume(ctx context.Context, req *csi.CreateVolumeRequest, a
718719 vs := req .VolumeContentSource
719720 switch vs .Type .(type ) {
720721 case * csi.VolumeContentSource_Snapshot :
721- return status . Errorf ( codes . InvalidArgument , "copy volume from volumeSnapshot is not supported" )
722+ return d . restoreSnapshot ( ctx , req , accountKey , shareOptions , storageEndpointSuffix )
722723 case * csi.VolumeContentSource_Volume :
723724 return d .copyFileShare (ctx , req , accountKey , shareOptions , storageEndpointSuffix )
724725 default :
@@ -1072,6 +1073,65 @@ func (d *Driver) ListSnapshots(ctx context.Context, req *csi.ListSnapshotsReques
10721073 return nil , status .Error (codes .Unimplemented , "" )
10731074}
10741075
1076+ // restoreSnapshot restores from a snapshot
1077+ func (d * Driver ) restoreSnapshot (ctx context.Context , req * csi.CreateVolumeRequest , accountKey string , shareOptions * fileclient.ShareOptions , storageEndpointSuffix string ) error {
1078+ if shareOptions .Protocol == storage .EnabledProtocolsNFS {
1079+ return fmt .Errorf ("protocol nfs is not supported for snapshot restore" )
1080+ }
1081+ var sourceSnapshotID string
1082+ if req .GetVolumeContentSource () != nil && req .GetVolumeContentSource ().GetSnapshot () != nil {
1083+ sourceSnapshotID = req .GetVolumeContentSource ().GetSnapshot ().GetSnapshotId ()
1084+ }
1085+ resourceGroupName , accountName , srcFileShareName , _ , _ , _ , err := GetFileShareInfo (sourceSnapshotID ) //nolint:dogsled
1086+ if err != nil {
1087+ return status .Error (codes .NotFound , err .Error ())
1088+ }
1089+ dstFileShareName := shareOptions .Name
1090+ if srcFileShareName == "" || dstFileShareName == "" {
1091+ return fmt .Errorf ("srcFileShareName(%s) or dstFileShareName(%s) is empty" , srcFileShareName , dstFileShareName )
1092+ }
1093+
1094+ klog .V (2 ).Infof ("generate sas token for account(%s)" , accountName )
1095+ accountSasToken , genErr := generateSASToken (accountName , accountKey , storageEndpointSuffix , d .sasTokenExpirationMinutes )
1096+ if genErr != nil {
1097+ return genErr
1098+ }
1099+
1100+ timeAfter := time .After (waitForCopyTimeout )
1101+ timeTick := time .Tick (waitForCopyInterval )
1102+ srcPath := fmt .Sprintf ("https://%s.file.%s/%s%s" , accountName , storageEndpointSuffix , srcFileShareName , accountSasToken )
1103+ dstPath := fmt .Sprintf ("https://%s.file.%s/%s%s" , accountName , storageEndpointSuffix , dstFileShareName , accountSasToken )
1104+
1105+ jobState , percent , err := getAzcopyJob (dstFileShareName )
1106+ klog .V (2 ).Infof ("azcopy job status: %s, copy percent: %s%%, error: %v" , jobState , percent , err )
1107+ if jobState == AzcopyJobError || jobState == AzcopyJobCompleted {
1108+ return err
1109+ }
1110+ klog .V (2 ).Infof ("begin to copy fileshare %s to %s" , srcFileShareName , dstFileShareName )
1111+ for {
1112+ select {
1113+ case <- timeTick :
1114+ jobState , percent , err := getAzcopyJob (dstFileShareName )
1115+ klog .V (2 ).Infof ("azcopy job status: %s, copy percent: %s%%, error: %v" , jobState , percent , err )
1116+ switch jobState {
1117+ case AzcopyJobError , AzcopyJobCompleted :
1118+ return err
1119+ case AzcopyJobNotFound :
1120+ klog .V (2 ).Infof ("copy fileshare %s to %s" , srcFileShareName , dstFileShareName )
1121+ out , copyErr := exec .Command ("azcopy" , "copy" , srcPath , dstPath , "--recursive" , "--check-length=false" ).CombinedOutput ()
1122+ if copyErr != nil {
1123+ klog .Warningf ("CopyFileShare(%s, %s, %s) failed with error(%v): %v" , resourceGroupName , accountName , dstFileShareName , copyErr , string (out ))
1124+ } else {
1125+ klog .V (2 ).Infof ("copied fileshare %s to %s successfully" , srcFileShareName , dstFileShareName )
1126+ }
1127+ return copyErr
1128+ }
1129+ case <- timeAfter :
1130+ return fmt .Errorf ("timeout waiting for copy fileshare %s to %s succeed" , srcFileShareName , dstFileShareName )
1131+ }
1132+ }
1133+ }
1134+
10751135// ControllerExpandVolume controller expand volume
10761136func (d * Driver ) ControllerExpandVolume (ctx context.Context , req * csi.ControllerExpandVolumeRequest ) (* csi.ControllerExpandVolumeResponse , error ) {
10771137 volumeID := req .GetVolumeId ()
0 commit comments