Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions calculator/calculator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,14 @@ namespace
auto const &v = x.GetInteger();
int hi = (int)v.size() - 1;
while (hi > 0 && v[hi] == 0) --hi;
char buf[16];
std::snprintf(buf, sizeof(buf), "%x", (unsigned)v[hi]);
char buf[32];
std::snprintf(buf, sizeof(buf), "%llx", (unsigned long long)v[hi]);
out += buf;

const char *fmt = (BigInteger::Base() == Base2_64) ? "%016llx" : "%08llx";
for (int i = hi - 1; i >= 0; --i)
{
std::snprintf(buf, sizeof(buf), "%08x", (unsigned)v[i]);
std::snprintf(buf, sizeof(buf), fmt, (unsigned long long)v[i]);
out += buf;
}
return out;
Expand All @@ -77,19 +79,21 @@ namespace
int hi = (int)v.size() - 1;
while (hi > 0 && v[hi] == 0) --hi;

auto pushBits = [&](uint32_t w, int bits) {
auto pushBits = [&](ULong w, int bits) {
for (int i = bits - 1; i >= 0; --i)
out.push_back(((w >> i) & 1u) ? '1' : '0');
out.push_back(((w >> i) & 1ULL) ? '1' : '0');
};

const int limbBits = (BigInteger::Base() == Base2_64) ? 64 : 32;

// Top limb: strip leading zero bits (but keep at least one).
uint32_t top = (uint32_t)v[hi];
int topBits = 32;
while (topBits > 1 && ((top >> (topBits - 1)) & 1u) == 0)
ULong top = v[hi];
int topBits = limbBits;
while (topBits > 1 && ((top >> (topBits - 1)) & 1ULL) == 0)
--topBits;
pushBits(top, topBits);
for (int i = hi - 1; i >= 0; --i)
pushBits((uint32_t)v[i], 32);
pushBits(v[i], limbBits);
return out;
}

Expand Down
104 changes: 103 additions & 1 deletion include/biginteger/BigInteger.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#ifndef BIGINTEGER
#define BIGINTEGER

#include <algorithm>
#include <span>
#include <vector>

#include "common/Util.h"
Expand All @@ -23,7 +25,7 @@ namespace BigMath
// True if the number is negative
bool isNegative;

// Constructor, desctructor, and assignment operator
// Constructor, destructor, and assignment operator
public:
explicit BigInteger(SizeT size = 0, bool negative = false) : theInteger(size == 0 ? 1 : size, 0), isNegative(negative)
{
Expand All @@ -40,6 +42,19 @@ namespace BigMath
}
}

// Zero-copy adoption of a raw limb vector (little-endian limb order).
// Precondition: every limb must be canonical for the current base —
// < 2^32 when built with BIGMATH_LIMB_64=0. No validation is performed;
// non-canonical limbs silently corrupt downstream arithmetic.
BigInteger(std::vector<DataT>&& aInt, bool negative) : theInteger(std::move(aInt)), isNegative(negative)
{
TrimZerosToOne(theInteger);
if (isNegative && Zero())
{
isNegative = false;
}
}

// Filled with specified data
BigInteger(SizeT size, bool negative, DataT fill) : theInteger(size), isNegative(negative)
{
Expand All @@ -65,6 +80,93 @@ namespace BigMath
return theInteger;
}

std::vector<DataT> ReleaseInteger()
{
std::vector<DataT> r = std::move(theInteger);
theInteger = {0};
isNegative = false;
return r;
}

// Byte order for magnitude serialization. Named enum instead of a bool so
// call sites read unambiguously next to FromByteArray's `negative` flag.
enum class ByteOrder
{
BigEndian,
LittleEndian
};

std::vector<uint8_t> ToByteArray(ByteOrder order = ByteOrder::BigEndian) const
{
if (Zero())
{
return {};
}

constexpr SizeT limbBytes = LimbBits / 8;
std::vector<uint8_t> bytes;
bytes.reserve(theInteger.size() * limbBytes);

for (DataT val : theInteger)
{
for (SizeT b = 0; b < limbBytes; ++b)
{
bytes.push_back(static_cast<uint8_t>((val >> (b * 8)) & 0xFF));
}
}

while (!bytes.empty() && bytes.back() == 0)
{
bytes.pop_back();
}

if (order == ByteOrder::BigEndian)
{
std::reverse(bytes.begin(), bytes.end());
}

return bytes;
}

static BigInteger FromByteArray(std::span<const uint8_t> bytes, bool negative, ByteOrder order = ByteOrder::BigEndian)
{
if (bytes.empty())
{
return BigInteger();
}

constexpr SizeT limbBytes = LimbBits / 8;
const size_t numLimbs = (bytes.size() + limbBytes - 1) / limbBytes;
std::vector<DataT> limbs(numLimbs, 0);

auto getByte = [&](size_t idx) -> uint8_t {
if (order == ByteOrder::BigEndian)
{
return bytes[bytes.size() - 1 - idx];
}
else
{
return bytes[idx];
}
};

for (size_t i = 0; i < numLimbs; ++i)
{
DataT limbVal = 0;
for (size_t b = 0; b < limbBytes; ++b)
{
size_t byteIdx = i * limbBytes + b;
if (byteIdx < bytes.size())
{
limbVal |= (static_cast<DataT>(getByte(byteIdx)) << (b * 8));
}
}
limbs[i] = limbVal;
}

return BigInteger(std::move(limbs), negative);
}

