Skip to content
Open
47 changes: 43 additions & 4 deletions Lib/multiprocessing/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,19 @@ def rebuild_exc(exc, tb):
# Code run by worker processes
#

class MaybeDecodingError(Exception):
def __init__(self, exc):
self.exc = repr(exc)
self.__cause__ = exc
super(MaybeDecodingError, self).__init__(self.exc)

def __str__(self):
return "Error receiving result. Reason: '%s'" % (self.exc,
self.exc)

def __repr__(self):
return "<%s: %s>" % (self.__class__.__name__, self)

class MaybeEncodingError(Exception):
"""Wraps possible unpickleable errors, so they can be
safely sent through the socket."""
Expand Down Expand Up @@ -158,6 +171,19 @@ def __init__(self, /, *args, notifier=None, **kwds):
self.notifier = notifier
super().__init__(*args, **kwds)

self._cache_failed = False
self._cache_failed_reason = None

def _disable_cache(self, exec):
self._cache_failed = True
self._cache_failed_reason = exec

def __setitem__(self, key, value):
if self._cache_failed:
raise RuntimeError("Pool cache is disabled due to previous error") \
from self._cache_failed_reason
super().__setitem__(key, value)

def __delitem__(self, item):
super().__delitem__(item)

Expand Down Expand Up @@ -572,13 +598,26 @@ def _handle_tasks(taskqueue, put, outqueue, pool, cache):

@staticmethod
def _handle_results(outqueue, get, cache):
def _handle_results_failure(cache, e):
exc = MaybeDecodingError(e)
cache._disable_cache(exc)
_cache = cache.copy()
for value in _cache.values():
if isinstance(value, ApplyResult):
chunk_number_left = getattr(value, '_number_left', 1)
for _ in range(chunk_number_left):
value._set(None, (False, exc))
elif isinstance(value, IMapIterator):
value._set_length(value._index + 1)
value._set(value._index, (False, exc))

thread = threading.current_thread()

while 1:
try:
task = get()
except (OSError, EOFError):
util.debug('result handler got EOFError/OSError -- exiting')
except Exception as e:
_handle_results_failure(cache, e)
return

if thread._state != RUN:
Expand All @@ -600,8 +639,8 @@ def _handle_results(outqueue, get, cache):
while cache and thread._state != TERMINATE:
try:
task = get()
except (OSError, EOFError):
util.debug('result handler got EOFError/OSError -- exiting')
except Exception as e:
_handle_results_failure(cache, e)
return

if task is None:
Expand Down
23 changes: 23 additions & 0 deletions Lib/test/_test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3236,12 +3236,24 @@ def test_resource_warning(self):
pool = None
support.gc_collect()


class _Undepickleable:
def __init__(self, fail=False):
if fail:
raise RuntimeError()

def __reduce__(self):
return self.__class__, (True,)

def raising():
raise KeyError("key")

def unpickleable_result():
return lambda: 42

def undepickleable_result():
return _Undepickleable()

class _TestPoolWorkerErrors(BaseTestCase):
ALLOWED_TYPES = ('processes', )

Expand Down Expand Up @@ -3284,6 +3296,17 @@ def errback(exc):
p.close()
p.join()

@warnings_helper.ignore_fork_in_thread_deprecation_warnings()
def test_undepickleable_result(self):
from multiprocessing.pool import MaybeDecodingError
p = multiprocessing.Pool(2)
res = p.apply_async(undepickleable_result)
self.assertRaises(MaybeDecodingError, res.get, 10)
self.assertRaises(RuntimeError, p.apply_async, undepickleable_result)
p.close()
p.join()


class _TestPoolWorkerLifetime(BaseTestCase):
ALLOWED_TYPES = ('processes', )

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix :code:`multiprocessing.Pool` hang when a unpicklable object is returned
Loading