Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion Lib/test/test_free_threading/test_itertools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import unittest
from itertools import accumulate, batched, chain, combinations_with_replacement, cycle, permutations, zip_longest
from itertools import accumulate, batched, chain, combinations_with_replacement, cycle, pairwise, permutations, zip_longest
from test.support import threading_helper


Expand Down Expand Up @@ -55,6 +55,13 @@ def test_combinations_with_replacement(self):
it = combinations_with_replacement(tuple(range(2)), 2)
threading_helper.run_concurrently(work_iterator, nthreads=6, args=[it])

@threading_helper.reap_threads
def test_pairwise(self):
number_of_iterations = 10
for _ in range(number_of_iterations):
it = pairwise(tuple(range(100)))
threading_helper.run_concurrently(work_iterator, nthreads=10, args=[it])

@threading_helper.reap_threads
def test_permutations(self):
number_of_iterations = 6
Expand Down
89 changes: 31 additions & 58 deletions Lib/test/test_itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,76 +968,49 @@ def test_pairwise(self):
pairwise(None) # non-iterable argument

def test_pairwise_reenter(self):
def check(reenter_at, expected):
# gh-149557: re-entering pairwise.__next__ from within the input
# iterator's __next__ now raises RuntimeError, matching the behavior
# of itertools.tee and of generators.
def check(reenter_at):
class I:
count = 0
def __iter__(self):
return self
def __next__(self):
self.count +=1
self.count += 1
if self.count in reenter_at:
return next(it)
return [self.count] # new object

it = pairwise(I())
for item in expected:
self.assertEqual(next(it), item)

check({1}, [
(([2], [3]), [4]),
([4], [5]),
])
check({2}, [
([1], ([1], [3])),
(([1], [3]), [4]),
([4], [5]),
])
check({3}, [
([1], [2]),
([2], ([2], [4])),
(([2], [4]), [5]),
([5], [6]),
])
check({1, 2}, [
((([3], [4]), [5]), [6]),
([6], [7]),
])
check({1, 3}, [
(([2], ([2], [4])), [5]),
([5], [6]),
])
check({1, 4}, [
(([2], [3]), (([2], [3]), [5])),
((([2], [3]), [5]), [6]),
([6], [7]),
])
check({2, 3}, [
([1], ([1], ([1], [4]))),
(([1], ([1], [4])), [5]),
([5], [6]),
])
with self.assertRaisesRegex(RuntimeError, "pairwise"):
list(it)

def test_pairwise_reenter2(self):
def check(maxcount, expected):
class I:
count = 0
def __iter__(self):
return self
def __next__(self):
if self.count >= maxcount:
raise StopIteration
self.count +=1
if self.count == 1:
return next(it, None)
return [self.count] # new object
# Re-entry on any of the first few __next__ calls must raise.
for reenter_at in [{1}, {2}, {3}, {1, 2}, {1, 3}, {1, 4}, {2, 3}]:
with self.subTest(reenter_at=reenter_at):
check(reenter_at)

it = pairwise(I())
self.assertEqual(list(it), expected)

check(1, [])
check(2, [])
check(3, [])
check(4, [(([2], [3]), [4])])
def test_pairwise_reenter2(self):
# gh-149557: variant of test_pairwise_reenter where the re-entrant
# call uses next(it, None) (i.e. supplies a default). Even with a
# default that suppresses StopIteration, the re-entry itself must
# raise RuntimeError, propagating out of the outer pairwise call.
class I:
count = 0
def __iter__(self):
return self
def __next__(self):
if self.count >= 4:
raise StopIteration
self.count += 1
if self.count == 1:
return next(it, None)
return [self.count] # new object

it = pairwise(I())
with self.assertRaisesRegex(RuntimeError, "pairwise"):
list(it)

def test_product(self):
for args, result in [
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
:class:`itertools.pairwise` iterators are now safe to share between
threads in the :term:`free threaded <free threading>` build: concurrent
calls to ``__next__`` are serialized via a critical section
(:gh:`123471`). Re-entrant calls (e.g. when the input iterator's
``__next__`` calls back into the :class:`!pairwise` object) now raise
:exc:`RuntimeError`, matching the behavior of generators and
:class:`itertools.tee` and fixing a use-after-free crash
(:gh:`149557`). Patch by Pieter Eendebak.
49 changes: 35 additions & 14 deletions Modules/itertoolsmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ typedef struct {
PyObject *it;
PyObject *old;
PyObject *result;
uint8_t running;
} pairwiseobject;

#define pairwiseobject_CAST(op) ((pairwiseobject *)(op))
Expand Down Expand Up @@ -344,36 +345,44 @@ pairwise_traverse(PyObject *op, visitproc visit, void *arg)
}

static PyObject *
pairwise_next(PyObject *op)
pairwise_next_lock_held(PyObject *op)
{
pairwiseobject *po = pairwiseobject_CAST(op);

// Re-entrancy guard. The enclosing critical section serializes calls
// from different threads; this flag catches same-thread re-entry (the
// input iterator's __next__ calling back into pairwise.__next__) and
// raises, matching itertools.tee and generators.
if (po->running) {
PyErr_SetString(PyExc_RuntimeError,
"pairwise() iterator already executing");
return NULL;
}
po->running = 1;

PyObject *it = po->it;
PyObject *old = po->old;
PyObject *new, *result;

if (it == NULL) {
return NULL;
result = NULL;
goto done;
}
if (old == NULL) {
old = (*Py_TYPE(it)->tp_iternext)(it);
Py_XSETREF(po->old, old);
if (old == NULL) {
Py_CLEAR(po->it);
return NULL;
}
it = po->it;
if (it == NULL) {
Py_CLEAR(po->old);
return NULL;
result = NULL;
goto done;
}
po->old = old; // po->old was NULL; no decref needed
}
Py_INCREF(old);
new = (*Py_TYPE(it)->tp_iternext)(it);
if (new == NULL) {
Py_CLEAR(po->it);
Py_CLEAR(po->old);
Py_DECREF(old);
return NULL;
result = NULL;
goto done;
}

result = po->result;
Expand All @@ -393,8 +402,20 @@ pairwise_next(PyObject *op)
result = _PyTuple_FromPair(old, new);
}

Py_XSETREF(po->old, new);
Py_DECREF(old);
Py_SETREF(po->old, new); // po->old == old, known non-NULL

done:
po->running = 0;
return result;
}

static PyObject *
pairwise_next(PyObject *op)
{
PyObject *result;
Py_BEGIN_CRITICAL_SECTION(op);
result = pairwise_next_lock_held(op);
Py_END_CRITICAL_SECTION();
return result;
}

Expand Down
Loading