Skip to content

Commit

Permalink
fix: RedirectURIValidator Encapsulation (#1345)
Browse files Browse the repository at this point in the history
  • Loading branch information
dopry authored Oct 20, 2023
1 parent 584627d commit 4c13679
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 54 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* #1322 Instructions in documentation on how to create a code challenge and code verifier
* #1284 Allow to logout with no id_token_hint even if the browser session already expired
* #1296 Added reverse function in migration 0006_alter_application_client_secret
* #1336 Fix encapsulation for Redirect URI scheme validation

## [2.3.0] 2023-05-31

Expand Down
11 changes: 5 additions & 6 deletions oauth2_provider/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .scopes import get_scopes_backend
from .settings import oauth2_settings
from .utils import jwk_from_pem
from .validators import AllowedURIValidator, RedirectURIValidator, WildcardSet
from .validators import AllowedURIValidator


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -202,12 +202,11 @@ def clean(self):
allowed_schemes = set(s.lower() for s in self.get_allowed_schemes())

if redirect_uris:
validator = RedirectURIValidator(WildcardSet())
validator = AllowedURIValidator(
allowed_schemes, name="redirect uri", allow_path=True, allow_query=True
)
for uri in redirect_uris:
validator(uri)
scheme = urlparse(uri).scheme
if scheme not in allowed_schemes:
raise ValidationError(_("Unauthorized redirect scheme: {scheme}").format(scheme=scheme))

elif self.authorization_grant_type in grant_types:
raise ValidationError(
Expand All @@ -218,7 +217,7 @@ def clean(self):
allowed_origins = self.allowed_origins.strip().split()
if allowed_origins:
# oauthlib allows only https scheme for CORS
validator = AllowedURIValidator(oauth2_settings.ALLOWED_SCHEMES, "Origin")
validator = AllowedURIValidator(oauth2_settings.ALLOWED_SCHEMES, "allowed origin")
for uri in allowed_origins:
validator(uri)

Expand Down
1 change: 0 additions & 1 deletion oauth2_provider/oauth2_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,6 @@ def authenticate_client_id(self, client_id, request, *args, **kwargs):
proceed only if the client exists and is not of type "Confidential".
"""
if self._load_application(client_id, request) is not None:
log.debug("Application %r has type %r" % (client_id, request.client.client_type))
return request.client.client_type != AbstractApplication.CLIENT_CONFIDENTIAL
return False

Expand Down
48 changes: 43 additions & 5 deletions oauth2_provider/validators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
import warnings
from urllib.parse import urlsplit

from django.core.exceptions import ValidationError
Expand All @@ -20,6 +21,7 @@ class URIValidator(URLValidator):

class RedirectURIValidator(URIValidator):
def __init__(self, allowed_schemes, allow_fragments=False):
warnings.warn("This class is deprecated and will be removed in version 2.5.0.", DeprecationWarning)
super().__init__(schemes=allowed_schemes)
self.allow_fragments = allow_fragments

Expand All @@ -32,6 +34,8 @@ def __call__(self, value):


class AllowedURIValidator(URIValidator):
# TODO: find a way to get these associated with their form fields in place of passing name
# TODO: submit PR to get `cause` included in the parent class ValidationError params`
def __init__(self, schemes, name, allow_path=False, allow_query=False, allow_fragments=False):
"""
:param schemes: List of allowed schemes. E.g.: ["https"]
Expand All @@ -47,15 +51,45 @@ def __init__(self, schemes, name, allow_path=False, allow_query=False, allow_fra
self.allow_fragments = allow_fragments

def __call__(self, value):
super().__call__(value)
value = force_str(value)
scheme, netloc, path, query, fragment = urlsplit(value)
try:
scheme, netloc, path, query, fragment = urlsplit(value)
except ValueError as e:
raise ValidationError(
"%(name)s URI validation error. %(cause)s: %(value)s",
params={"name": self.name, "value": value, "cause": e},
)

# send better validation errors
if scheme not in self.schemes:
raise ValidationError(
"%(name)s URI Validation error. %(cause)s: %(value)s",
params={"name": self.name, "value": value, "cause": "invalid_scheme"},
)

if query and not self.allow_query:
raise ValidationError("{} URIs must not contain query".format(self.name))
raise ValidationError(
"%(name)s URI validation error. %(cause)s: %(value)s",
params={"name": self.name, "value": value, "cause": "query string not allowed"},
)
if fragment and not self.allow_fragments:
raise ValidationError("{} URIs must not contain fragments".format(self.name))
raise ValidationError(
"%(name)s URI validation error. %(cause)s: %(value)s",
params={"name": self.name, "value": value, "cause": "fragment not allowed"},
)
if path and not self.allow_path:
raise ValidationError("{} URIs must not contain path".format(self.name))
raise ValidationError(
"%(name)s URI validation error. %(cause)s: %(value)s",
params={"name": self.name, "value": value, "cause": "path not allowed"},
)

try:
super().__call__(value)
except ValidationError as e:
raise ValidationError(
"%(name)s URI validation error. %(cause)s: %(value)s",
params={"name": self.name, "value": value, "cause": e},
)


##
Expand All @@ -69,5 +103,9 @@ class WildcardSet(set):
A set that always returns True on `in`.
"""

def __init__(self, *args, **kwargs):
warnings.warn("This class is deprecated and will be removed in version 2.5.0.", DeprecationWarning)
super().__init__(*args, **kwargs)

def __contains__(self, item):
return True
2 changes: 1 addition & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def test_application_clean(oauth2_settings, application):
application.allowed_origins = "http://example.com"
with pytest.raises(ValidationError) as exc:
application.clean()
assert "Enter a valid URL" in str(exc.value)
assert "allowed origin URI Validation error. invalid_scheme: http://example.com" in str(exc.value)
application.allowed_origins = "https://example.com"
application.clean()

Expand Down
216 changes: 175 additions & 41 deletions tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from django.core.validators import ValidationError
from django.test import TestCase

from oauth2_provider.validators import AllowedURIValidator, RedirectURIValidator
from oauth2_provider.validators import AllowedURIValidator, RedirectURIValidator, WildcardSet


@pytest.mark.usefixtures("oauth2_settings")
Expand Down Expand Up @@ -36,11 +36,6 @@ def test_validate_custom_uri_scheme(self):
# Check ValidationError not thrown
validator(uri)

validator = AllowedURIValidator(["my-scheme", "https", "git+ssh"], "Origin")
for uri in good_uris:
# Check ValidationError not thrown
validator(uri)

def test_validate_bad_uris(self):
validator = RedirectURIValidator(allowed_schemes=["https"])
self.oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["https", "good"]
Expand All @@ -67,47 +62,73 @@ def test_validate_bad_uris(self):
with self.assertRaises(ValidationError):
validator(uri)

def test_validate_good_origin_uris(self):
"""
Test AllowedURIValidator validates origin URIs if they match requirements
"""
validator = AllowedURIValidator(
["https"],
"Origin",
allow_path=False,
allow_query=False,
allow_fragments=False,
)
def test_validate_wildcard_scheme__bad_uris(self):
validator = RedirectURIValidator(allowed_schemes=WildcardSet())
bad_uris = [
"http:/example.com#fragment",
"HTTP://localhost#fragment",
"http://example.com/#fragment",
"good://example.com/#fragment",
" ",
"",
# Bad IPv6 URL, urlparse behaves differently for these
'https://["><script>alert()</script>',
]

for uri in bad_uris:
with self.assertRaises(ValidationError, msg=uri):
validator(uri)

def test_validate_wildcard_scheme_good_uris(self):
validator = RedirectURIValidator(allowed_schemes=WildcardSet())
good_uris = [
"my-scheme://example.com",
"my-scheme://example",
"my-scheme://localhost",
"https://example.com",
"https://example.com:8080",
"https://example",
"https://localhost",
"https://1.1.1.1",
"https://127.0.0.1",
"https://255.255.255.255",
"HTTPS://example.com",
"HTTPS://example.com.",
"git+ssh://example.com",
"ANY://localhost",
"scheme://example.com",
"at://example.com",
"all://example.com",
]
for uri in good_uris:
# Check ValidationError not thrown
validator(uri)

def test_validate_bad_origin_uris(self):
"""
Test AllowedURIValidator rejects origin URIs if they do not match requirements
"""
validator = AllowedURIValidator(
["https"],
"Origin",
allow_path=False,
allow_query=False,
allow_fragments=False,
)

@pytest.mark.usefixtures("oauth2_settings")
class TestAllowedURIValidator(TestCase):
# TODO: verify the specifics of the ValidationErrors
def test_valid_schemes(self):
validator = AllowedURIValidator(["my-scheme", "https", "git+ssh"], "test")
good_uris = [
"my-scheme://example.com",
"my-scheme://example",
"my-scheme://localhost",
"https://example.com",
"HTTPS://example.com",
"git+ssh://example.com",
]
for uri in good_uris:
# Check ValidationError not thrown
validator(uri)

def test_invalid_schemes(self):
validator = AllowedURIValidator(["https"], "test")
bad_uris = [
"http:/example.com",
"HTTP://localhost",
"HTTP://example.com",
"https://-exa", # triggers an exception in the upstream validators
"HTTP://example.com/path",
"HTTP://example.com/path?query=string",
"HTTP://example.com/path?query=string#fragmemt",
"HTTP://example.com.",
"http://example.com/#fragment",
"http://example.com/path/#fragment",
"http://example.com?query=string#fragment",
"123://example.com",
"http://fe80::1",
"git+ssh://example.com",
Expand All @@ -119,12 +140,125 @@ def test_validate_bad_origin_uris(self):
"",
# Bad IPv6 URL, urlparse behaves differently for these
'https://["><script>alert()</script>',
# Origin uri should not contain path, query of fragment parts
# https://www.rfc-editor.org/rfc/rfc6454#section-7.1
"https://example.com/",
"https://example.com/test",
"https://example.com/?q=test",
"https://example.com/#test",
]

for uri in bad_uris:
with self.assertRaises(ValidationError):
validator(uri)

def test_allow_paths_valid_urls(self):
validator = AllowedURIValidator(["https", "myapp"], "test", allow_path=True)
good_uris = [
"https://example.com",
"https://example.com:8080",
"https://example",
"https://example.com/path",
"https://example.com:8080/path",
"https://example/path",
"https://localhost/path",
"myapp://host/path",
]
for uri in good_uris:
# Check ValidationError not thrown
validator(uri)

def test_allow_paths_invalid_urls(self):
validator = AllowedURIValidator(["https", "myapp"], "test", allow_path=True)
bad_uris = [
"https://example.com?query=string",
"https://example.com#fragment",
"https://example.com/path?query=string",
"https://example.com/path#fragment",
"https://example.com/path?query=string#fragment",
"myapp://example.com/path?query=string",
"myapp://example.com/path#fragment",
"myapp://example.com/path?query=string#fragment",
"bad://example.com/path",
]

for uri in bad_uris:
with self.assertRaises(ValidationError):
validator(uri)

def test_allow_query_valid_urls(self):
validator = AllowedURIValidator(["https", "myapp"], "test", allow_query=True)
good_uris = [
"https://example.com",
"https://example.com:8080",
"https://example.com?query=string",
"https://example",
"myapp://example.com?query=string",
"myapp://example?query=string",
]
for uri in good_uris:
# Check ValidationError not thrown
validator(uri)

def test_allow_query_invalid_urls(self):
validator = AllowedURIValidator(["https", "myapp"], "test", allow_query=True)
bad_uris = [
"https://example.com/path",
"https://example.com#fragment",
"https://example.com/path?query=string",
"https://example.com/path#fragment",
"https://example.com/path?query=string#fragment",
"https://example.com:8080/path",
"https://example/path",
"https://localhost/path",
"myapp://example.com/path?query=string",
"myapp://example.com/path#fragment",
"myapp://example.com/path?query=string#fragment",
"bad://example.com/path",
]

for uri in bad_uris:
with self.assertRaises(ValidationError):
validator(uri)

def test_allow_fragment_valid_urls(self):
validator = AllowedURIValidator(["https", "myapp"], "test", allow_fragments=True)
good_uris = [
"https://example.com",
"https://example.com#fragment",
"https://example.com:8080",
"https://example.com:8080#fragment",
"https://example",
"https://example#fragment",
"myapp://example",
"myapp://example#fragment",
"myapp://example.com",
"myapp://example.com#fragment",
]
for uri in good_uris:
# Check ValidationError not thrown
validator(uri)

def test_allow_fragment_invalid_urls(self):
validator = AllowedURIValidator(["https", "myapp"], "test", allow_fragments=True)
bad_uris = [
"https://example.com?query=string",
"https://example.com?query=string#fragment",
"https://example.com/path",
"https://example.com/path?query=string",
"https://example.com/path#fragment",
"https://example.com/path?query=string#fragment",
"https://example.com:8080/path",
"https://example?query=string",
"https://example?query=string#fragment",
"https://example/path",
"https://example/path?query=string",
"https://example/path#fragment",
"https://example/path?query=string#fragment",
"myapp://example?query=string",
"myapp://example?query=string#fragment",
"myapp://example/path",
"myapp://example/path?query=string",
"myapp://example/path#fragment",
"myapp://example.com/path?query=string",
"myapp://example.com/path#fragment",
"myapp://example.com/path?query=string#fragment",
"myapp://example.com?query=string",
"bad://example.com",
]

for uri in bad_uris:
Expand Down

0 comments on commit 4c13679

Please sign in to comment.