From d1a9cb461c0c4fca18f866d751600327f8f57527 Mon Sep 17 00:00:00 2001 From: Sergey B Kirpichev Date: Mon, 28 Aug 2023 11:17:34 +0300 Subject: [PATCH] XXX move k_mul() implementation to Lib/_pylong.py --- Lib/_pylong.py | 16 ++ Objects/longobject.c | 408 ++++--------------------------------------- 2 files changed, 52 insertions(+), 372 deletions(-) diff --git a/Lib/_pylong.py b/Lib/_pylong.py index 936346e187ff699..9b03d8056a7aae8 100644 --- a/Lib/_pylong.py +++ b/Lib/_pylong.py @@ -283,3 +283,19 @@ def int_divmod(a, b): return ~q, b + ~r else: return _divmod_pos(a, b) + + +def k_mul(x, y): + n = max(x.bit_length(), y.bit_length()) // 2 + x0 = x >> n + x -= x0 << n + if x == y: + y0 = x0 + y = x + else: + y0 = y >> n + y -= y0 << n + s1 = x0 * y0 + s2 = x * y + s3 = (x0 + x) * (y0 + y) - s1 - s2 + return (s1 << (n << 1)) + (s3 << n) + s2 diff --git a/Objects/longobject.c b/Objects/longobject.c index e73de7422290056..77846fb598c1abe 100644 --- a/Objects/longobject.c +++ b/Objects/longobject.c @@ -1506,57 +1506,6 @@ _PyLong_Size_t_Converter(PyObject *obj, void *ptr) Py_RETURN_NOTIMPLEMENTED; \ } while(0) -/* x[0:m] and y[0:n] are digit vectors, LSD first, m >= n required. x[0:n] - * is modified in place, by adding y to it. Carries are propagated as far as - * x[m-1], and the remaining carry (0 or 1) is returned. - */ -static digit -v_iadd(digit *x, Py_ssize_t m, digit *y, Py_ssize_t n) -{ - Py_ssize_t i; - digit carry = 0; - - assert(m >= n); - for (i = 0; i < n; ++i) { - carry += x[i] + y[i]; - x[i] = carry & PyLong_MASK; - carry >>= PyLong_SHIFT; - assert((carry & 1) == carry); - } - for (; carry && i < m; ++i) { - carry += x[i]; - x[i] = carry & PyLong_MASK; - carry >>= PyLong_SHIFT; - assert((carry & 1) == carry); - } - return carry; -} - -/* x[0:m] and y[0:n] are digit vectors, LSD first, m >= n required. x[0:n] - * is modified in place, by subtracting y from it. Borrows are propagated as - * far as x[m-1], and the remaining borrow (0 or 1) is returned. - */ -static digit -v_isub(digit *x, Py_ssize_t m, digit *y, Py_ssize_t n) -{ - Py_ssize_t i; - digit borrow = 0; - - assert(m >= n); - for (i = 0; i < n; ++i) { - borrow = x[i] - y[i] - borrow; - x[i] = borrow & PyLong_MASK; - borrow >>= PyLong_SHIFT; - borrow &= 1; /* keep only 1 sign bit */ - } - for (; borrow && i < m; ++i) { - borrow = x[i] - borrow; - x[i] = borrow & PyLong_MASK; - borrow >>= PyLong_SHIFT; - borrow &= 1; - } - return borrow; -} /* Shift digit vector a[0:m] d bits left, with 0 <= d < PyLong_SHIFT. Put * result in z[0:m], and return the d bits shifted out of the top. @@ -3612,339 +3561,54 @@ x_mul(PyLongObject *a, PyLongObject *b) return long_normalize(z); } -/* A helper for Karatsuba multiplication (k_mul). - Takes an int "n" and an integer "size" representing the place to - split, and sets low and high such that abs(n) == (high << size) + low, - viewing the shift as being by digits. The sign bit is ignored, and - the return values are >= 0. - Returns 0 on success, -1 on failure. -*/ -static int -kmul_split(PyLongObject *n, - Py_ssize_t size, - PyLongObject **high, - PyLongObject **low) +PyObject * +_PyLong_Multiply(PyLongObject *a, PyLongObject *b) { - PyLongObject *hi, *lo; - Py_ssize_t size_lo, size_hi; - const Py_ssize_t size_n = _PyLong_DigitCount(n); - - size_lo = Py_MIN(size_n, size); - size_hi = size_n - size_lo; - - if ((hi = _PyLong_New(size_hi)) == NULL) - return -1; - if ((lo = _PyLong_New(size_lo)) == NULL) { - Py_DECREF(hi); - return -1; + /* fast path for single-digit multiplication */ + if (_PyLong_BothAreCompact(a, b)) { + stwodigits v = medium_value(a) * medium_value(b); + return _PyLong_FromSTwoDigits(v); } - memcpy(lo->long_value.ob_digit, n->long_value.ob_digit, size_lo * sizeof(digit)); - memcpy(hi->long_value.ob_digit, n->long_value.ob_digit + size_lo, size_hi * sizeof(digit)); - - *high = long_normalize(hi); - *low = long_normalize(lo); - return 0; -} - -static PyLongObject *k_lopsided_mul(PyLongObject *a, PyLongObject *b); - -/* Karatsuba multiplication. Ignores the input signs, and returns the - * absolute value of the product (or NULL if error). - * See Knuth Vol. 2 Chapter 4.3.3 (Pp. 294-295). - */ -static PyLongObject * -k_mul(PyLongObject *a, PyLongObject *b) -{ Py_ssize_t asize = _PyLong_DigitCount(a); Py_ssize_t bsize = _PyLong_DigitCount(b); - PyLongObject *ah = NULL; - PyLongObject *al = NULL; - PyLongObject *bh = NULL; - PyLongObject *bl = NULL; - PyLongObject *ret = NULL; - PyLongObject *t1, *t2, *t3; - Py_ssize_t shift; /* the number of digits we split off */ - Py_ssize_t i; + Py_ssize_t cutoff = a == b ? KARATSUBA_SQUARE_CUTOFF : KARATSUBA_CUTOFF; - /* (ah*X+al)(bh*X+bl) = ah*bh*X*X + (ah*bl + al*bh)*X + al*bl - * Let k = (ah+al)*(bh+bl) = ah*bl + al*bh + ah*bh + al*bl - * Then the original product is - * ah*bh*X*X + (k - ah*bh - al*bl)*X + al*bl - * By picking X to be a power of 2, "*X" is just shifting, and it's - * been reduced to 3 multiplies on numbers half the size. - */ - - /* We want to split based on the larger number; fiddle so that b - * is largest. - */ - if (asize > bsize) { - t1 = a; - a = b; - b = t1; - - i = asize; - asize = bsize; - bsize = i; - } + asize = asize > bsize ? bsize : asize; /* Use gradeschool math when either number is too small. */ - i = a == b ? KARATSUBA_SQUARE_CUTOFF : KARATSUBA_CUTOFF; - if (asize <= i) { - if (asize == 0) - return (PyLongObject *)PyLong_FromLong(0); - else - return x_mul(a, b); - } - - /* If a is small compared to b, splitting on b gives a degenerate - * case with ah==0, and Karatsuba may be (even much) less efficient - * than "grade school" then. However, we can still win, by viewing - * b as a string of "big digits", each of the same width as a. That - * leads to a sequence of balanced calls to k_mul. - */ - if (2 * asize <= bsize) - return k_lopsided_mul(a, b); - - /* Split a & b into hi & lo pieces. */ - shift = bsize >> 1; - if (kmul_split(a, shift, &ah, &al) < 0) goto fail; - assert(_PyLong_IsPositive(ah)); /* the split isn't degenerate */ - - if (a == b) { - bh = (PyLongObject*)Py_NewRef(ah); - bl = (PyLongObject*)Py_NewRef(al); - } - else if (kmul_split(b, shift, &bh, &bl) < 0) goto fail; - - /* The plan: - * 1. Allocate result space (asize + bsize digits: that's always - * enough). - * 2. Compute ah*bh, and copy into result at 2*shift. - * 3. Compute al*bl, and copy into result at 0. Note that this - * can't overlap with #2. - * 4. Subtract al*bl from the result, starting at shift. This may - * underflow (borrow out of the high digit), but we don't care: - * we're effectively doing unsigned arithmetic mod - * BASE**(sizea + sizeb), and so long as the *final* result fits, - * borrows and carries out of the high digit can be ignored. - * 5. Subtract ah*bh from the result, starting at shift. - * 6. Compute (ah+al)*(bh+bl), and add it into the result starting - * at shift. - */ - - /* 1. Allocate result space. */ - ret = _PyLong_New(asize + bsize); - if (ret == NULL) goto fail; -#ifdef Py_DEBUG - /* Fill with trash, to catch reference to uninitialized digits. */ - memset(ret->long_value.ob_digit, 0xDF, _PyLong_DigitCount(ret) * sizeof(digit)); -#endif - - /* 2. t1 <- ah*bh, and copy into high digits of result. */ - if ((t1 = k_mul(ah, bh)) == NULL) goto fail; - assert(!_PyLong_IsNegative(t1)); - assert(2*shift + _PyLong_DigitCount(t1) <= _PyLong_DigitCount(ret)); - memcpy(ret->long_value.ob_digit + 2*shift, t1->long_value.ob_digit, - _PyLong_DigitCount(t1) * sizeof(digit)); - - /* Zero-out the digits higher than the ah*bh copy. */ - i = _PyLong_DigitCount(ret) - 2*shift - _PyLong_DigitCount(t1); - if (i) - memset(ret->long_value.ob_digit + 2*shift + _PyLong_DigitCount(t1), 0, - i * sizeof(digit)); - - /* 3. t2 <- al*bl, and copy into the low digits. */ - if ((t2 = k_mul(al, bl)) == NULL) { - Py_DECREF(t1); - goto fail; - } - assert(!_PyLong_IsNegative(t2)); - assert(_PyLong_DigitCount(t2) <= 2*shift); /* no overlap with high digits */ - memcpy(ret->long_value.ob_digit, t2->long_value.ob_digit, _PyLong_DigitCount(t2) * sizeof(digit)); - - /* Zero out remaining digits. */ - i = 2*shift - _PyLong_DigitCount(t2); /* number of uninitialized digits */ - if (i) - memset(ret->long_value.ob_digit + _PyLong_DigitCount(t2), 0, i * sizeof(digit)); - - /* 4 & 5. Subtract ah*bh (t1) and al*bl (t2). We do al*bl first - * because it's fresher in cache. - */ - i = _PyLong_DigitCount(ret) - shift; /* # digits after shift */ - (void)v_isub(ret->long_value.ob_digit + shift, i, t2->long_value.ob_digit, _PyLong_DigitCount(t2)); - _Py_DECREF_INT(t2); - - (void)v_isub(ret->long_value.ob_digit + shift, i, t1->long_value.ob_digit, _PyLong_DigitCount(t1)); - _Py_DECREF_INT(t1); - - /* 6. t3 <- (ah+al)(bh+bl), and add into result. */ - if ((t1 = x_add(ah, al)) == NULL) goto fail; - _Py_DECREF_INT(ah); - _Py_DECREF_INT(al); - ah = al = NULL; - - if (a == b) { - t2 = (PyLongObject*)Py_NewRef(t1); - } - else if ((t2 = x_add(bh, bl)) == NULL) { - Py_DECREF(t1); - goto fail; - } - _Py_DECREF_INT(bh); - _Py_DECREF_INT(bl); - bh = bl = NULL; - - t3 = k_mul(t1, t2); - _Py_DECREF_INT(t1); - _Py_DECREF_INT(t2); - if (t3 == NULL) goto fail; - assert(!_PyLong_IsNegative(t3)); - - /* Add t3. It's not obvious why we can't run out of room here. - * See the (*) comment after this function. - */ - (void)v_iadd(ret->long_value.ob_digit + shift, i, t3->long_value.ob_digit, _PyLong_DigitCount(t3)); - _Py_DECREF_INT(t3); - - return long_normalize(ret); - - fail: - Py_XDECREF(ret); - Py_XDECREF(ah); - Py_XDECREF(al); - Py_XDECREF(bh); - Py_XDECREF(bl); - return NULL; -} - -/* (*) Why adding t3 can't "run out of room" above. - -Let f(x) mean the floor of x and c(x) mean the ceiling of x. Some facts -to start with: - -1. For any integer i, i = c(i/2) + f(i/2). In particular, - bsize = c(bsize/2) + f(bsize/2). -2. shift = f(bsize/2) -3. asize <= bsize -4. Since we call k_lopsided_mul if asize*2 <= bsize, asize*2 > bsize in this - routine, so asize > bsize/2 >= f(bsize/2) in this routine. - -We allocated asize + bsize result digits, and add t3 into them at an offset -of shift. This leaves asize+bsize-shift allocated digit positions for t3 -to fit into, = (by #1 and #2) asize + f(bsize/2) + c(bsize/2) - f(bsize/2) = -asize + c(bsize/2) available digit positions. - -bh has c(bsize/2) digits, and bl at most f(size/2) digits. So bh+hl has -at most c(bsize/2) digits + 1 bit. - -If asize == bsize, ah has c(bsize/2) digits, else ah has at most f(bsize/2) -digits, and al has at most f(bsize/2) digits in any case. So ah+al has at -most (asize == bsize ? c(bsize/2) : f(bsize/2)) digits + 1 bit. - -The product (ah+al)*(bh+bl) therefore has at most - - c(bsize/2) + (asize == bsize ? c(bsize/2) : f(bsize/2)) digits + 2 bits - -and we have asize + c(bsize/2) available digit positions. We need to show -this is always enough. An instance of c(bsize/2) cancels out in both, so -the question reduces to whether asize digits is enough to hold -(asize == bsize ? c(bsize/2) : f(bsize/2)) digits + 2 bits. If asize < bsize, -then we're asking whether asize digits >= f(bsize/2) digits + 2 bits. By #4, -asize is at least f(bsize/2)+1 digits, so this in turn reduces to whether 1 -digit is enough to hold 2 bits. This is so since PyLong_SHIFT=15 >= 2. If -asize == bsize, then we're asking whether bsize digits is enough to hold -c(bsize/2) digits + 2 bits, or equivalently (by #1) whether f(bsize/2) digits -is enough to hold 2 bits. This is so if bsize >= 2, which holds because -bsize >= KARATSUBA_CUTOFF >= 2. - -Note that since there's always enough room for (ah+al)*(bh+bl), and that's -clearly >= each of ah*bh and al*bl, there's always enough room to subtract -ah*bh and al*bl too. -*/ - -/* b has at least twice the digits of a, and a is big enough that Karatsuba - * would pay off *if* the inputs had balanced sizes. View b as a sequence - * of slices, each with the same number of digits as a, and multiply the - * slices by a, one at a time. This gives k_mul balanced inputs to work with, - * and is also cache-friendly (we compute one double-width slice of the result - * at a time, then move on, never backtracking except for the helpful - * single-width slice overlap between successive partial sums). - */ -static PyLongObject * -k_lopsided_mul(PyLongObject *a, PyLongObject *b) -{ - const Py_ssize_t asize = _PyLong_DigitCount(a); - Py_ssize_t bsize = _PyLong_DigitCount(b); - Py_ssize_t nbdone; /* # of b digits already multiplied */ - PyLongObject *ret; - PyLongObject *bslice = NULL; - - assert(asize > KARATSUBA_CUTOFF); - assert(2 * asize <= bsize); - - /* Allocate result space, and zero it out. */ - ret = _PyLong_New(asize + bsize); - if (ret == NULL) - return NULL; - memset(ret->long_value.ob_digit, 0, _PyLong_DigitCount(ret) * sizeof(digit)); - - /* Successive slices of b are copied into bslice. */ - bslice = _PyLong_New(asize); - if (bslice == NULL) - goto fail; - - nbdone = 0; - while (bsize > 0) { - PyLongObject *product; - const Py_ssize_t nbtouse = Py_MIN(bsize, asize); - - /* Multiply the next slice of b by a. */ - memcpy(bslice->long_value.ob_digit, b->long_value.ob_digit + nbdone, - nbtouse * sizeof(digit)); - assert(nbtouse >= 0); - _PyLong_SetSignAndDigitCount(bslice, 1, nbtouse); - product = k_mul(a, bslice); - if (product == NULL) - goto fail; - - /* Add into result. */ - (void)v_iadd(ret->long_value.ob_digit + nbdone, _PyLong_DigitCount(ret) - nbdone, - product->long_value.ob_digit, _PyLong_DigitCount(product)); - _Py_DECREF_INT(product); - - bsize -= nbtouse; - nbdone += nbtouse; - } - - _Py_DECREF_INT(bslice); - return long_normalize(ret); - - fail: - Py_DECREF(ret); - Py_XDECREF(bslice); - return NULL; -} - -PyObject * -_PyLong_Multiply(PyLongObject *a, PyLongObject *b) -{ - PyLongObject *z; + if (asize <= cutoff) { + if (asize == 0) { + return PyLong_FromLong(0); + } + PyObject *z = (PyObject*)x_mul(a, b); - /* fast path for single-digit multiplication */ - if (_PyLong_BothAreCompact(a, b)) { - stwodigits v = medium_value(a) * medium_value(b); - return _PyLong_FromSTwoDigits(v); + /* Negate if exactly one of the inputs is negative. */ + if (!_PyLong_SameSign(a, b) && z) { + _PyLong_Negate((PyLongObject**)&z); + if (z == NULL) { + return NULL; + } + } + return z; } - - z = k_mul(a, b); - /* Negate if exactly one of the inputs is negative. */ - if (!_PyLong_SameSign(a, b) && z) { - _PyLong_Negate(&z); - if (z == NULL) + else { + PyObject *mod = PyImport_ImportModule("_pylong"); + if (mod == NULL) { return NULL; + } + PyObject *z = PyObject_CallMethod(mod, "k_mul", "OO", a, b); + Py_DECREF(mod); + if (z == NULL) { + return NULL; + } + if (!PyLong_CheckExact(z)) { + PyErr_SetString(PyExc_TypeError, + "_pylong.k_mul did not return an int"); + return NULL; + } + return z; } - return (PyObject *)z; } static PyObject *