1285 const size_t n_input = scalars_span.size();
1287 return Curve::Group::point_at_infinity;
1296 const size_t num_threads_dispatch =
std::max<size_t>(1, std::min(n_input, max_threads));
1297 const size_t pts_per_thread = (n_input + num_threads_dispatch - 1) / num_threads_dispatch;
1299 return trivial_msm_threaded<Curve>(scalars_span, all_points);
1303 BB_ASSERT_GTE(all_points.size(), scalars_span.start_index + n_input);
1304 std::span<const AffineElement> input_points(&all_points[scalars_span.start_index], n_input);
1306 constexpr size_t FULL_NUM_BITS = ScalarField::modulus.get_msb() + 1;
1309 ScalarField* scalar_ptr =
const_cast<ScalarField*
>(&scalars_span[scalars_span.start_index]);
1320 const bool external_glv_provided = !external_glv_doubled.empty();
1329 const size_t n = use_glv ? (2 * n_input) : n_input;
1330 const size_t NUM_BITS = use_glv ?
size_t{ 128 } : FULL_NUM_BITS;
1333 "working scalar indices must fit in the 29-bit schedule payload");
1335 std::span<const AffineElement> points;
1336 const bool inline_glv_double = use_glv && !external_glv_provided;
1343 const bool dedup_active = dedup_hint;
1357 const size_t arena_total_bytes = compute_arena_bytes_for_msm<Curve>(n_input, external_glv_provided, dedup_active);
1358 round_parallel_detail::MsmArena arena(arena_total_bytes, external_arena);
1372 using round_parallel_detail::MSB_ZERO_SENTINEL;
1374 auto msb_per_scalar = arena.template alloc<uint8_t>(n);
1375 auto per_thread_msb_hist = arena.template alloc<std::array<uint32_t, 256>>(profile_threads);
1386 glv_scalars_storage = arena.template alloc<ScalarField>(n);
1387 if (inline_glv_double) {
1388 glv_points_storage = arena.template alloc<AffineElement>(n);
1400 const BaseField beta = inline_glv_double ? BaseField::cube_root_of_unity() : BaseField{};
1402 auto& th_hist = per_thread_msb_hist[chunk.
thread_index];
1403 for (
size_t i : chunk.
range(n_input)) {
1404 const ScalarField canonical = input_scalars[i].from_montgomery_form_reduced();
1405 const auto split = ScalarField::split_into_endomorphism_scalars(canonical);
1406 const auto& k1 = split.first;
1407 const auto& k2 = split.second;
1408 glv_scalars_storage[2 * i].data[0] = k1[0];
1409 glv_scalars_storage[2 * i].data[1] = k1[1];
1410 glv_scalars_storage[2 * i].data[2] = 0;
1411 glv_scalars_storage[2 * i].data[3] = 0;
1412 glv_scalars_storage[(2 * i) + 1].
data[0] = k2[0];
1413 glv_scalars_storage[(2 * i) + 1].
data[1] = k2[1];
1414 glv_scalars_storage[(2 * i) + 1].
data[2] = 0;
1415 glv_scalars_storage[(2 * i) + 1].
data[3] = 0;
1416 if (inline_glv_double) {
1417 glv_points_storage[2 * i] = input_points[i];
1418 glv_points_storage[(2 * i) + 1].x = input_points[i].x * beta;
1419 glv_points_storage[(2 * i) + 1].y = -input_points[i].y;
1421 round_parallel_detail::record_msb(
1422 round_parallel_detail::msb_of_2limb(k1[0], k1[1]), msb_per_scalar[2 * i], th_hist);
1423 round_parallel_detail::record_msb(
1424 round_parallel_detail::msb_of_2limb(k2[0], k2[1]), msb_per_scalar[(2 * i) + 1], th_hist);
1428 inline_glv_double ? std::span<const AffineElement>(glv_points_storage.data(), n) : external_glv_doubled;
1429 scalars = glv_scalars_storage;
1433 auto& th_hist = per_thread_msb_hist[chunk.
thread_index];
1434 for (
size_t i : chunk.
range(n_input)) {
1435 input_scalars[i].self_from_montgomery_form_reduced();
1436 round_parallel_detail::record_msb(
1437 round_parallel_detail::msb_of_4limb(input_scalars[i].
data), msb_per_scalar[i], th_hist);
1440 scalars = input_scalars;
1441 points = input_points;
1445 for (
size_t t = 0; t < profile_threads; ++t) {
1446 for (
size_t b = 0;
b < 256; ++
b) {
1447 msb_hist[
b] += per_thread_msb_hist[t][
b];
1450 const size_t n_active_early = n -
static_cast<size_t>(msb_hist[0]);
1459 const size_t threads_for_dispatch =
std::max<size_t>(1, std::min(n_active_early, max_threads_dispatch));
1460 const size_t pts_per_thread = (n_active_early + threads_for_dispatch - 1) / threads_for_dispatch;
1463 for (
size_t i : chunk.
range(n)) {
1464 scalars[i].self_to_montgomery_form();
1467 std::span<const ScalarField> scalars_const(scalars.data(), n);
1469 return trivial_msm_threaded<Curve>(ps, points);
1480 size_t effective_num_bits = 0;
1481 for (
size_t bin = 256; bin > 1;) {
1483 if (msb_hist[bin] != 0) {
1484 effective_num_bits = bin;
1488 if (effective_num_bits == 0 || effective_num_bits > NUM_BITS) {
1489 effective_num_bits = NUM_BITS;
1491 const size_t window_bits =
1493 const size_t num_buckets = (
size_t{ 1 } << (window_bits - 1)) + 1;
1511 "window schedule exceeds compile-time max window count");
1523 const size_t max_threads_for_min_batch =
std::max<size_t>(1, n / MIN_BATCH_CAPACITY);
1524 const size_t num_threads = std::min(desired_threads, max_threads_for_min_batch);
1539 size_t B_eff = num_buckets;
1540 for (
size_t w = 0; w < sched.num_windows; ++w) {
1541 B_eff =
std::max(B_eff,
static_cast<size_t>(sched.num_buckets[w]));
1544 const size_t worker_total_for_budget = num_threads;
1546 const size_t bucket_partials_per_window_max =
1548 const size_t per_window_bytes_lo = round_parallel_detail::compute_per_window_bytes<Curve>(
1549 num_threads, B_eff, n, dense_stride_est, worker_total_for_budget);
1551 const size_t global_max_overflow_per_window_for_budget =
1554 const size_t phase_one_prologue_bytes =
1558 const size_t phase_a_cluster_members_cap = phase_a_caps.members_cap;
1559 const size_t phase_a_cluster_offsets_cap = phase_a_caps.offsets_cap;
1565 SUBCHUNK_ENTRIES_CAP,
1566 global_max_overflow_per_window_for_budget,
1568 phase_a_cluster_members_cap,
1569 phase_a_cluster_offsets_cap,
1574 const size_t fixed_overhead = (worker_union_bytes_for_budget * worker_total_for_budget) +
1576 + (
size_t{ 8 } * (num_threads + 1))
1577 + phase_one_prologue_bytes;
1580 const size_t available_budget =
1581 (BATCH_MEM_BUDGET > fixed_overhead) ? (BATCH_MEM_BUDGET - fixed_overhead) :
size_t{ 0 };
1582 const size_t windows_per_batch =
1593 const size_t global_max_chunk_len = (n + num_threads - 1) / num_threads;
1594 const size_t global_max_overflow_per_window =
1595 (global_max_chunk_len + SUBCHUNK_ENTRIES_CAP - 1) / SUBCHUNK_ENTRIES_CAP;
1596 const size_t chunk_capacity =
std::max(SUBCHUNK_ENTRIES_CAP, 2 * global_max_overflow_per_window);
1604 const size_t worker_total = num_threads;
1608 phase_a_scratch.resize(worker_total);
1651 const size_t bytes_P_prefix = arena.cursor;
1656 auto align_up = [](
size_t off,
size_t align) ->
size_t {
return (off + align - 1) & ~(align - 1); };
1657 auto layout_add = [&](
size_t& off,
size_t bytes,
size_t align) { off = align_up(off, align) + bytes; };
1663 global_max_overflow_per_window,
1665 phase_a_cluster_members_cap,
1666 phase_a_cluster_offsets_cap,
1676 size_t bytes_P_extra_layout = 0;
1677 layout_add(bytes_P_extra_layout,
sizeof(
Element) * VAR_WINDOW_WINDOW_SUMS_CAP,
alignof(
Element));
1679 layout_add(bytes_P_extra_layout,
sizeof(uint32_t) * n,
alignof(uint32_t));
1680 layout_add(bytes_P_extra_layout,
1682 alignof(AffineElement));
1690 const size_t arena_base_misalign =
static_cast<size_t>(arena.base_addr & (WORKER_SLAB_ALIGN - 1));
1691 const size_t bytes_P_min = align_up(bytes_P_prefix,
alignof(
Element)) + bytes_P_extra_layout;
1692 const size_t bytes_P = align_up(bytes_P_min + arena_base_misalign, WORKER_SLAB_ALIGN) - arena_base_misalign;
1695 const size_t bytes_W = per_worker_bytes * worker_total;
1701 const size_t bytes_S_total = arena.capacity - bytes_P - bytes_W;
1706 size_t zone_P_cursor = bytes_P_prefix;
1707 size_t zone_S_cursor = 0;
1708 auto zone_P_alloc = [&]<
typename T>(
size_t count) ->
std::span<T> {
1709 return arena.template bump_alloc<T>(count, zone_P_cursor, bytes_P, 0);
1711 auto zone_S_alloc = [&]<
typename T>(
size_t count) ->
std::span<T> {
1712 return arena.template bump_alloc<T>(count, zone_S_cursor, bytes_S_total, bytes_P + bytes_W);
1722 for (
size_t t = 0; t < worker_total; ++t) {
1724 const size_t slab_base = t * per_worker_bytes;
1725 auto& s = thread_scratch[t];
1728 size_t ts_fixed_cur = 0;
1729 auto ts_fixed_alloc = [&]<
typename T>(
size_t count) ->
std::span<T> {
1730 return arena.template bump_alloc<T>(count, ts_fixed_cur, per_worker_union_bytes, bytes_P + slab_base);
1732 s.curr_pts = ts_fixed_alloc.template operator()<AffineElement>(chunk_capacity);
1733 s.curr_buckets = ts_fixed_alloc.template operator()<uint32_t>(chunk_capacity);
1734 s.points_to_add = ts_fixed_alloc.template operator()<AffineElement>(2 * BATCH_CAPACITY);
1735 s.inversion_scratch = ts_fixed_alloc.template operator()<BaseField>(BATCH_CAPACITY);
1736 s.pair_dest = ts_fixed_alloc.template operator()<uint32_t>(BATCH_CAPACITY);
1737 s.overflow_slots = ts_fixed_alloc.template operator()<uint32_t>(global_max_overflow_per_window);
1738 s.overflow_pts = ts_fixed_alloc.template operator()<AffineElement>(global_max_overflow_per_window);
1745 auto pa_alloc = [&]<
typename T>(
size_t count) ->
std::span<T> {
1746 return arena.template bump_alloc<T>(count, pa_cur, per_worker_union_bytes, bytes_P + slab_base);
1748 auto& ps = phase_a_scratch[t];
1750 ps.cluster_members = pa_alloc.template operator()<uint32_t>(phase_a_cluster_members_cap);
1751 ps.cluster_offsets = pa_alloc.template operator()<uint32_t>(phase_a_cluster_offsets_cap);
1752 ps.dirty_slots = pa_alloc.template operator()<uint16_t>(PWAL::PHASE_A_DIRTY_SLOTS_CAP);
1753 ps.bucket_rep = pa_alloc.template operator()<uint32_t>(PWAL::PHASE_A_BUCKET_REP_CAP);
1755 ps.chunk_pts = pa_alloc.template operator()<AffineElement>(PWAL::PHASE_A_CHUNK_CAP);
1756 ps.chunk_ids = pa_alloc.template operator()<uint32_t>(PWAL::PHASE_A_CHUNK_CAP);
1762 size_t ts_tail_cur = per_worker_union_bytes;
1763 auto ts_tail_alloc = [&]<
typename T>(
size_t count) ->
std::span<T> {
1764 return arena.template bump_alloc<T>(count, ts_tail_cur, per_worker_bytes, bytes_P + slab_base);
1766 const size_t dense_total = windows_per_batch * dense_stride_est;
1767 const size_t dense_pair_max = dense_total / 2;
1768 s.dense_buckets = ts_tail_alloc.template operator()<AffineElement>(dense_total);
1769 s.is_present = ts_tail_alloc.template operator()<uint8_t>(dense_total);
1771 s.affine_bucket_indices = ts_tail_alloc.template operator()<uint32_t>(dense_pair_max);
1772 s.affine_bucket_inversion_scratch = ts_tail_alloc.template operator()<BaseField>(dense_pair_max);
1776 s.affine_bucket_stride = dense_stride_est;
1780 const size_t schedule_total = windows_per_batch * n;
1781 auto schedule = zone_S_alloc.template operator()<uint32_t>(schedule_total);
1808 static_assert(
alignof(
Element) <= 32,
"HIST slot O layout assumes alignof(Element) <= 32");
1810 "HIST slot O layout assumes alignof(ChunkOutput) <= 32");
1812 auto align_up_local = [](
size_t off,
size_t a) ->
size_t {
return (off +
a - 1) & ~(
a - 1); };
1815 const size_t hist_h_bytes_total = (
size_t{ 4 } * windows_per_batch * num_threads * B_eff);
1819 size_t o_layout_cur = 0;
1821 const size_t off_chunk_outputs = o_layout_cur;
1823 o_layout_cur = align_up_local(o_layout_cur,
alignof(
typename Curve::Element));
1824 const size_t off_window_partial_sums = o_layout_cur;
1825 o_layout_cur +=
sizeof(
typename Curve::Element) * num_threads * windows_per_batch;
1826 const size_t hist_o_bytes_total = o_layout_cur;
1828 const size_t hist_slot_bytes_total =
std::max(hist_h_bytes_total, hist_o_bytes_total);
1834 const size_t hist_slot_cells = (hist_slot_bytes_total +
sizeof(AffineElement) - 1) /
sizeof(AffineElement);
1835 auto hist_slot_cells_span = zone_S_alloc.template operator()<AffineElement>(hist_slot_cells);
1837 std::byte*
const hist_slot_bytes =
reinterpret_cast<std::byte*
>(hist_slot_cells_span.data());
1851 auto digit_cursors =
1852 std::span<uint32_t>{
reinterpret_cast<uint32_t*
>(hist_slot_bytes), windows_per_batch * num_threads * B_eff };
1864 windows_per_batch * num_threads
1868 reinterpret_cast<typename
Curve::Element*
>(hist_slot_bytes + off_window_partial_sums),
1869 num_threads * windows_per_batch
1887 static_assert(
alignof(AffineElement) == 64,
"DENSE slot D layout assumes alignof(AffineElement) == 64");
1888 const size_t bp_total = windows_per_batch * bucket_partials_per_window_max;
1889 size_t d_layout_cur = 0;
1890 const size_t off_dense = d_layout_cur;
1891 d_layout_cur +=
sizeof(AffineElement) * bp_total;
1892 const size_t off_present = d_layout_cur;
1893 d_layout_cur +=
sizeof(uint8_t) * bp_total;
1894 const size_t dense_slot_bytes_total = d_layout_cur;
1895 const size_t dense_slot_cells = (dense_slot_bytes_total +
sizeof(AffineElement) - 1) /
sizeof(AffineElement);
1898 auto dense_slot_cells_span = zone_S_alloc.template operator()<AffineElement>(dense_slot_cells);
1900 std::byte*
const dense_slot_bytes =
reinterpret_cast<std::byte*
>(dense_slot_cells_span.data());
1903 auto bucket_partials_dense =
1906 auto bucket_partials_present =
1907 std::span<uint8_t>{
reinterpret_cast<uint8_t*
>(dense_slot_bytes + off_present), bp_total };
1910 auto bucket_start_all = zone_S_alloc.template operator()<
size_t>(windows_per_batch * (B_eff + 1));
1911 auto chunk_start_all = zone_S_alloc.template operator()<
size_t>(windows_per_batch * (num_threads + 1));
1919 auto chunk_bucket_lo_all = zone_S_alloc.template operator()<
size_t>(windows_per_batch * (num_threads + 1));
1920 auto chunk_bucket_hi_all = zone_S_alloc.template operator()<
size_t>(windows_per_batch * num_threads);
1925 auto bucket_partials_offsets = zone_S_alloc.template operator()<
size_t>((num_threads * windows_per_batch) + 1);
1932 auto rebalanced_bucket_lo_partition = zone_S_alloc.template operator()<
size_t>(num_threads + 1);
1933 auto orig_thread_lo = zone_S_alloc.template operator()<
size_t>(windows_per_batch * num_threads);
1934 auto orig_thread_hi = zone_S_alloc.template operator()<
size_t>(windows_per_batch * num_threads);
1937 auto window_sums = zone_P_alloc.template operator()<
typename Curve::Element>(VAR_WINDOW_WINDOW_SUMS_CAP);
1938 std::fill_n(window_sums.begin(), VAR_WINDOW_WINDOW_SUMS_CAP, Curve::Group::point_at_infinity);
1946 dedup_state.
redirect_lookup = zone_P_alloc.template operator()<uint32_t>(n);
1952 for (
size_t i : chunk.
range(n)) {
1960 constexpr uint32_t BUCKET_MASK = (uint32_t{ 1 } << 31) - 1;
1968 bool phase_a_done =
false;
1970 auto run_batch = [&](
size_t batch_start,
size_t windows_in_batch,
size_t B_R)
noexcept {
1973 const size_t bucket_stride = B_eff;
1977 constexpr size_t SCALAR_UINT64_LIMBS =
sizeof(ScalarField) /
sizeof(uint64_t);
1985 constexpr size_t SCALAR_U32_LIMBS =
sizeof(ScalarField) /
sizeof(uint32_t);
1986 for (
size_t w = 0; w < windows_in_batch; ++w) {
1987 const size_t global_w = batch_start + w;
1988 const size_t window_bits_w = sched.window_bits_per_window[global_w];
1989 per_window_bits[w] =
static_cast<uint8_t
>(window_bits_w);
1991 sched.bit_base[global_w], window_bits_w, SCALAR_UINT64_LIMBS);
1993 sched.bit_base[global_w], window_bits_w, SCALAR_U32_LIMBS);
1995 const uint32_t lo_mask = slice_params_u32[w].lo_mask;
1996 const uint32_t hi_mask = slice_params_u32[w].hi_mask;
1997 const uint32_t val_mask = (uint32_t{ 1 } <<
static_cast<uint32_t
>(window_bits_w)) - 1;
2003 constexpr size_t SIMD_BATCH = 64;
2004 static_assert(SIMD_BATCH % 4 == 0,
"SIMD_BATCH must be divisible by 4");
2005 constexpr size_t LIMBS_PER_SCALAR =
sizeof(ScalarField) /
sizeof(uint32_t);
2006 const auto* scalars_u32 =
reinterpret_cast<const uint32_t*
>(scalars.data());
2008 auto fill_packed_digit_buffer = [&](
size_t w,
size_t i, uint32_t* packed_buf)
noexcept {
2009 const auto& sp32 = slice_params_u32[w];
2010 const uint32_t window_bits_w =
static_cast<uint32_t
>(per_window_bits[w]);
2012 for (
size_t k = 0; k < SIMD_BATCH; k += 4) {
2015 scalars_u32 + ((i + k + 0) * LIMBS_PER_SCALAR),
2016 scalars_u32 + ((i + k + 1) * LIMBS_PER_SCALAR),
2017 scalars_u32 + ((i + k + 2) * LIMBS_PER_SCALAR),
2018 scalars_u32 + ((i + k + 3) * LIMBS_PER_SCALAR),
2023 val_mask_vectors[w],
2027 for (
size_t k = 0; k < SIMD_BATCH; k += 4) {
2030 scalars_u32 + ((i + k + 0) * LIMBS_PER_SCALAR),
2031 scalars_u32 + ((i + k + 1) * LIMBS_PER_SCALAR),
2032 scalars_u32 + ((i + k + 2) * LIMBS_PER_SCALAR),
2033 scalars_u32 + ((i + k + 3) * LIMBS_PER_SCALAR),
2038 val_mask_vectors[w],
2042 for (
size_t k = 0; k < SIMD_BATCH; k += 4) {
2045 scalars_u32 + ((i + k + 0) * LIMBS_PER_SCALAR),
2046 scalars_u32 + ((i + k + 1) * LIMBS_PER_SCALAR),
2047 scalars_u32 + ((i + k + 2) * LIMBS_PER_SCALAR),
2048 scalars_u32 + ((i + k + 3) * LIMBS_PER_SCALAR),
2056 val_mask_vectors[w],
2065 const bool phase_a_done_at_batch_start = phase_a_done;
2066 const bool dedup_known_for_batch =
2067 dedup_active && phase_a_done_at_batch_start && dedup_state.
n_dedup_extras != 0;
2072 auto stage1_digit_extract = [&]<
bool DedupKnown>(
size_t tid)
noexcept {
2073 [[maybe_unused]]
const uint32_t*
const rl_data = dedup_state.
redirect_lookup.data();
2074 for (
size_t w = 0; w < windows_in_batch; ++w) {
2075 uint32_t* my_counts = digit_cursors.data() + (((w * num_threads) + tid) * bucket_stride);
2076 std::memset(my_counts, 0, B_R *
sizeof(uint32_t));
2078 const size_t start = tid * n / num_threads;
2079 const size_t end = (tid + 1) * n / num_threads;
2085 auto compute_include_mask = [&](
size_t block_start)
noexcept -> uint64_t {
2086 uint64_t include_mask = 0;
2087 for (
size_t k = 0; k < SIMD_BATCH; ++k) {
2088 const size_t scalar_idx = block_start + k;
2089 const uint8_t m = msb_per_scalar[scalar_idx];
2090 bool include = (m != MSB_ZERO_SENTINEL);
2091 if constexpr (DedupKnown) {
2093 const uint32_t patch = rl_data[scalar_idx];
2098 include_mask |=
static_cast<uint64_t
>(include) << k;
2100 return include_mask;
2104 while (i + SIMD_BATCH <= end) {
2105 const uint64_t include_mask = compute_include_mask(i);
2106 if (include_mask == 0) {
2110 const bool all_included = include_mask == ~uint64_t{ 0 };
2111 for (
size_t w = 0; w < windows_in_batch; ++w) {
2112 fill_packed_digit_buffer(w, i, packed_buf.data());
2113 uint32_t* my_counts = digit_cursors.data() + (((w * num_threads) + tid) * bucket_stride);
2115 for (
size_t k = 0; k < SIMD_BATCH; ++k) {
2116 ++my_counts[packed_buf[k] & BUCKET_MASK];
2119 uint64_t scatter_mask = include_mask;
2120 for (
size_t k = 0; k < SIMD_BATCH; ++k) {
2121 if ((scatter_mask & uint64_t{ 1 }) != 0) {
2122 ++my_counts[packed_buf[k] & BUCKET_MASK];
2133 for (; i < end; ++i) {
2134 const uint8_t m = msb_per_scalar[i];
2135 if (m == MSB_ZERO_SENTINEL) {
2138 if constexpr (DedupKnown) {
2139 const uint32_t patch = rl_data[i];
2145 for (
size_t w = 0; w < windows_in_batch; ++w) {
2146 uint32_t* my_counts = digit_cursors.data() + (((w * num_threads) + tid) * bucket_stride);
2148 const uint32_t window_bits_w =
static_cast<uint32_t
>(per_window_bits[w]);
2149 const uint32_t packed =
2159 ++my_counts[packed & BUCKET_MASK];
2163 if (dedup_known_for_batch) {
2164 bb::parallel_for(num_threads, [&](
size_t tid) { stage1_digit_extract.template operator()<
true>(tid); });
2166 bb::parallel_for(num_threads, [&](
size_t tid) { stage1_digit_extract.template operator()<
false>(tid); });
2180 const size_t d_start = tid * B_R / num_threads;
2181 const size_t d_end = (tid + 1) * B_R / num_threads;
2182 for (
size_t w = 0; w < windows_in_batch; ++w) {
2183 size_t*
const bucket_start_w = bucket_start_all.data() + (w * (bucket_stride + 1));
2184 for (
size_t d = d_start; d < d_end; ++d) {
2188 uint32_t running = 0;
2189 for (
size_t t = 0; t < num_threads; ++t) {
2190 const size_t k = (((w * num_threads) + t) * bucket_stride) + d;
2191 const uint32_t cnt = digit_cursors[k];
2192 digit_cursors[k] = running;
2195 bucket_start_w[d + 1] = running;
2205 auto build_bucket_offsets_for_window = [&](
size_t w)
noexcept {
2206 size_t* bucket_start = bucket_start_all.data() + (w * (bucket_stride + 1));
2207 bucket_start[0] = 0;
2208 bucket_start[1] = 0;
2209 for (
size_t d = 1; d < B_R; ++d) {
2210 bucket_start[d + 1] += bucket_start[d];
2213 const size_t offset_threads = std::min(num_threads, windows_in_batch);
2214 if (offset_threads <= 1) {
2215 for (
size_t w = 0; w < windows_in_batch; ++w) {
2216 build_bucket_offsets_for_window(w);
2220 for (
size_t w = tid; w < windows_in_batch; w += offset_threads) {
2221 build_bucket_offsets_for_window(w);
2244 auto stage4_emit = [&]<
bool DedupKnown>(
size_t tid)
noexcept {
2245 [[maybe_unused]]
const uint32_t*
const rl_data = dedup_state.
redirect_lookup.data();
2246 const size_t start = tid * n / num_threads;
2247 const size_t end = (tid + 1) * n / num_threads;
2251 for (
size_t w = 0; w < windows_in_batch; ++w) {
2252 cursors[w] = digit_cursors.data() + (((w * num_threads) + tid) * bucket_stride);
2253 bucket_starts[w] = bucket_start_all.data() + (w * (bucket_stride + 1));
2254 schedules[w] = schedule.data() + (w * n);
2258 constexpr size_t STAGE4_SCALAR_TILE = 2048;
2262 for (
size_t tile_start = start; tile_start < end; tile_start += STAGE4_SCALAR_TILE) {
2263 const size_t tile_end = std::min(end, tile_start + STAGE4_SCALAR_TILE);
2264 const size_t tile_len = tile_end - tile_start;
2265 for (
size_t j = 0; j < tile_len; ++j) {
2266 const size_t scalar_idx = tile_start + j;
2267 const uint8_t m = msb_per_scalar[scalar_idx];
2268 bool include = (m != MSB_ZERO_SENTINEL);
2269 if constexpr (DedupKnown) {
2270 uint32_t out_base =
static_cast<uint32_t
>(scalar_idx);
2272 const uint32_t patch = rl_data[scalar_idx];
2278 out_base_tile[j] = out_base;
2280 active_tile[j] =
static_cast<uint8_t
>(include);
2283 for (
size_t w = 0; w < windows_in_batch; ++w) {
2284 uint32_t* my_cursor = cursors[w];
2285 const size_t* bucket_start = bucket_starts[w];
2286 uint32_t* sched_w = schedules[w];
2287 size_t i = tile_start;
2288 while (i + SIMD_BATCH <= tile_end) {
2289 const size_t rel = i - tile_start;
2290 uint64_t include_mask = 0;
2291 for (
size_t k = 0; k < SIMD_BATCH; ++k) {
2292 include_mask |=
static_cast<uint64_t
>(active_tile[rel + k]) << k;
2294 if (include_mask == 0) {
2298 fill_packed_digit_buffer(w, i, packed_buf.data());
2299 uint64_t scatter_mask = include_mask;
2300 for (
size_t k = 0; k < SIMD_BATCH; ++k) {
2301 if ((scatter_mask & uint64_t{ 1 }) != 0) {
2302 const uint32_t packed = packed_buf[k];
2303 const uint32_t bucket_idx = packed & BUCKET_MASK;
2304 if (bucket_idx != 0) {
2305 const uint32_t idx =
2306 static_cast<uint32_t
>(bucket_start[bucket_idx]) + my_cursor[bucket_idx]++;
2308 if constexpr (DedupKnown) {
2309 out |= out_base_tile[rel + k];
2311 out |=
static_cast<uint32_t
>(i + k);
2320 for (; i < tile_end; ++i) {
2321 const size_t rel = i - tile_start;
2322 if (active_tile[rel] == 0) {
2335 static_cast<uint32_t
>(per_window_bits[w]));
2336 const uint32_t bucket_idx = packed & BUCKET_MASK;
2337 if (bucket_idx != 0) {
2338 const uint32_t idx =
2339 static_cast<uint32_t
>(bucket_start[bucket_idx]) + my_cursor[bucket_idx]++;
2341 if constexpr (DedupKnown) {
2342 out |= out_base_tile[rel];
2344 out |=
static_cast<uint32_t
>(i);
2353 if (dedup_known_for_batch) {
2354 bb::parallel_for(num_threads, [&](
size_t tid) { stage4_emit.template operator()<
true>(tid); });
2356 bb::parallel_for(num_threads, [&](
size_t tid) { stage4_emit.template operator()<
false>(tid); });
2371 if (dedup_active && windows_in_batch > 0 && !phase_a_done) {
2373 uint32_t* sched_w0 = schedule.data();
2383 const uint32_t cids_per_thread =
2395 const size_t*
const w0_bucket_start = bucket_start_all.data();
2396 std::atomic<size_t> dedup_cluster_count{ 0 };
2398 const size_t b_lo = 1 + ((tid * (B_R - 1)) / num_threads);
2399 const size_t b_hi = 1 + (((tid + 1) * (B_R - 1)) / num_threads);
2400 const uint32_t cid_lo =
static_cast<uint32_t
>(tid) * cids_per_thread;
2401 const uint32_t cid_max = cid_lo + cids_per_thread;
2402 const size_t local_clusters = round_parallel_detail::dedup_phase_a_worker_hash<Curve>(
2407 std::span<const ScalarField>(scalars.data(), n),
2411 msb_per_scalar.data(),
2415 phase_a_scratch[tid]);
2416 if (local_clusters != 0) {
2422 phase_a_done =
true;
2432 auto partition_chunks_for_window = [&](
size_t w)
noexcept {
2433 const size_t* bucket_start = bucket_start_all.data() + (w * (bucket_stride + 1));
2434 const size_t*
const bucket_start_end = bucket_start + B_R + 1;
2435 size_t* chunk_start = chunk_start_all.data() + (w * (num_threads + 1));
2436 size_t* chunk_bucket_lo = chunk_bucket_lo_all.data() + (w * (num_threads + 1));
2437 size_t* chunk_bucket_hi = chunk_bucket_hi_all.data() + (w * num_threads);
2438 const size_t m = bucket_start[B_R];
2439 const size_t* search_begin = bucket_start + 1;
2441 chunk_start[0] = lo;
2442 for (
size_t t = 0; t < num_threads; ++t) {
2443 const size_t hi = ((t + 1) == num_threads) ? m : (((t + 1) * m) / num_threads);
2444 chunk_start[t + 1] = hi;
2446 const size_t*
const lo_it =
std::upper_bound(search_begin, bucket_start_end, lo);
2447 const size_t lo_bucket =
static_cast<size_t>(lo_it - bucket_start - 1);
2448 const size_t*
const hi_it =
std::upper_bound(lo_it, bucket_start_end, hi - 1);
2449 const size_t hi_bucket =
static_cast<size_t>(hi_it - bucket_start - 1);
2450 chunk_bucket_lo[t] = lo_bucket;
2451 chunk_bucket_hi[t] = hi_bucket;
2452 search_begin = hi_it;
2454 chunk_bucket_lo[t] = B_R;
2455 chunk_bucket_hi[t] = 0;
2459 chunk_bucket_lo[num_threads] = B_R;
2462 bool chunk_partition_done =
false;
2463 if (dedup_active && windows_in_batch > 0 && phase_a_done && !phase_a_done_at_batch_start) {
2466 const size_t bs_stride = bucket_stride + 1;
2467 const size_t br = B_R;
2468 const size_t cap_R = n;
2469 bb::parallel_for(num_threads, [&, rl_data, bs_stride, br, cap_R](
size_t tid)
noexcept {
2470 for (
size_t w = tid; w < windows_in_batch; w += num_threads) {
2471 uint32_t* sched_w = schedule.data() + (w * cap_R);
2472 size_t* bucket_start_w = bucket_start_all.data() + (w * bs_stride);
2473 round_parallel_detail::dedup_patch_schedule_window<Curve>(sched_w, bucket_start_w, br, rl_data);
2474 partition_chunks_for_window(w);
2477 chunk_partition_done =
true;
2486 if (!chunk_partition_done) {
2487 for (
size_t w = 0; w < windows_in_batch; ++w) {
2488 partition_chunks_for_window(w);
2504 auto next_pow2 = [](
size_t x) ->
size_t {
2515 size_t max_chunk_len = 0;
2516 for (
size_t t = 0; t < num_threads; ++t) {
2517 for (
size_t w = 0; w < windows_in_batch; ++w) {
2518 const size_t* chunk_start = chunk_start_all.data() + (w * (num_threads + 1));
2519 const size_t entries_in_chunk = chunk_start[t + 1] - chunk_start[t];
2520 if (entries_in_chunk == 0) {
2523 max_chunk_len =
std::max(max_chunk_len, entries_in_chunk);
2535 size_t global_stride = 0;
2540 const size_t active_digits = (B_R > 0) ? (B_R - 1) : 0;
2541 for (
size_t t = 0; t <= num_threads; ++t) {
2542 rebalanced_bucket_lo_partition[t] = 1 + (t * active_digits) / num_threads;
2544 rebalanced_bucket_lo_partition[num_threads] = B_R;
2545 size_t max_buckets_per_task = 0;
2546 for (
size_t t = 0; t + 1 <= num_threads; ++t) {
2547 const size_t hi_d = (t + 1 == num_threads) ? (B_R - 1) : (rebalanced_bucket_lo_partition[t + 1] - 1);
2548 const size_t lo_d = rebalanced_bucket_lo_partition[t];
2550 max_buckets_per_task =
std::max(max_buckets_per_task, hi_d - lo_d + 1);
2553 global_stride = next_pow2(max_buckets_per_task);
2558 for (
size_t w = 0; w < windows_in_batch; ++w) {
2559 const size_t* chunk_bucket_lo = chunk_bucket_lo_all.data() + (w * (num_threads + 1));
2560 const size_t* chunk_bucket_hi = chunk_bucket_hi_all.data() + (w * num_threads);
2561 const size_t* chunk_start_w = chunk_start_all.data() + (w * (num_threads + 1));
2562 for (
size_t tprime = 0; tprime < num_threads; ++tprime) {
2563 const size_t lo_d = rebalanced_bucket_lo_partition[tprime];
2565 (tprime + 1 == num_threads) ? (B_R - 1) : (rebalanced_bucket_lo_partition[tprime + 1] - 1);
2566 size_t lo_orig = num_threads;
2568 for (
size_t t = 0; t < num_threads; ++t) {
2569 const size_t entries = chunk_start_w[t + 1] - chunk_start_w[t];
2573 const size_t cl = chunk_bucket_lo[t];
2574 const size_t ch = chunk_bucket_hi[t];
2575 if (ch < lo_d || cl > hi_d) {
2578 if (lo_orig == num_threads) {
2583 orig_thread_lo[(w * num_threads) + tprime] = lo_orig;
2584 orig_thread_hi[(w * num_threads) + tprime] = hi_orig;
2592 size_t bucket_partials_cursor = 0;
2593 for (
size_t t = 0; t < num_threads; ++t) {
2594 for (
size_t w = 0; w < windows_in_batch; ++w) {
2595 bucket_partials_offsets[(t * windows_in_batch) + w] = bucket_partials_cursor;
2596 const size_t* chunk_bucket_lo_w = chunk_bucket_lo_all.data() + (w * (num_threads + 1));
2597 const size_t* chunk_bucket_hi_w = chunk_bucket_hi_all.data() + (w * num_threads);
2598 const size_t* chunk_start_w = chunk_start_all.data() + (w * (num_threads + 1));
2599 const size_t entries = chunk_start_w[t + 1] - chunk_start_w[t];
2601 bucket_partials_cursor += chunk_bucket_hi_w[t] - chunk_bucket_lo_w[t] + 1;
2605 bucket_partials_offsets[num_threads * windows_in_batch] = bucket_partials_cursor;
2606 const size_t bucket_partials_total = bucket_partials_cursor;
2607 BB_ASSERT_LTE(bucket_partials_total, bucket_partials_dense.size());
2608 std::memset(bucket_partials_present.data(), 0, bucket_partials_total);
2613 for (
size_t t = 0; t < worker_total; ++t) {
2614 thread_scratch[t].affine_bucket_stride = global_stride;
2623 auto bucket_partials_per_thread_lambda = [&](
size_t tid) {
2624 auto& s = thread_scratch[tid];
2625 for (
size_t w = 0; w < windows_in_batch; ++w) {
2626 const size_t* chunk_start_w = chunk_start_all.data() + (w * (num_threads + 1));
2627 const size_t cs_lo = chunk_start_w[tid];
2628 const size_t cs_hi = chunk_start_w[tid + 1];
2629 if (cs_lo == cs_hi) {
2632 const uint32_t* sched_w = schedule.data() + (w * n);
2633 const size_t* bucket_start = bucket_start_all.data() + (w * (bucket_stride + 1));
2634 AffineElement* dst_dense =
2635 bucket_partials_dense.data() + bucket_partials_offsets[(tid * windows_in_batch) + w];
2636 uint8_t* dst_present =
2637 bucket_partials_present.data() + bucket_partials_offsets[(tid * windows_in_batch) + w];
2638 const size_t* chunk_bucket_lo = chunk_bucket_lo_all.data() + (w * (num_threads + 1));
2639 const uint32_t my_lo =
static_cast<uint32_t
>(chunk_bucket_lo[tid]);
2640 const size_t my_hi = chunk_bucket_hi_all[(w * num_threads) + tid];
2641 size_t bucket_cursor = my_lo;
2643 for (
size_t pos = cs_lo; pos < cs_hi;) {
2644 const size_t end = std::min(pos + SUBCHUNK_ENTRIES_CAP, cs_hi);
2645 reduce_chunk<Curve>(s,
2653 std::span<const AffineElement>(dedup_state.
extra_points));
2654 const size_t len = s.result_len;
2655 for (
size_t k = 0; k <
len; ++k) {
2656 const uint32_t d = s.curr_buckets[k];
2657 const size_t slot = d - my_lo;
2658 if (dst_present[
slot]) {
2659 s.overflow_slots[s.overflow_len] =
static_cast<uint32_t
>(
slot);
2660 s.overflow_pts[s.overflow_len] = s.curr_pts[k];
2663 dst_dense[
slot] = s.curr_pts[k];
2664 dst_present[
slot] = 1;
2669 merge_overflow<Curve>(s, dst_dense);
2681 auto bucket_reduce_cross_thread_lambda = [&](
size_t tprime) {
2682 auto& s = thread_scratch[tprime];
2683 Element* my_partials = window_partial_sums.data() + (tprime * windows_per_batch);
2684 for (
size_t w = 0; w < windows_in_batch; ++w) {
2685 my_partials[w] = Curve::Group::point_at_infinity;
2688 const size_t stride = s.affine_bucket_stride;
2689 std::memset(s.is_present.data(), 0, windows_in_batch * stride);
2691 const size_t lo_d = rebalanced_bucket_lo_partition[tprime];
2693 (tprime + 1 == num_threads) ? (B_R - 1) : (rebalanced_bucket_lo_partition[tprime + 1] - 1);
2694 const uint32_t lo_d_u =
static_cast<uint32_t
>(lo_d);
2695 const uint32_t hi_d_u =
static_cast<uint32_t
>(hi_d);
2697 bool any_nonempty =
false;
2698 for (
size_t w = 0; w < windows_in_batch; ++w) {
2699 auto&
info = s.chunk_infos[w];
2700 auto& out = chunk_outputs[(w * num_threads) + tprime];
2705 info.buckets_padded = 0;
2709 const size_t orig_lo = orig_thread_lo[(w * num_threads) + tprime];
2710 const size_t orig_hi = orig_thread_hi[(w * num_threads) + tprime];
2711 if (orig_lo == num_threads) {
2715 info.buckets_padded = 0;
2719 const size_t base = w * stride;
2720 bool has_data =
false;
2727 const size_t* chunk_bucket_lo_w = chunk_bucket_lo_all.data() + (w * (num_threads + 1));
2728 const size_t* chunk_bucket_hi_w = chunk_bucket_hi_all.data() + (w * num_threads);
2729 for (
size_t t = orig_lo; t <= orig_hi; ++t) {
2730 const size_t cl = chunk_bucket_lo_w[t];
2731 const size_t ch = chunk_bucket_hi_w[t];
2734 if (d_lo_clip > d_hi_clip) {
2737 const AffineElement* src_dense =
2738 bucket_partials_dense.data() + bucket_partials_offsets[(t * windows_in_batch) + w];
2739 const uint8_t* src_present =
2740 bucket_partials_present.data() + bucket_partials_offsets[(t * windows_in_batch) + w];
2741 for (
size_t d = d_lo_clip; d <= d_hi_clip; ++d) {
2742 const size_t src_slot = d - cl;
2743 if (src_present[src_slot] == 0) {
2746 const size_t dst_slot = base + (d - lo_d);
2747 if (s.is_present[dst_slot] == 0) {
2748 s.dense_buckets[dst_slot] = src_dense[src_slot];
2749 s.is_present[dst_slot] = 1;
2756 acc +=
Element(src_dense[src_slot]);
2757 s.dense_buckets[dst_slot] = AffineElement(acc);
2766 info.buckets_padded = 0;
2770 any_nonempty =
true;
2771 const size_t M = hi_d - lo_d + 1;
2772 const uint32_t buckets_padded =
2773 (M == 1) ? 1 : (uint32_t{ 1 } << (32 - __builtin_clz(
static_cast<uint32_t
>(M - 1))));
2777 info.buckets_padded = buckets_padded;
2783 if (!any_nonempty) {
2787 round_parallel_detail::recursive_affine_bucket_reduce_strided<Curve>(
2788 s, s.chunk_infos.data(), windows_in_batch, chunk_outputs.data() + tprime, num_threads);
2790 for (
size_t w = 0; w < windows_in_batch; ++w) {
2791 auto& out = chunk_outputs[(w * num_threads) + tprime];
2792 if (out.empty == 0) {
2793 my_partials[w] = round_parallel_detail::chunk_contribution<Curve>(out);
2806 const size_t reduce_threads = std::min(num_threads, windows_in_batch);
2808 const size_t lo = rid * windows_in_batch / reduce_threads;
2809 const size_t hi = (rid + 1) * windows_in_batch / reduce_threads;
2810 for (
size_t w = lo; w < hi; ++w) {
2811 Element sum = Curve::Group::point_at_infinity;
2812 for (
size_t tid = 0; tid < num_threads; ++tid) {
2813 sum += window_partial_sums[(tid * windows_per_batch) + w];
2815 window_sums[batch_start + w] =
sum;
2823 const size_t B_R = (
size_t{ 1 } << (window_bits - 1)) + 1;
2824 for (
size_t batch_start = 0; batch_start < sched.num_windows; batch_start += windows_per_batch) {
2825 const size_t windows_in_batch = std::min(windows_per_batch, sched.num_windows - batch_start);
2826 run_batch(batch_start, windows_in_batch, B_R);
2832 Element result = (sched.num_windows == 0) ? Curve::Group::point_at_infinity : window_sums[sched.num_windows - 1];
2833 for (
size_t w_rev = sched.num_windows - 1; w_rev > 0; --w_rev) {
2834 const size_t window_bits_w = sched.window_bits_per_window[w_rev - 1];
2835 for (
size_t d = 0; d < window_bits_w; ++d) {
2838 result += window_sums[w_rev - 1];
2845 for (
size_t i : chunk.
range(n_input)) {
2846 input_scalars[i].self_to_montgomery_form();