Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ Dylan Tack
Eduardo Oliveira
Egor Poderiagin
Emanuele Palazzetti
Fazeel Ghafoor
Federico Dolce
Florian Demmer
Frederico Vieira
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [unreleased]
### Added
* Add migration to include `token_checksum` field in AbstractAccessToken model.
* #1404 Add a new setting `REFRESH_TOKEN_REUSE_PROTECTION`
### Changed
* Update token to TextField from CharField with 255 character limit and SHA-256 checksum in AbstractAccessToken model. Removing the 255 character limit enables supporting JWT tokens with additional claims

* Update middleware, validators, and views to use token checksums instead of token for token retrieval and validation.
### Deprecated
### Removed
* #1425 Remove deprecated `RedirectURIValidator`, `WildcardSet` per #1345; `validate_logout_request` per #1274
Expand Down
4 changes: 3 additions & 1 deletion oauth2_provider/middleware.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import logging

from django.contrib.auth import authenticate
Expand Down Expand Up @@ -55,7 +56,8 @@ def __call__(self, request):
tokenstring = authheader.split()[1]
AccessToken = get_access_token_model()
try:
token = AccessToken.objects.get(token=tokenstring)
token_checksum = hashlib.sha256(tokenstring.encode("utf-8")).hexdigest()
token = AccessToken.objects.get(token_checksum=token_checksum)
request.access_token = token
except AccessToken.DoesNotExist as e:
log.exception(e)
Expand Down
26 changes: 26 additions & 0 deletions oauth2_provider/migrations/0012_add_token_checksum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Generated by Django 5.0.7 on 2024-07-29 23:13

import oauth2_provider.models
from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("oauth2_provider", "0011_refreshtoken_token_family"),
migrations.swappable_dependency(settings.OAUTH2_PROVIDER_ACCESS_TOKEN_MODEL),
]

operations = [
migrations.AddField(
model_name="accesstoken",
name="token_checksum",
field=oauth2_provider.models.TokenChecksumField(
blank=True, db_index=True, max_length=64, unique=True
),
),
migrations.AlterField(
model_name="accesstoken",
name="token",
field=models.TextField(),
),
]
15 changes: 13 additions & 2 deletions oauth2_provider/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import logging
import time
import uuid
Expand Down Expand Up @@ -44,6 +45,14 @@ def pre_save(self, model_instance, add):
return super().pre_save(model_instance, add)


class TokenChecksumField(models.CharField):
def pre_save(self, model_instance, add):
token = getattr(model_instance, "token")
checksum = hashlib.sha256(token.encode("utf-8")).hexdigest()
setattr(model_instance, self.attname, checksum)
return super().pre_save(model_instance, add)


class AbstractApplication(models.Model):
"""
An Application instance represents a Client on the Authorization server.
Expand Down Expand Up @@ -379,8 +388,10 @@ class AbstractAccessToken(models.Model):
null=True,
related_name="refreshed_access_token",
)
token = models.CharField(
max_length=255,
token = models.TextField()
token_checksum = TokenChecksumField(
max_length=64,
blank=True,
unique=True,
db_index=True,
)
Expand Down
8 changes: 7 additions & 1 deletion oauth2_provider/oauth2_validators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
import binascii
import hashlib
import http.client
import inspect
import json
Expand Down Expand Up @@ -461,7 +462,12 @@ def validate_bearer_token(self, token, scopes, request):
return False

def _load_access_token(self, token):
return AccessToken.objects.select_related("application", "user").filter(token=token).first()
token_checksum = hashlib.sha256(token.encode("utf-8")).hexdigest()
return (
AccessToken.objects.select_related("application", "user")
.filter(token_checksum=token_checksum)
.first()
)

def validate_code(self, client_id, code, client, request, *args, **kwargs):
try:
Expand Down
4 changes: 3 additions & 1 deletion oauth2_provider/views/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import json
import logging
from urllib.parse import parse_qsl, urlencode, urlparse
Expand Down Expand Up @@ -289,7 +290,8 @@ def post(self, request, *args, **kwargs):
if status == 200:
access_token = json.loads(body).get("access_token")
if access_token is not None:
token = get_access_token_model().objects.get(token=access_token)
token_checksum = hashlib.sha256(access_token.encode("utf-8")).hexdigest()
token = get_access_token_model().objects.get(token_checksum=token_checksum)
app_authorized.send(sender=self, request=request, token=token)
response = HttpResponse(content=body, status=status)

Expand Down
6 changes: 5 additions & 1 deletion oauth2_provider/views/introspect.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import calendar
import hashlib

from django.core.exceptions import ObjectDoesNotExist
from django.http import JsonResponse
Expand All @@ -24,8 +25,11 @@ class IntrospectTokenView(ClientProtectedScopedResourceView):
@staticmethod
def get_token_response(token_value=None):
try:
token_checksum = hashlib.sha256(token_value.encode("utf-8")).hexdigest()
token = (
get_access_token_model().objects.select_related("user", "application").get(token=token_value)
get_access_token_model()
.objects.select_related("user", "application")
.get(token_checksum=token_checksum)
)
except ObjectDoesNotExist:
return JsonResponse({"active": False}, status=200)
Expand Down
12 changes: 8 additions & 4 deletions tests/migrations/0002_swapped_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,14 @@ class Migration(migrations.Migration):
field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='s_refreshed_access_token', to=settings.OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL),
),
migrations.AddField(
model_name='sampleaccesstoken',
name='token',
field=models.CharField(max_length=255, unique=True),
preserve_default=False,
model_name="sampleaccesstoken",
name="token",
field=models.TextField(),
),
migrations.AddField(
model_name="sampleaccesstoken",
name="token_checksum",
field=models.CharField(max_length=64, unique=True, db_index=True),
),
migrations.AddField(
model_name='sampleaccesstoken',
Expand Down
13 changes: 13 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import hashlib
import secrets
from datetime import timedelta

import pytest
Expand Down Expand Up @@ -310,6 +312,17 @@ def test_expires_can_be_none(self):
self.assertIsNone(access_token.expires)
self.assertTrue(access_token.is_expired())

def test_token_checksum_field(self):
token = secrets.token_urlsafe(32)
access_token = AccessToken.objects.create(
user=self.user,
token=token,
expires=timezone.now() + timedelta(hours=1),
)
expected_checksum = hashlib.sha256(token.encode()).hexdigest()

self.assertEqual(access_token.token_checksum, expected_checksum)


class TestRefreshTokenModel(BaseTestModels):
def test_str(self):
Expand Down
Loading