Skip to content
7 changes: 7 additions & 0 deletions Doc/library/math.integer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ computed exactly and are integers.
:trim:


.. function:: isqrt_rem(n, /)

For a nonnegative integer *n*, return the pair of integers ``(s, t)``
such that ``s = isqrt(n)`` and ``t = n - s*s``.
The remainder *t* is zero, if *n* is a perfect square.


.. function:: lcm(*integers)

Return the least common multiple of the specified integer arguments.
Expand Down
7 changes: 7 additions & 0 deletions Doc/whatsnew/3.16.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ lzma
requires ``lzma`` 5.4.0 or newer while RISC-V requires 5.6.0 or newer.
(Contributed by Chien Wong in :gh:`115988`.)

math.integer
------------

* Add :func:`math.integer.isqrt_rem` to compute integer square root with
a remainder.
(Contributed by Sergey B Kirpichev in :gh:`90345`.)

os
--

Expand Down
39 changes: 39 additions & 0 deletions Lib/test/test_math_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from fractions import Fraction
import unittest
from test import support
from math.integer import isqrt_rem


class IntSubclass(int):
Expand Down Expand Up @@ -249,6 +250,44 @@ def test_isqrt_huge(self, size):
self.assertEqual(w.bit_length(), size // 2 + 1)
self.assertEqual(w.bit_count(), 1)

def test_isqrt_rem(self):
test_values = (
list(range(1000))
+ list(range(10**6 - 1000, 10**6 + 1000))
+ [2**e + i for e in range(60, 200) for i in range(-40, 40)]
+ [3**9999, 10**5001]
)
for value in test_values:
with self.subTest(value=value):
root, rem = isqrt_rem(value)
self.assertIs(type(root), int)
self.assertLessEqual(root*root, value)
self.assertLess(value, (root+1)*(root+1))
self.assertIs(type(rem), int)
self.assertEqual(rem, value - root*root)

# Negative values
with self.assertRaises(ValueError):
isqrt_rem(-1)

# Integer-like things
self.assertEqual(isqrt_rem(True), (1, 0))
self.assertEqual(isqrt_rem(False), (0, 0))
self.assertEqual(isqrt_rem(MyIndexable(1729)), (41, 48))
Comment thread
skirpichev marked this conversation as resolved.

with self.assertRaises(ValueError):
isqrt_rem(MyIndexable(-3))

# Non-integer-like things
bad_values = [
3.5, "3.5", Decimal("3.5"), 3.5j,
100.0, -4.0,
]
for value in bad_values:
with self.subTest(value=value):
with self.assertRaises(TypeError):
isqrt_rem(value)

def test_perm(self):
perm = self.module.perm
factorial = self.module.factorial
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add :func:`math.integer.isqrt_rem`.
11 changes: 10 additions & 1 deletion Modules/clinic/mathintegermodule.c.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

97 changes: 84 additions & 13 deletions Modules/mathintegermodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -341,18 +341,8 @@ _approximate_isqrt(uint64_t n)
return (u << 15) + (uint32_t)((n >> 17) / u);
}

/*[clinic input]
math.integer.isqrt

n: object
/

Return the integer part of the square root of the input.
[clinic start generated code]*/

static PyObject *
math_integer_isqrt(PyObject *module, PyObject *n)
/*[clinic end generated code: output=551031e41a0f5d9e input=921ddd9853133d8d]*/
_isqrt_rem(PyObject *n, PyObject **rem)
{
int a_too_large, c_bit_length;
int64_t c, d;
Expand All @@ -373,6 +363,9 @@ math_integer_isqrt(PyObject *module, PyObject *n)
}
if (_PyLong_IsZero((PyLongObject *)n)) {
Py_DECREF(n);
if (rem) {
*rem = PyLong_FromLong(0);
}
return PyLong_FromLong(0);
}

Expand All @@ -392,7 +385,15 @@ math_integer_isqrt(PyObject *module, PyObject *n)
return NULL;
}
u = _approximate_isqrt(m << 2*shift) >> shift;
u -= (uint64_t)u * u > m;
uint64_t sq = (uint64_t)u * u;
u -= sq > m;
if (rem) {
if (sq > m) {
sq -= 2*(uint64_t)u + 1;
}
m -= sq;
*rem = PyLong_FromUnsignedLongLong(m);
}
return PyLong_FromUnsignedLong(u);
}

Expand Down Expand Up @@ -460,13 +461,40 @@ math_integer_isqrt(PyObject *module, PyObject *n)
goto error;
}
a_too_large = PyObject_RichCompareBool(n, b, Py_LT);
Py_DECREF(b);
if (a_too_large == -1) {
Py_DECREF(b);
goto error;
}

if (a_too_large) {
if (rem) {
PyObject *tmp = PyNumber_Add(b, _PyLong_GetOne());

if (tmp == NULL) {
Py_DECREF(b);
goto error;
}
Py_SETREF(b, tmp);
tmp = PyNumber_Add(a, a);
if (tmp == NULL) {
Py_DECREF(b);
goto error;
}
Py_SETREF(b, PyNumber_Subtract(b, tmp));
Py_DECREF(tmp);
}
Py_SETREF(a, PyNumber_Subtract(a, _PyLong_GetOne()));
if (a == NULL) {
Py_DECREF(b);
goto error;
}
}
if (rem) {
Py_SETREF(b, PyNumber_Subtract(n, b));
Comment thread
skirpichev marked this conversation as resolved.
*rem = b;
}
else {
Py_DECREF(b);
}
Py_DECREF(n);
return a;
Expand All @@ -478,6 +506,48 @@ math_integer_isqrt(PyObject *module, PyObject *n)
}


/*[clinic input]
math.integer.isqrt

n: object
/

Return the integer part of the square root of the input.
[clinic start generated code]*/

static PyObject *
math_integer_isqrt(PyObject *module, PyObject *n)
/*[clinic end generated code: output=551031e41a0f5d9e input=921ddd9853133d8d]*/
{
return _isqrt_rem(n, NULL);
}

/*[clinic input]
math.integer.isqrt_rem

n: object
/

Return a pair of values (s,t) such that s=isqrt(n) and t=n-s*s.
[clinic start generated code]*/

static PyObject *
math_integer_isqrt_rem(PyObject *module, PyObject *n)
/*[clinic end generated code: output=b17d11479d08cdc4 input=7ed2dd870818d2bb]*/
{
PyObject *rem = NULL;
PyObject *root = _isqrt_rem(n, &rem);
PyObject *res = NULL;

if (root && rem) {
res = PyTuple_Pack(2, root, rem);
}
Py_XDECREF(root);
Py_XDECREF(rem);
return res;
}


static unsigned long
count_set_bits(unsigned long n)
{
Expand Down Expand Up @@ -1231,6 +1301,7 @@ static PyMethodDef math_integer_methods[] = {
MATH_INTEGER_FACTORIAL_METHODDEF
MATH_INTEGER_GCD_METHODDEF
MATH_INTEGER_ISQRT_METHODDEF
MATH_INTEGER_ISQRT_REM_METHODDEF
MATH_INTEGER_LCM_METHODDEF
MATH_INTEGER_PERM_METHODDEF
{NULL, NULL} /* sentinel */
Expand Down
Loading