Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 45 additions & 13 deletions opentreemap/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import urllib.parse
import urllib.error

from django.http import HttpResponse
from django.http import HttpResponse, RawPostDataException
from django.contrib.auth import authenticate


Expand Down Expand Up @@ -37,19 +37,32 @@ def get_signature_for_request(request, secret_key):

sign_string = '\n'.join([httpverb, hostheader, request_uri, paramstr])

# Sometimes reeading from body fails, so try reading as a file-like
# Sometimes reading from body fails, so try reading as a file-like object
try:
body_encoded = base64.b64encode(request.body)
except:
body_encoded = base64.b64encode(request.read())
body_decoded = base64.b64encode(request.body).decode()
except RawPostDataException:
body_decoded = base64.b64encode(request.read()).decode()

if body_encoded:
sign_string += body_encoded
if body_decoded:
sign_string += body_decoded

try:
binary_secret_key = secret_key.encode()
except (AttributeError, UnicodeEncodeError):
binary_secret_key = secret_key

sig = base64.b64encode(
hmac.new(secret_key, sign_string, hashlib.sha256).digest())
hmac.new(
binary_secret_key,
sign_string.encode(),
hashlib.sha256
).digest()
)

if sig is None:
return sig

return sig
return sig.decode()


def create_401unauthorized(body="Unauthorized"):
Expand All @@ -67,19 +80,38 @@ def firstmatch(regx, strg):
return m.group(1)


def decodebasicauth(strg):
if strg is None:
def split_basicauth(strg):
"""
Returns username, password from decoded,
stringified, basic auth credentials
"""
if strg is None or len(strg) == 0:
return None
else:
m = re.match(r'([^:]*)\:(.*)', base64.decodestring(strg))
m = re.match(r'([^:]*)\:(.*)', strg)
if m is not None:
return (m.group(1), m.group(2))
else:
return None


def parse_basicauth(authstr):
auth = decodebasicauth(firstmatch('Basic (.*)', authstr))
string_wrapped_binary_credentials = firstmatch("Basic (.*)", authstr)
if string_wrapped_binary_credentials is None:
return None

# tease bytes-like object out of string, i.e. "b'credentials'"
reg_exp = r"'(.*?)\'"
parsed_credentials = re.search(r"'(.*?)\'", string_wrapped_binary_credentials)
str_credentials = parsed_credentials.groups()[0]

# decode the binary encoded credentials
decoded_str_credentials = base64.decodebytes(
bytes(str_credentials, 'utf-8')
).decode()

auth = split_basicauth(decoded_str_credentials)

if auth is None:
return None
else:
Expand Down
2 changes: 1 addition & 1 deletion opentreemap/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __unicode__(self):
def create(clz, user=None):
secret_key = base64.urlsafe_b64encode(os.urandom(64))
access_key = base64.urlsafe_b64encode(uuid.uuid4().bytes)\
.replace('=', '')
.replace(b'=', b'')

return APIAccessCredential.objects.create(
user=user, access_key=access_key, secret_key=secret_key)
15 changes: 12 additions & 3 deletions opentreemap/api/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from django.shortcuts import get_object_or_404
from django.contrib.gis.geos import Point
from django.contrib.gis.measure import D
from django.http import RawPostDataException

from django_tinsel.exceptions import HttpBadRequestException

Expand Down Expand Up @@ -50,8 +51,7 @@ def plots_closest_to_point(request, instance, lat, lng):
distance = float(request.GET.get(
'distance', settings.MAP_CLICK_RADIUS))
except ValueError:
raise HttpBadRequestException(
'The distance parameter must be a number')
raise HttpBadRequestException('The distance parameter must be a number')

plots = Plot.objects.distance(point)\
.filter(instance=instance)\
Expand All @@ -72,7 +72,16 @@ def update_or_create_plot(request, instance, plot_id=None):
# The API communicates via nested dictionaries but
# our internal functions prefer dotted pairs (which
# is what inline edit form users)
request_dict = json.loads(request.body)
# Sometimes reading from body fails, so try reading as a file-like object
try:
body_decoded = request.body.decode()
except RawPostDataException:
body_decoded = request.read().decode()

if body_decoded:
request_dict = json.loads(body_decoded)
else:
request_dict = {}

data = {}

Expand Down
62 changes: 31 additions & 31 deletions opentreemap/api/tests.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-


from io import StringIO
from io import BytesIO
from json import loads, dumps
from urllib.parse import urlparse

Expand All @@ -14,6 +14,7 @@
import datetime
import psycopg2
from unittest.case import skip
from django_tinsel.exceptions import HttpBadRequestException

