Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions cas.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,16 @@ class CASClientBase(object):

def __init__(self, service_url=None, server_url=None,
extra_login_params=None, renew=False,
username_attribute=None, verify_ssl_certificate=True):
username_attribute=None, verify_ssl_certificate=True,
session=None):

self.service_url = service_url
self.server_url = server_url
self.extra_login_params = extra_login_params or {}
self.renew = renew
self.username_attribute = username_attribute
self.verify_ssl_certificate = verify_ssl_certificate
pass
self.session = session or requests.Session()

def verify_ticket(self, ticket):
"""Verify ticket.
Expand Down Expand Up @@ -136,7 +137,7 @@ def get_proxy_ticket(self, pgt):
Raises:
CASError: Non 200 http code or bad XML body.
"""
response = requests.get(self.get_proxy_url(pgt), verify=self.verify_ssl_certificate)
response = self.session.get(self.get_proxy_url(pgt), verify=self.verify_ssl_certificate)
if response.status_code == 200:
from lxml import etree
root = etree.fromstring(response.content)
Expand Down Expand Up @@ -168,7 +169,7 @@ def verify_ticket(self, ticket):
params = [('ticket', ticket), ('service', self.service_url)]
url = (urllib_parse.urljoin(self.server_url, 'validate') + '?' +
urllib_parse.urlencode(params))
page = requests.get(
page = self.session.get(
url,
stream=True,
verify=self.verify_ssl_certificate
Expand Down Expand Up @@ -208,7 +209,7 @@ def get_verification_response(self, ticket):
if self.proxy_callback:
params.update({'pgtUrl': self.proxy_callback})
base_url = urllib_parse.urljoin(self.server_url, self.url_suffix)
page = requests.get(
page = self.session.get(
base_url,
params=params,
verify=self.verify_ssl_certificate
Expand Down Expand Up @@ -376,7 +377,7 @@ def fetch_saml_validation(self, ticket):
saml_validate_url = urllib_parse.urljoin(
self.server_url, 'samlValidate',
)
return requests.post(
return self.session.post(
saml_validate_url,
self.get_saml_assertion(ticket),
params=params,
Expand Down
22 changes: 22 additions & 0 deletions tests/test_cas.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,28 @@ def test_can_saml_assertion_is_encoded():
else:
assert ticket in saml

# Test session= constructor argument with a mock session
@pytest.mark.skipif(sys.version_info < (3, 3), reason="Mock class not available")
def test_v3_custom_session():
from unittest.mock import Mock

response = Mock()
response.content = SUCCESS_RESPONSE
session = Mock()
session.get = Mock(return_value=response)

client = cas.CASClient(
version='3',
server_url='https://cas.example.com/cas/',
service_url='https://example.com/login',
session=session)
user, attributes, pgtiou = client.verify_ticket('ABC123')

assert user == 'user@example.com'
assert not attributes
assert not pgtiou



@fixture
def client_v2():
Expand Down