diff --git a/cas.py b/cas.py index 93e5b1c..193177d 100644 --- a/cas.py +++ b/cas.py @@ -70,7 +70,8 @@ class CASClientBase(object): def __init__(self, service_url=None, server_url=None, extra_login_params=None, renew=False, - username_attribute=None, verify_ssl_certificate=True): + username_attribute=None, verify_ssl_certificate=True, + session=None): self.service_url = service_url self.server_url = server_url @@ -78,7 +79,7 @@ def __init__(self, service_url=None, server_url=None, self.renew = renew self.username_attribute = username_attribute self.verify_ssl_certificate = verify_ssl_certificate - pass + self.session = session or requests.Session() def verify_ticket(self, ticket): """Verify ticket. @@ -136,7 +137,7 @@ def get_proxy_ticket(self, pgt): Raises: CASError: Non 200 http code or bad XML body. """ - response = requests.get(self.get_proxy_url(pgt), verify=self.verify_ssl_certificate) + response = self.session.get(self.get_proxy_url(pgt), verify=self.verify_ssl_certificate) if response.status_code == 200: from lxml import etree root = etree.fromstring(response.content) @@ -168,7 +169,7 @@ def verify_ticket(self, ticket): params = [('ticket', ticket), ('service', self.service_url)] url = (urllib_parse.urljoin(self.server_url, 'validate') + '?' + urllib_parse.urlencode(params)) - page = requests.get( + page = self.session.get( url, stream=True, verify=self.verify_ssl_certificate @@ -208,7 +209,7 @@ def get_verification_response(self, ticket): if self.proxy_callback: params.update({'pgtUrl': self.proxy_callback}) base_url = urllib_parse.urljoin(self.server_url, self.url_suffix) - page = requests.get( + page = self.session.get( base_url, params=params, verify=self.verify_ssl_certificate @@ -376,7 +377,7 @@ def fetch_saml_validation(self, ticket): saml_validate_url = urllib_parse.urljoin( self.server_url, 'samlValidate', ) - return requests.post( + return self.session.post( saml_validate_url, self.get_saml_assertion(ticket), params=params, diff --git a/tests/test_cas.py b/tests/test_cas.py index da70e53..3a4ff73 100644 --- a/tests/test_cas.py +++ b/tests/test_cas.py @@ -189,6 +189,28 @@ def test_can_saml_assertion_is_encoded(): else: assert ticket in saml +# Test session= constructor argument with a mock session +@pytest.mark.skipif(sys.version_info < (3, 3), reason="Mock class not available") +def test_v3_custom_session(): + from unittest.mock import Mock + + response = Mock() + response.content = SUCCESS_RESPONSE + session = Mock() + session.get = Mock(return_value=response) + + client = cas.CASClient( + version='3', + server_url='https://cas.example.com/cas/', + service_url='https://example.com/login', + session=session) + user, attributes, pgtiou = client.verify_ticket('ABC123') + + assert user == 'user@example.com' + assert not attributes + assert not pgtiou + + @fixture def client_v2():