Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions Microsoft.NET.Build.Containers/AuthHandshakeMessageHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ private record TokenResponse(string? token, string? access_token, int? expires_i
/// <param name="scope"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
private async Task<AuthenticationHeaderValue?> GetAuthenticationAsync(string scheme, Uri uri, string service, string? scope, CancellationToken cancellationToken)
private async Task<AuthenticationHeaderValue?> GetAuthenticationAsync(string registry, string scheme, Uri realm, string service, string? scope, CancellationToken cancellationToken)
{
// Allow overrides for auth via environment variables
string? credU = Environment.GetEnvironmentVariable(ContainerHelpers.HostObjectUser);
Expand All @@ -102,26 +102,26 @@ private record TokenResponse(string? token, string? access_token, int? expires_i
{
try
{
privateRepoCreds = await CredsProvider.GetCredentialsAsync(uri.Host);
privateRepoCreds = await CredsProvider.GetCredentialsAsync(registry);
}
catch (Exception e)
{
throw new CredentialRetrievalException(uri.Host, e);
throw new CredentialRetrievalException(registry, e);
}
}

if (scheme is "Basic")
{
var basicAuth = new AuthenticationHeaderValue("Basic", Convert.ToBase64String(Encoding.ASCII.GetBytes($"{privateRepoCreds.Username}:{privateRepoCreds.Password}")));
return AuthHeaderCache.AddOrUpdate(uri, basicAuth);
return AuthHeaderCache.AddOrUpdate(realm, basicAuth);
}
else if (scheme is "Bearer")
{
// use those creds when calling the token provider
var header = privateRepoCreds.Username == "<token>"
? new AuthenticationHeaderValue("Bearer", privateRepoCreds.Password)
: new AuthenticationHeaderValue("Basic", Convert.ToBase64String(Encoding.ASCII.GetBytes($"{privateRepoCreds.Username}:{privateRepoCreds.Password}")));
var builder = new UriBuilder(uri);
var builder = new UriBuilder(realm);
var queryDict = System.Web.HttpUtility.ParseQueryString("");
queryDict["service"] = service;
if (scope is string s)
Expand All @@ -143,7 +143,7 @@ private record TokenResponse(string? token, string? access_token, int? expires_i

// save the retrieved token in the cache
var bearerAuth = new AuthenticationHeaderValue("Bearer", token.ResolvedToken);
return AuthHeaderCache.AddOrUpdate(uri, bearerAuth);
return AuthHeaderCache.AddOrUpdate(realm, bearerAuth);
}
else
{
Expand Down Expand Up @@ -177,7 +177,7 @@ protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage
}
else if (response is { StatusCode: HttpStatusCode.Unauthorized } && TryParseAuthenticationInfo(response, out string? scheme, out AuthInfo? authInfo))
{
if (await GetAuthenticationAsync(scheme, authInfo.Realm, authInfo.Service, authInfo.Scope, cancellationToken) is AuthenticationHeaderValue authentication)
if (await GetAuthenticationAsync(request.RequestUri.Host, scheme, authInfo.Realm, authInfo.Service, authInfo.Scope, cancellationToken) is AuthenticationHeaderValue authentication)
{
request.Headers.Authorization = AuthHeaderCache.AddOrUpdate(request.RequestUri, authentication);
return await base.SendAsync(request, cancellationToken);
Expand Down