Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Doc/library/functools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,9 @@ The :mod:`!functools` module defines the following functions:
.. versionchanged:: 3.14
Added support for :data:`Placeholder` in positional arguments.

.. versionchanged:: 3.15
:class:`partial` now stores keywords in a :class:`frozendict`

.. class:: partialmethod(func, /, *args, **keywords)

Return a new :class:`partialmethod` descriptor which behaves
Expand Down
23 changes: 18 additions & 5 deletions Lib/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,11 +342,12 @@ def _partial_new(cls, func, /, *args, **keywords):
phcount, merger = _partial_prepare_merger(tot_args)
else: # works for both pto_phcount == 0 and != 0
phcount, merger = pto_phcount, func._merger
keywords = {**func.keywords, **keywords}
keywords = frozendict(**func.keywords, **keywords)
func = func.func
else:
tot_args = args
phcount, merger = _partial_prepare_merger(tot_args)
keywords = frozendict(**keywords)

self = object.__new__(cls)
self.func = func
Expand Down Expand Up @@ -397,6 +398,12 @@ def __get__(self, obj, objtype=None):
return self
return MethodType(self, obj)

def __reduce_ex__(self, protocol):
if protocol >= 2:
return self.__reduce__()
return type(self), (self.func,), (self.func, self.args,
dict(self.keywords) or None, self.__dict__ or None)

def __reduce__(self):
return type(self), (self.func,), (self.func, self.args,
self.keywords or None, self.__dict__ or None)
Expand All @@ -408,19 +415,25 @@ def __setstate__(self, state):
raise TypeError(f"expected 4 items in state, got {len(state)}")
func, args, kwds, namespace = state
if (not callable(func) or not isinstance(args, tuple) or
(kwds is not None and not isinstance(kwds, dict)) or
(namespace is not None and not isinstance(namespace, dict))):
raise TypeError("invalid partial state")
if kwds is not None and not (
isinstance(kwds, dict) or isinstance(kwds, frozendict)):
raise TypeError(f"keywords must be an instance of dict or frozendict, not {type(kwds)}")

if args and args[-1] is Placeholder:
raise TypeError("trailing Placeholders are not allowed")
phcount, merger = _partial_prepare_merger(args)

args = tuple(args) # just in case it's a subclass
if kwds is None:
kwds = {}
elif type(kwds) is not dict: # XXX does it need to be *exactly* dict?
kwds = dict(kwds)
kwds = frozendict()
else:
for key in kwds:
if type(key) is not str:
raise TypeError("keywords must be a string")
kwds = frozendict(kwds)

if namespace is None:
namespace = {}

Expand Down
53 changes: 31 additions & 22 deletions Lib/test/test_functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def __add__(self, other):
class MyDict(dict):
pass

class MyFrozenDict(frozendict):
pass

class TestImportTime(unittest.TestCase):

@cpython_only
Expand Down Expand Up @@ -404,6 +407,16 @@ def test_setstate(self):
with self.assertRaisesRegex(TypeError, f'^{msg_regex}$') as cm:
f.__setstate__((capture, (1, PH), dict(a=10), dict(attr=[])))

with self.assertRaises(TypeError):
f.__setstate__((capture, (1,), {1234: 1234}, dict(attr=[])))

class FakeString(str):
pass

with self.assertRaises(TypeError):
f.__setstate__((capture, (1,), {FakeString("string"): 1234}, dict(attr=[])))


def test_setstate_errors(self):
f = self.partial(signature)

Expand All @@ -423,7 +436,18 @@ def test_setstate_subclasses(self):
s = signature(f)
self.assertEqual(s, (capture, (1,), dict(a=10), {}))
self.assertIs(type(s[1]), tuple)
self.assertIs(type(s[2]), dict)
self.assertIs(type(s[2]), frozendict)
r = f()
self.assertEqual(r, ((1,), {'a': 10}))
self.assertIs(type(r[0]), tuple)
self.assertIs(type(r[1]), dict)


f.__setstate__((capture, MyTuple((1,)), MyFrozenDict(a=10), None))
s = signature(f)
self.assertEqual(s, (capture, (1,), dict(a=10), {}))
self.assertIs(type(s[1]), tuple)
self.assertIs(type(s[2]), frozendict)
r = f()
self.assertEqual(r, ((1,), {'a': 10}))
self.assertIs(type(r[0]), tuple)
Expand Down Expand Up @@ -588,30 +612,15 @@ def test_attributes_unwritable(self):
else:
self.fail('partial object allowed __dict__ to be deleted')

def test_manually_adding_non_string_keyword(self):
def test_keyword_mutations(self):
p = self.partial(capture)
# Adding a non-string/unicode keyword to partial kwargs
p.keywords[1234] = 'value'
r = repr(p)
self.assertIn('1234', r)
self.assertIn("'value'", r)
with self.assertRaises(TypeError):
p()

def test_keystr_replaces_value(self):
p = self.partial(capture)
with self.assertRaises(TypeError):
p.keywords["new key"] = ['sth']

class MutatesYourDict(object):
def __str__(self):
p.keywords[self] = ['sth2']
return 'astr'

# Replacing the value during key formatting should keep the original
# value alive (at least long enough).
p.keywords[MutatesYourDict()] = ['sth']
r = repr(p)
self.assertIn('astr', r)
self.assertIn("['sth']", r)
# Adding a non-string/unicode keyword to partial kwargs
with self.assertRaises(TypeError):
p.keywords[1234] = 'value'

def test_placeholders_refcount_smoke(self):
PH = self.module.Placeholder
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Now :func:`functools.partial` stores keywords in a :class:`frozendict` and enforces that they are always strings.
Loading
Loading