DataT operator[](const SizeT i) const
{
return theInteger[i];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ namespace BigMath

ULong carry = 0;

for (Int j = aStart; j <= aEnd; j++)
for (SizeT j = aStart; j <= aEnd; j++)
{
ULong multiply = a[j];
multiply *= b;
Expand All @@ -127,7 +127,7 @@ namespace BigMath
carry = NextCarry(multiply, base);
}

Int j = aEnd + 1;
SizeT j = aEnd + 1;
while (carry > 0)
{
SetOrPush(a, j, LowDigit(carry, base));
Expand Down Expand Up @@ -240,7 +240,7 @@ namespace BigMath

ULong carry = 0;

for (Int j = 0; j < len; j++)
for (SizeT j = 0; j < len; j++)
{
ULong multiply = 0;
SizeT aPos = aStart + j;
Expand Down Expand Up @@ -341,13 +341,13 @@ namespace BigMath
SizeT jStart = rStart + (i - bStart);
for (SizeT j = aStart; j <= aEnd; j++)
{
SizeT k = jStart + (j - aStart);
SizeT kk = jStart + (j - aStart);
ULong multiply = a[j];
multiply *= b[i];
multiply += result[k];
multiply += result[kk];
multiply += carry;

result[k] = LowDigit(multiply, base);
result[kk] = LowDigit(multiply, base);
carry = NextCarry(multiply, base);
}
k = jStart + lenA;
Expand Down
3 changes: 1 addition & 2 deletions include/biginteger/algorithms/multiplication/NTTCore.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,7 @@ namespace BigMath
Int stride = n / outerLen;
Int stride4 = stride << 2;
Int numBlocks = n / outerLen;
Int omega4_off = n / 4;
auto body = [a, halflen, qlen4, qlen8, outerLen, stride, stride4, omega4_off, roots](Int bStart, Int bEnd) {
auto body = [a, halflen, qlen4, qlen8, outerLen, stride, stride4, roots](Int bStart, Int bEnd) {
for (Int b = bStart; b < bEnd; ++b)
{
Int i = b * outerLen;
Expand Down
2 changes: 1 addition & 1 deletion include/biginteger/common/Util.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ namespace BigMath
return std::vector<DataT>{0};
}

inline SizeT FindNonZeroByte(std::vector<DataT> const &a, Int start = 0, Int end = -1)
inline Int FindNonZeroByte(std::vector<DataT> const &a, Int start = 0, Int end = -1)
{
Int i = (end == -1 ? (Int)a.size() : end + 1);
while (i > start && a[i - 1] == 0)
Expand Down
41 changes: 22 additions & 19 deletions src/algorithms/Addition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,30 +95,33 @@ namespace BigMath
aEnd = std::min(aEnd, (SizeT)(a.size() - 1));
bEnd = std::min(bEnd, (SizeT)(b.size() - 1));

Int size = std::max(Len(aStart, aEnd), Len(bStart, bEnd));
Int len = std::max(Len(aStart, aEnd), Len(bStart, bEnd));
if (len <= 0)
return;
SizeT size = (SizeT)len;

if (base == Base2_64)
{
ULong128 carry = 0;
for (Int i = 0; i < size; i++)
for (SizeT i = 0; i < size; i++)
{
ULong128 digitOps = carry;
Int aPos = i + aStart;
if (aPos <= aEnd && aPos < (Int)a.size())
SizeT aPos = i + aStart;
if (aPos <= aEnd && aPos < a.size())
digitOps += a[aPos];
Int bPos = i + bStart;
if (bPos <= bEnd && bPos < (Int)b.size())
SizeT bPos = i + bStart;
if (bPos <= bEnd && bPos < b.size())
digitOps += b[bPos];
Int rPos = rStart + i;
if (rPos < (Int)result.size())
SizeT rPos = rStart + i;
if (rPos < result.size())
result[rPos] = (DataT)(digitOps & 0xFFFFFFFFFFFFFFFFULL);
carry = digitOps >> 64;
}
// Propagate the final carry: a bare += can itself overflow the slot.
// (If the result window ends here the carry is dropped — this is a
// fixed-width window primitive; whole-vector AddTo grows `a` first.)
Int rPos = rStart + size;
while (carry > 0 && rPos < (Int)result.size())
SizeT rPos = rStart + size;
while (carry > 0 && rPos < result.size())
{
carry += result[rPos];
result[rPos] = (DataT)(carry & 0xFFFFFFFFFFFFFFFFULL);
Expand All @@ -129,28 +132,28 @@ namespace BigMath
}

Long carry = 0;
for (Int i = 0; i < size; i++)
for (SizeT i = 0; i < size; i++)
{
Long digitOps = 0;

Int aPos = i + aStart;
if (aPos <= aEnd && aPos < (Int)a.size())
SizeT aPos = i + aStart;
if (aPos <= aEnd && aPos < a.size())
digitOps = a[aPos];

digitOps += carry;

Int bPos = i + bStart;
if (bPos <= bEnd && bPos < (Int)b.size())
SizeT bPos = i + bStart;
if (bPos <= bEnd && bPos < b.size())
digitOps += b[bPos];

Int rPos = rStart + i;
if (rPos < (Int)result.size())
SizeT rPos = rStart + i;
if (rPos < result.size())
result[rPos] = (DataT)(digitOps % base);
carry = digitOps / base;
}
// Propagate the final carry (see the Base2_64 branch above).
Int rPos = rStart + size;
while (carry > 0 && rPos < (Int)result.size())
SizeT rPos = rStart + size;
while (carry > 0 && rPos < result.size())
{
Long digitOps = (Long)result[rPos] + carry;
result[rPos] = (DataT)(digitOps % base);
Expand Down
Loading
Loading