from django.db import connection
from django.contrib.auth.models import AnonymousUser
Expand Down Expand Up @@ -96,11 +97,11 @@ def send_json_body(url, body_object, client, method, user=None):
are posting form data, so you need to manually setup the parameters
to override that default functionality.
"""
body_string = dumps(body_object)
body_stream = StringIO(body_string)
body_binary_string = dumps(body_object).encode()
body_stream = BytesIO(body_binary_string)
parsed_url = urlparse(url)
client_params = {
'CONTENT_LENGTH': len(body_string),
'CONTENT_LENGTH': len(body_binary_string),
'CONTENT_TYPE': 'application/json',
'PATH_INFO': _get_path(parsed_url),
'QUERY_STRING': parsed_url[4],
Expand Down Expand Up @@ -375,8 +376,8 @@ def test_locations_plots_endpoint_max_plots_param_must_be_a_number(self):
API_PFX, self.instance.url_name))
self.assertEqual(response.status_code, 400)
self.assertEqual(response.content,
'The max_plots parameter must be '
'a number between 1 and 500')
b'The max_plots parameter must be '
b'a number between 1 and 500')

def test_locations_plots_max_plots_param_cannot_be_greater_than_500(self):
response = get_signed(
Expand All @@ -385,8 +386,8 @@ def test_locations_plots_max_plots_param_cannot_be_greater_than_500(self):
API_PFX, self.instance.url_name))
self.assertEqual(response.status_code, 400)
self.assertEqual(response.content,
'The max_plots parameter must be '
'a number between 1 and 500')
b'The max_plots parameter must be '
b'a number between 1 and 500')
response = get_signed(
self.client,
"%s/instance/%s/locations/0,0/plots?max_plots=500" %
Expand All @@ -402,8 +403,8 @@ def test_locations_plots_endpoint_max_plots_param_cannot_be_less_than_1(

self.assertEqual(response.status_code, 400)
self.assertEqual(response.content,
'The max_plots parameter must be a '
'number between 1 and 500')
b'The max_plots parameter must be a '
b'number between 1 and 500')
response = get_signed(
self.client,
"%s/instance/%s/locations/0,0/plots?max_plots=1" %
Expand All @@ -419,7 +420,7 @@ def test_locations_plots_endpoint_distance_param_must_be_a_number(self):

self.assertEqual(response.status_code, 400)
self.assertEqual(response.content,
'The distance parameter must be a number')
b'The distance parameter must be a number')

response = get_signed(
self.client,
Expand Down Expand Up @@ -472,7 +473,7 @@ def test_create_plot_with_tree(self):
data, self.client, self.user)

self.assertEqual(200, response.status_code,
"Create failed:" + response.content)
"Create failed:" + response.content.decode())

# Assert that a plot was added
self.assertEqual(plot_count + 1, Plot.objects.count())
Expand Down Expand Up @@ -509,7 +510,7 @@ def test_create_plot_with_invalid_tree_returns_400(self):
self.assertEqual(400,
response.status_code,
"Expected creating a million foot "
"tall tree to return 400:" + response.content)
"tall tree to return 400:" + response.content.decode())

body_dict = loads(response.content)
self.assertTrue('fieldErrors' in body_dict,
Expand Down Expand Up @@ -548,7 +549,7 @@ def test_create_plot_with_geometry(self):
data, self.client, self.user)

self.assertEqual(200, response.status_code,
"Create failed:" + response.content)
"Create failed:" + response.content.decode())

# Assert that a plot was added
self.assertEqual(plot_count + 1, Plot.objects.count())
Expand Down Expand Up @@ -1334,8 +1335,8 @@ def setUp(self):
'sort_key': 'Date'}
]
self.instance.save()
self.instance.logo.save(Instance.test_png_name,
File(open(Instance.test_png_path, 'r')))
with open(Instance.test_png_path, 'rb') as f:
self.instance.logo.save(Instance.test_png_name, f)

def test_returns_config_colors(self):
request = sign_request_as_user(make_request(), self.user)
Expand Down Expand Up @@ -1568,7 +1569,7 @@ def _test_post_photo(self, path):
self.instance.url_name,
plot_id)

with open(path) as img:
with open(path, 'rb') as img:
req = self.factory.post(
url, {'name': 'afile', 'file': img})

Expand Down Expand Up @@ -1621,7 +1622,7 @@ def testUploadPhoto(self):
url = reverse('update_user_photo', kwargs={'version': 3,
'user_id': peon.pk})

with open(TreePhotoTest.test_jpeg_path) as img:
with open(TreePhotoTest.test_jpeg_path, 'rb') as img:
req = self.factory.post(
url, {'name': 'afile', 'file': img})

Expand All @@ -1648,7 +1649,7 @@ def testCanOnlyUploadAsSelf(self):
grunt = make_user(username='grunt', password='pw')
grunt.save()

with open(TreePhotoTest.test_jpeg_path) as img:
with open(TreePhotoTest.test_jpeg_path, 'rb') as img:
req = self.factory.post(
url, {'name': 'afile', 'file': img})

Expand Down Expand Up @@ -1883,7 +1884,7 @@ def testTimestampVoidsSignature(self):
acred = APIAccessCredential.create()
url = ('http://testserver.com/test/blah?'
'timestamp=%%s&'
'k1=4&k2=a&access_key=%s' % acred.access_key)
'k1=4&k2=a&access_key=%s' % acred.access_key.decode())

curtime = datetime.datetime.now()
invalid = curtime - datetime.timedelta(minutes=100)
Expand All @@ -1902,7 +1903,7 @@ def testPOSTBodyChangesSig(self):
url = "%s/i/plots/1/tree/photo" % API_PFX

def get_sig(path):
with open(path) as img:
with open(path, 'rb') as img:
req = self.factory.post(
url, {'name': 'afile', 'file': img})

Expand Down Expand Up @@ -1949,14 +1950,14 @@ def testMalformedTimestamp(self):

url = ('http://testserver.com/test/blah?'
'timestamp=%%s&'
'k1=4&k2=a&access_key=%s' % acred.access_key)
'k1=4&k2=a&access_key=%s' % acred.access_key.decode())

req = self.sign_and_send(url % ('%sFAIL' % timestamp),
acred.secret_key)

self.assertEqual(req.status_code, 400)

req = self.sign_and_send(url % timestamp, acred.secret_key)
req = self.sign_and_send(url % timestamp, acred.secret_key.decode())

self.assertRequestWasSuccess(req)

Expand All @@ -1972,7 +1973,7 @@ def testMissingAccessKey(self):

self.assertEqual(req.status_code, 400)

req = self.sign_and_send('%s&access_key=%s' % (url, acred.access_key),
req = self.sign_and_send('%s&access_key=%s' % (url, acred.access_key.decode()),
acred.secret_key)

self.assertRequestWasSuccess(req)
Expand All @@ -1987,9 +1988,8 @@ def testAuthenticatesAsUser(self):
req = self.sign_and_send('http://testserver.com/test/blah?'
'timestamp=%s&'
'k1=4&k2=a&access_key=%s' %
(timestamp, acred.access_key),
acred.secret_key)

(timestamp, acred.access_key.decode()),
acred.secret_key.decode())
self.assertEqual(req.user.pk, peon.pk)


Expand All @@ -2003,7 +2003,7 @@ def test_401(self):
self.assertEqual(ret.status_code, 401)

def test_ok(self):
auth = base64.b64encode("jim:password")
auth = base64.b64encode(b"jim:password")
withauth = {"HTTP_AUTHORIZATION": "Basic %s" % auth}

ret = get_signed(self.client, "%s/user" % API_PFX, **withauth)
Expand All @@ -2015,14 +2015,14 @@ def test_malformed_auth(self):
ret = get_signed(self.client, "%s/user" % API_PFX, **withauth)
self.assertEqual(ret.status_code, 401)

auth = base64.b64encode("foobar")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider updating the commit message to explain why we are removing this base 64 encoding. It is my understanding that basic auth requires base 64.

In basic HTTP authentication, a request contains a header field in the form of Authorization: Basic <credentials>, where credentials is the base64 encoding of id and password joined by a single colon :.

from https://en.wikipedia.org/wiki/Basic_access_authentication

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for this detail. I took another dive in with this better baseline understanding, and get what is going on 100% now. In py3, there are ripple effect from having to have the initial credentials start as bytes. This complication appears on these lines: https://github.com/OpenTreeMap/otm-core/pull/3288/files#diff-c7aac248bb6ed915db70251295a116a3R96-R103. I pushed up changes to this effect.

auth = base64.b64encode(b"foobar")
withauth = {"HTTP_AUTHORIZATION": "Basic %s" % auth}

ret = get_signed(self.client, "%s/user" % API_PFX, **withauth)
self.assertEqual(ret.status_code, 401)

def test_bad_cred(self):
auth = base64.b64encode("jim:passwordz")
auth = base64.b64encode(b"jim:passwordz")
withauth = {"HTTP_AUTHORIZATION": "Basic %s" % auth}

ret = get_signed(self.client, "%s/user" % API_PFX, **withauth)
Expand All @@ -2034,7 +2034,7 @@ def test_user_has_rep(self):
ijim.reputation = 1001
ijim.save()

auth = base64.b64encode("jim:password")
auth = base64.b64encode(b"jim:password")
withauth = dict(list(self.sign.items()) +
[("HTTP_AUTHORIZATION", "Basic %s" % auth)])

Expand Down
2 changes: 1 addition & 1 deletion opentreemap/api/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def wrapper(request, *args, **kwargs):
# You can't directly set a new request body
# (http://stackoverflow.com/a/22745559)
request._body = body
request._stream = BytesIO(body)
request._stream = BytesIO(body.encode())

return user_view_fn(request, *args, **kwargs)

Expand Down
Loading