Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
scalar_multiplication_fast.cpp
Go to the documentation of this file.
2
13
14#include <algorithm>
15#include <atomic>
16#include <bit>
17#include <cstddef>
18#include <cstdint>
19#include <limits>
20#include <memory>
21#include <span>
22#include <vector>
23
24#ifdef __wasm_simd128__
25#include <wasm_simd128.h>
26#endif
27
29
31{
32#ifdef __wasm__
33 if (n_input <= (size_t{ 1 } << 11)) {
34 return 1;
35 }
36 if (n_input <= (size_t{ 1 } << 15)) {
37 return 2;
38 }
39 return 4;
40#else
41 static_cast<void>(n_input);
42 return 4;
43#endif
44}
45
46namespace round_parallel_detail {
47
48// Anonymous namespace gives all TU-private helpers in `round_parallel_detail` internal
49// linkage (clang-tidy `misc-use-anonymous-namespace`). It is briefly closed and reopened
50// around `pippenger_round_parallel_jacobian_fast`, which has external linkage via
51// `extern template` declarations in the header.
52namespace {
53
54// Bulk-copy a 64-byte affine point (BN254 / Grumpkin layout: 8 × uint64_t).
55// On wasm, V8 TurboFan compiles the default struct copy to 8 i64 loads/stores; explicit
56// v128 loads/stores halve that and roughly double throughput on random-gather access.
57// On native, std::memcpy of a constant-size struct already lowers to 4 × movdqu.
58template <typename AffineElement>
59[[gnu::always_inline]] inline void copy_affine64(AffineElement& dst, const AffineElement& src) noexcept
60{
61 static_assert(sizeof(AffineElement) == 64, "copy_affine64 requires 64-byte affine point");
63 "AffineElement must be trivially copyable for memcpy / SIMD bulk copy "
64 "(also required by the bulk std::memcpy of reduce_chunk output into "
65 "ThreadScratch::window_pts in recursive_affine_bucket_reduce_strided's caller)");
66#ifdef __wasm_simd128__
67 const auto* s = reinterpret_cast<const v128_t*>(&src);
68 auto* d = reinterpret_cast<v128_t*>(&dst);
69 const v128_t a = wasm_v128_load(s + 0);
70 const v128_t b = wasm_v128_load(s + 1);
71 const v128_t c = wasm_v128_load(s + 2);
72 const v128_t e = wasm_v128_load(s + 3);
73 wasm_v128_store(d + 0, a);
74 wasm_v128_store(d + 1, b);
75 wasm_v128_store(d + 2, c);
76 wasm_v128_store(d + 3, e);
77#else
78 std::memcpy(&dst, &src, sizeof(AffineElement));
79#endif
80}
81
82// Constantine signed-Booth window recoder (scalar + SIMD x4 paths) lives in
83// pippenger_constantine.hpp.
84
85// `choose_window_bits` and `build_var_window_schedule` are defined inline in
86// `pippenger_arena_layout.hpp` so the test suite can build identical schedules.
87// `VAR_WINDOW_MAX_WINDOWS` and `VariableWindowSchedule` likewise live there.
88
89// Sentinel value for `msb_per_scalar[i]` when scalar i is zero. uint8_t fits the 254 valid msb
90// positions (0..253) plus this sentinel; matching `msb_hist` bin layout uses bin 0 = zero count
91// so callers index via `msb + 1` (with -1 → bin 0 for the zero case).
92inline constexpr uint8_t MSB_ZERO_SENTINEL = 255;
93
94// Batched-affine drain trigger. `tree_reduce_in_place` accumulates same-bucket pair
95// candidates into the per-thread `points_to_add` / `pair_dest` scratch and drains via a
96// single inversion + N-pair add when the queue hits this size. Sizing trade-off:
97// - higher = larger inversion amortisation = lower per-pair cost,
98// - lower = smaller scratch / less L1 pressure but more drain calls.
99// 256 was chosen empirically: keeps `points_to_add` (256 × 64 B = 16 KB) inside L1, is
100// well above the ~32-pair amortisation breakeven, and is the value the per-OS-thread
101// scratch buffers (`points_to_add`, `inversion_scratch`, `pair_dest`) are sized for.
102//
103// Deliberately a compile-time constant rather than a per-call parameter: the only sites
104// that ever passed a different value were chunks shorter than 256, where the early-drain
105// branch never fires anyway (the end-of-loop drain catches the residue). Keeping it
106// constexpr lets the compiler turn the per-iter `if (pair_count >= BATCH_CAPACITY)` into
107// a compare-against-immediate and fold the drain-trigger condition into the loop shape.
108// `BATCH_CAPACITY` is defined in `pippenger_arena_layout.hpp` so the layout struct can
109// reference it without depending on this TU.
110
111inline int msb_of_2limb(uint64_t lo, uint64_t hi) noexcept
112{
113 if (hi != 0) {
114 return 64 + 63 - __builtin_clzll(hi);
115 }
116 if (lo != 0) {
117 return 63 - __builtin_clzll(lo);
118 }
119 return -1;
120}
121
122// Accepts the raw `uint64_t[4]` `.data` of `uint256_t` / field elements directly.
123inline int msb_of_4limb(const uint64_t (&d)[4]) noexcept // NOLINT(cppcoreguidelines-avoid-c-arrays)
124{
125 if (d[3] != 0) {
126 return 192 + 63 - __builtin_clzll(d[3]);
127 }
128 if (d[2] != 0) {
129 return 128 + 63 - __builtin_clzll(d[2]);
130 }
131 if (d[1] != 0) {
132 return 64 + 63 - __builtin_clzll(d[1]);
133 }
134 if (d[0] != 0) {
135 return 63 - __builtin_clzll(d[0]);
136 }
137 return -1;
138}
139
140inline void record_msb(int msb, uint8_t& dst, std::array<uint32_t, 256>& th_hist) noexcept
141{
142 dst = (msb < 0) ? MSB_ZERO_SENTINEL : static_cast<uint8_t>(msb);
143 ++th_hist[static_cast<size_t>(msb) + 1];
144}
145
149// `AffineBucketChunkInfo` is defined in `pippenger_arena_layout.hpp` (included above).
150
159template <typename Curve> struct ThreadScratch {
160 using AffineElement = typename Curve::AffineElement;
161 using Element = typename Curve::Element;
162 using BaseField = typename Curve::BaseField;
163
164 // reduce_chunk's tree-reduce buffer. Per level the inner loop walks with a read cursor
165 // `i` and a write cursor `next_len ≤ i`, compacting in-place; the next level re-enters
166 // the same buffer without a swap.
168 std::span<uint32_t> curr_buckets;
169
170 // reduce_chunk's batch-affine scratch.
173 std::span<uint32_t> pair_dest;
174
175 size_t result_len = 0;
176
177 // Stage 6a seam-overflow buffer: when a sub-chunk emits a partial for a slot whose
178 // dense bucket entry is already populated (i.e. the digit's run was split across two
179 // sub-chunks), the partial is deferred here and merged at end-of-window via a single
180 // Montgomery-batched tree reduce. Reset to length 0 between windows.
181 std::span<uint32_t> overflow_slots;
183 size_t overflow_len = 0;
184
185 // Recursive affine bucket reduction scratch (cross-window batched, sparse-aware).
186 // `dense_buckets` holds W chunks worth of dense AffineElement arrays back-to-back.
187 // Layout: dense_buckets[w * affine_bucket_stride + i] for window w and 0-indexed slot i.
188 // `is_present` is a parallel uint8_t array marking non-identity slots (0 = empty, 1 = present).
189 // `affine_bucket_pairs` is the scratch buffer for the real-pairs list (single pass: filtered
190 // inline as candidates are generated, no intermediate candidate buffer).
191 // `affine_bucket_indices` is the scratch index buffer for the doubling kernel.
192 // `affine_bucket_inversion_scratch` is reused for the indexed batch-affine kernels.
196 std::span<uint32_t> affine_bucket_indices;
199 // Per-window metadata consumed by recursive_affine_bucket_reduce_strided (lo, hi, buckets_padded,
200 // empty per window). Filled in the lambda before the call.
202};
203
204struct MsmArena {
205 std::unique_ptr<std::byte[]> local_owner; // NOLINT(cppcoreguidelines-avoid-c-arrays)
206 std::byte* data = nullptr;
207 uintptr_t base_addr = 0;
208 size_t capacity = 0;
209 size_t cursor = 0;
210
211 MsmArena(size_t required_bytes, std::span<std::byte> external_arena)
212 {
213 if (!external_arena.empty() && required_bytes <= external_arena.size()) {
214 data = external_arena.data();
215 capacity = external_arena.size();
216 } else {
217 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays)
219 data = local_owner.get();
220 capacity = required_bytes;
221 }
222 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
223 base_addr = reinterpret_cast<uintptr_t>(data);
224 }
225
226 template <typename T> std::span<T> alloc(size_t count) { return bump_alloc<T>(count, cursor, capacity, 0); }
227
228 template <typename T> std::span<T> bump_alloc(size_t count, size_t& local_cursor, size_t bound, size_t base_offset)
229 {
230 const size_t align = alignof(T);
231 const uintptr_t cur_addr = base_addr + base_offset + local_cursor;
232 const uintptr_t aligned_addr = (cur_addr + align - 1) & ~(uintptr_t{ align } - 1);
233 const size_t aligned_local = static_cast<size_t>(aligned_addr - (base_addr + base_offset));
234 const size_t bytes = count * sizeof(T);
235 BB_ASSERT_LTE(aligned_local + bytes, bound);
236 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
237 T* p = reinterpret_cast<T*>(data + base_offset + aligned_local);
238 local_cursor = aligned_local + bytes;
239 return std::span<T>{ p, count };
240 }
241};
242
243template <typename Curve> inline void drain_batch(ThreadScratch<Curve>& s, size_t pair_count) noexcept
244{
245 if (pair_count == 0) {
246 return;
247 }
248 bb::group_elements::batch_affine_add_interleaved<typename Curve::AffineElement, typename Curve::BaseField>(
249 s.points_to_add.data(), 2 * pair_count, s.inversion_scratch.data());
250 // In-place compaction: each `pair_dest[i]` is the `next_len` value at the moment the
251 // pair was queued, which is < the read cursor `i_outer` and < the current `next_len`
252 // — so writing back into curr_pts at `pair_dest[i]` lands on a slot that is already
253 // past the read cursor. See reduce_chunk for the full invariant.
254 for (size_t i = 0; i < pair_count; ++i) {
255 s.curr_pts[s.pair_dest[i]] = s.points_to_add[pair_count + i];
256 }
257}
258
274template <typename Curve> void tree_reduce_in_place(ThreadScratch<Curve>& s, size_t initial_len) noexcept
275{
276 size_t curr_len = initial_len;
277
278 while (true) {
279 size_t i = 0;
280 size_t next_len = 0;
281 size_t pair_count = 0;
282 bool made_pair = false;
283
284 while (i < curr_len) {
285 if (i + 1 < curr_len && s.curr_buckets[i] == s.curr_buckets[i + 1]) {
286 const size_t slot = 2 * pair_count;
287 s.points_to_add[slot] = s.curr_pts[i];
288 s.points_to_add[slot + 1] = s.curr_pts[i + 1];
289 s.curr_buckets[next_len] = s.curr_buckets[i];
290 s.pair_dest[pair_count] = static_cast<uint32_t>(next_len);
291 ++next_len;
292 ++pair_count;
293 i += 2;
294 made_pair = true;
295
296 if (pair_count >= BATCH_CAPACITY) {
297 drain_batch<Curve>(s, pair_count);
298 pair_count = 0;
299 }
300 } else {
301 s.curr_pts[next_len] = s.curr_pts[i];
302 s.curr_buckets[next_len] = s.curr_buckets[i];
303 ++next_len;
304 ++i;
305 }
306 }
307
308 drain_batch<Curve>(s, pair_count);
309
310 if (!made_pair) {
311 break;
312 }
313
314 curr_len = next_len;
315 }
316
317 s.result_len = curr_len;
318}
319
336template <typename Curve>
337void merge_overflow(ThreadScratch<Curve>& s, typename Curve::AffineElement* dst_dense) noexcept
338{
339 if (s.overflow_len == 0) {
340 return;
341 }
342
343 size_t merge_len = 0;
344 size_t i = 0;
345 while (i < s.overflow_len) {
346 const uint32_t slot = s.overflow_slots[i];
347 s.curr_buckets[merge_len] = slot;
348 s.curr_pts[merge_len] = dst_dense[slot];
349 ++merge_len;
350 while (i < s.overflow_len && s.overflow_slots[i] == slot) {
351 s.curr_buckets[merge_len] = slot;
352 s.curr_pts[merge_len] = s.overflow_pts[i];
353 ++merge_len;
354 ++i;
355 }
356 }
357
358 tree_reduce_in_place<Curve>(s, merge_len);
359
360 for (size_t k = 0; k < s.result_len; ++k) {
361 dst_dense[s.curr_buckets[k]] = s.curr_pts[k];
362 }
363
364 s.overflow_len = 0;
365}
366
371template <typename Curve>
372void reduce_chunk(ThreadScratch<Curve>& s,
373 const uint32_t* schedule,
374 const size_t* bucket_start,
375 size_t chunk_lo,
376 size_t chunk_hi,
377 size_t& bucket_cursor,
378 size_t chunk_bucket_hi,
380 std::span<const typename Curve::AffineElement> dedup_extra_points = {}) noexcept
381{
382 const size_t chunk_len = chunk_hi - chunk_lo;
383 if (chunk_len == 0) {
384 s.result_len = 0;
385 return;
386 }
387
388 BB_ASSERT_LTE(chunk_len, s.curr_pts.size());
389 static_assert(BATCH_CAPACITY <= 4096, "BATCH_CAPACITY must fit in pair_dest scratch");
390
391 // Compact entries while loading: dedup non-rep entries (DEDUP_SKIP_BIT set in the
392 // schedule entry) carry no contribution — their points are already accumulated
393 // into the cluster's combined `extra_points[cid]` emitted at the rep's slot. Skip
394 // them to avoid double-counting and to shrink the tree-reduce input.
395 size_t valid_len = 0;
396 size_t bucket = bucket_cursor;
397 size_t pos = chunk_lo;
398 while (bucket <= chunk_bucket_hi && pos < chunk_hi) {
399 const size_t run_lo = std::max(pos, bucket_start[bucket]);
400 const size_t run_hi = std::min(chunk_hi, bucket_start[bucket + 1]);
401 if (run_lo >= run_hi) {
402 ++bucket;
403 continue;
404 }
405
406 const uint32_t bucket_u32 = static_cast<uint32_t>(bucket);
407 for (size_t i = run_lo; i < run_hi; ++i) {
408 const uint32_t e = schedule[i];
409 if ((e & DEDUP_SKIP_BIT) != 0) {
410 continue; // non-rep: skip, don't consume a curr_pts slot
411 }
412 const uint32_t raw_idx = e & SCHEDULE_INDEX_MASK;
413 const bool neg = (e & SCHEDULE_SIGN_BIT) != 0;
414 s.curr_buckets[valid_len] = bucket_u32;
415 // SIMD-widened gather: 4 × v128.load on WASM (2× faster than the
416 // default 8 × i64.load struct copy on V8 TurboFan); 4 × movdqu on
417 // native (already optimal). The conditional negation runs after the
418 // copy because Fq::operator-() is a modular subtract, not a bit flip,
419 // so it can't be folded into the SIMD load lanes.
420 auto& dst_pt = s.curr_pts[valid_len];
421 // Dedup redirect: if the redirect bit is set, fetch from the dedup
422 // extra-points buffer (combined point for a cluster of duplicate scalars)
423 // instead of the original points span. The branch is always-not-taken when
424 // dedup is inactive (`dedup_extra_points` empty) and predictably-mostly-taken-or-not
425 // when active, since cluster-rep scheduling is uniform per MSM_fast.
426 if ((e & DEDUP_REDIRECT_BIT) != 0) {
427 copy_affine64(dst_pt, dedup_extra_points[raw_idx]);
428 } else {
429 copy_affine64(dst_pt, points[raw_idx]);
430 }
431 if (neg) {
432 dst_pt.y = -dst_pt.y;
433 }
434 ++valid_len;
435 }
436 pos = run_hi;
437 if (pos < chunk_hi) {
438 ++bucket;
439 }
440 }
441 bucket_cursor = bucket;
442
443 tree_reduce_in_place<Curve>(s, valid_len);
444}
445
446// `ChunkOutput<Curve>` (Stage 6 per-chunk bucket-reduce output) is defined in
447// `pippenger_arena_layout.hpp` so the test suite can size the Zone S slot the
448// same way the live allocator does.
449
450// `AffineBucketChunkInfo` is defined in `pippenger_arena_layout.hpp` (forward declared
451// above at line ~674 for ThreadScratch). It describes one chunk's contribution to the
452// cross-window recursive affine bucket reduction (lo/hi digit bounds, buckets_padded,
453// empty flag).
454
474template <typename Curve>
475[[gnu::always_inline]] inline void try_filter_pair(typename Curve::AffineElement* buckets,
476 uint8_t* is_present,
477 uint32_t dst_idx,
478 uint32_t src_idx,
480 size_t& real_count) noexcept
481{
482 using Element = typename Curve::Element;
483 using AffineElement = typename Curve::AffineElement;
484
485 if (is_present[src_idx] == 0) {
486 return; // src is identity → no-op
487 }
488 if (is_present[dst_idx] == 0) {
489 buckets[dst_idx] = buckets[src_idx]; // dst was identity → just copy
490 is_present[dst_idx] = 1;
491 return;
492 }
493 // Edge case: dst.x == src.x. Since both points are on-curve, this means either
494 // dst == src (doubling case) or dst == -src (inverse case, result is identity).
495 // batch_affine_add_indexed_impl would invert zero here, so handle out-of-band.
496 if (buckets[dst_idx].x == buckets[src_idx].x) {
497 if (buckets[dst_idx].y == buckets[src_idx].y) {
498 // dst == src → result is 2 * dst.
499 Element doubled = Element(buckets[dst_idx]);
500 doubled.self_dbl();
501 buckets[dst_idx] = AffineElement{ doubled };
502 } else {
503 // dst == -src → result is identity.
504 buckets[dst_idx].self_set_infinity();
505 is_present[dst_idx] = 0;
506 }
507 return;
508 }
509 real_pairs[real_count++] = { dst_idx, src_idx };
510}
511
517[[gnu::always_inline]] inline void try_filter_idx(const uint8_t* is_present,
518 uint32_t idx,
519 uint32_t* real_indices,
520 size_t& real_count) noexcept
521{
522 if (is_present[idx] != 0) {
523 real_indices[real_count++] = idx;
524 }
525}
526
564template <typename Curve>
565void recursive_affine_bucket_reduce_strided(ThreadScratch<Curve>& s,
566 const AffineBucketChunkInfo* chunk_infos,
567 size_t windows_in_batch,
568 ChunkOutput<Curve>* outputs_base,
569 size_t output_stride) noexcept
570{
571 using AffineElement = typename Curve::AffineElement;
572 using Element = typename Curve::Element;
573
574 auto out_at = [outputs_base, output_stride](size_t w) -> ChunkOutput<Curve>& {
575 return outputs_base[w * output_stride];
576 };
577
578 if (windows_in_batch == 0) {
579 return;
580 }
581
582 // Stride is the caller's pre-sized layout width (`s.affine_bucket_stride`, set via
583 // `ensure_affine_bucket_capacity`). The densification step in the caller scattered buckets at
584 // `w * s.affine_bucket_stride + i`, so we MUST use the same value for our own indexing — any
585 // re-derivation that disagrees with the layout would index neighbouring windows. The
586 // pre-size already enforces `stride ≥ max_w(buckets_padded_w)` AND `stride ≥ 2` AND
587 // `stride is a power of two`, so the trivial-stride fast path and the 4-phase math
588 // both stay valid here. Per-window buckets_padded controls how many slots each window walks
589 // and is bounded by `stride` — verified below in debug.
590 const size_t stride = s.affine_bucket_stride;
591 bool any_nonempty = false;
592 for (size_t w = 0; w < windows_in_batch; ++w) {
593 if (chunk_infos[w].empty == 0) {
594 any_nonempty = true;
595 BB_ASSERT_LTE(chunk_infos[w].buckets_padded, stride);
596 }
597 }
598 if (!any_nonempty) {
599 for (size_t w = 0; w < windows_in_batch; ++w) {
600 out_at(w).R = Curve::Group::point_at_infinity;
601 out_at(w).L = Curve::Group::point_at_infinity;
602 }
603 return;
604 }
605
606 AffineElement* const buckets = s.dense_buckets.data();
607 uint8_t* const is_present = s.is_present.data();
608
609 // Pick L0 (the leaf-partition size). c0 = floor(log2(stride) / 2)
610 // gives L0 ≈ sqrt(stride) — balances Phase A batch size (W·D) vs Phase A iter count
611 // (L0 - 1). Both L0 and D = stride / L0 must be powers of two.
612 BB_ASSERT_GT(stride, size_t{ 0 });
613 const size_t c_log = static_cast<size_t>(std::countr_zero(stride));
614 BB_ASSERT_EQ(static_cast<size_t>(1) << c_log, stride);
615 // Trivial-stride fast paths. The 4-phase algorithm requires c_log ≥ 2 (so we can pick
616 // c0 ∈ [1, c_log - 1]) — fall back to direct computation for stride ∈ {1, 2}.
617 if (stride <= 2) {
618 for (size_t w = 0; w < windows_in_batch; ++w) {
619 if (chunk_infos[w].empty != 0) {
620 out_at(w).R = Curve::Group::point_at_infinity;
621 out_at(w).L = Curve::Group::point_at_infinity;
622 continue;
623 }
624 // Walk the (up to two) populated slots directly.
625 const size_t base = w * stride;
626 Element R = Curve::Group::point_at_infinity;
627 Element L = Curve::Group::point_at_infinity;
628 for (size_t i = 0; i < chunk_infos[w].buckets_padded; ++i) {
629 if (is_present[base + i] == 0) {
630 continue;
631 }
632 R += Element(buckets[base + i]);
633 L += Element(buckets[base + i]); // weight 1
634 if (i == 1) {
635 L += Element(buckets[base + i]); // weight 2 for i=1
636 }
637 }
638 out_at(w).R = R;
639 out_at(w).L = L;
640 }
641 return;
642 }
643
644 // Choose c0 = floor(c_log / 2), clamped so that 1 ≤ c0 ≤ c_log - 1.
645 size_t c0 = c_log / 2;
646 if (c0 == 0) {
647 c0 = 1;
648 }
649 if (c0 >= c_log) {
650 c0 = c_log - 1;
651 }
652 const size_t L0 = static_cast<size_t>(1) << c0;
653 const size_t D = stride >> c0; // == stride / L0
654 BB_ASSERT_EQ(L0 * D, stride);
655 BB_ASSERT_GTE(L0, size_t{ 2 });
656 BB_ASSERT_GTE(D, size_t{ 2 });
657
658 auto* const reals = s.affine_bucket_pairs.data();
659 auto* const dbl_reals = s.affine_bucket_indices.data();
660 auto* const inv_scratch = s.affine_bucket_inversion_scratch.data();
661
662 // Phase A: per-sub-partition running-sum (suffix sums).
663 // For each window w and each sub-partition d, walk slots from L0-1 down to 1 within the
664 // sub-partition, accumulating buckets[w*stride + d*L0 + l - 1] += buckets[... l]. All
665 // (w, d, l) triples for a fixed l share one batch-affine inversion (up to windows_in_batch
666 // · D pairs). Short windows (my_M_w < L0) are treated as a single sub-partition of length
667 // my_M_w to skip dead candidates; effective per-(w, d) length is min(L0, my_M_w - d·L0).
668 {
669 for (size_t l = L0 - 1; l >= 1; --l) {
670 size_t real_count = 0;
671 for (size_t w = 0; w < windows_in_batch; ++w) {
672 if (chunk_infos[w].empty != 0) {
673 continue;
674 }
675 const size_t my_M_w = chunk_infos[w].buckets_padded;
676 const size_t base = w * stride;
677 if (my_M_w < L0) {
678 // Short window: single sub-partition of effective length `my_M_w`.
679 if (l >= my_M_w) {
680 continue; // l is in the empty-padding region, skip
681 }
682 const uint32_t src = static_cast<uint32_t>(base + l);
683 const uint32_t dst = static_cast<uint32_t>(base + l - 1);
684 try_filter_pair<Curve>(buckets, is_present, dst, src, reals, real_count);
685 } else {
686 const size_t my_D = my_M_w >> c0; // ≥ 1
687 for (size_t d = 0; d < my_D; ++d) {
688 const uint32_t src = static_cast<uint32_t>(base + (d * L0) + l);
689 const uint32_t dst = static_cast<uint32_t>(base + (d * L0) + l - 1);
690 try_filter_pair<Curve>(buckets, is_present, dst, src, reals, real_count);
691 }
692 }
693 }
694 if (real_count > 0) {
695 bb::group_elements::batch_affine_add_indexed_impl<typename Curve::AffineElement,
696 typename Curve::BaseField>(
697 buckets, reals, real_count, inv_scratch);
698 }
699 }
700 }
701
702 // After Phase A, each window's slot 0 holds the simple sum of its sub-partition 0,
703 // and slot d*L0 (d ≥ 1) holds the simple sum of sub-partition d. The other slots within
704 // each sub-partition hold suffix sums that Phase D will combine.
705
706 // Phase B: log-recombine sub-partition simple sums into slot 0.
707 // For L1 = L0, 2*L0, 4*L0, ..., stride/2: pair (slot 2d*L1, slot (2d+1)*L1).
708 {
709 size_t L1 = L0;
710 while (L1 < stride) {
711 size_t real_count = 0;
712 const size_t step = 2 * L1;
713 for (size_t w = 0; w < windows_in_batch; ++w) {
714 if (chunk_infos[w].empty != 0) {
715 continue;
716 }
717 const size_t my_M = chunk_infos[w].buckets_padded;
718 if (step > my_M) {
719 continue;
720 }
721 const size_t base = w * stride;
722 const size_t num_pairs_w = my_M / step;
723 for (size_t d = 0; d < num_pairs_w; ++d) {
724 const uint32_t dst = static_cast<uint32_t>(base + ((2 * d) * L1));
725 const uint32_t src = static_cast<uint32_t>(base + (((2 * d) + 1) * L1));
726 try_filter_pair<Curve>(buckets, is_present, dst, src, reals, real_count);
727 }
728 }
729 if (real_count > 0) {
730 bb::group_elements::batch_affine_add_indexed_impl<typename Curve::AffineElement,
731 typename Curve::BaseField>(
732 buckets, reals, real_count, inv_scratch);
733 }
734 L1 *= 2;
735 }
736 }
737
738 // After Phase B, each window's slot 0 holds Σ_d B_{c,d} = R_c. Save R_c into outputs
739 // before Phase D's tree-add overwrites slot 0.
740 for (size_t w = 0; w < windows_in_batch; ++w) {
741 if (chunk_infos[w].empty != 0) {
742 out_at(w).R = Curve::Group::point_at_infinity;
743 continue;
744 }
745 const AffineElement& slot0 = buckets[w * stride];
746 if (is_present[w * stride] == 0) {
747 out_at(w).R = Curve::Group::point_at_infinity;
748 } else {
749 out_at(w).R = Element(slot0);
750 }
751 }
752
753 // Phase C: doublings.
754 // The candidate index list for the initial pass is constant across all c0 iters —
755 // every slot d*L0 for d ∈ [1, my_D - 1] in every non-empty window. Build the empty-
756 // filtered list once and chain c0 doublings on it instead of filtering c0 times.
757 // Subsequent levels (L1 = 2*L0, 4*L0, ...) do one doubling per level on level-specific
758 // index sets handled separately below.
759 {
760 size_t real_count = 0;
761 for (size_t w = 0; w < windows_in_batch; ++w) {
762 if (chunk_infos[w].empty != 0) {
763 continue;
764 }
765 const size_t my_M_w = chunk_infos[w].buckets_padded;
766 const size_t my_D = (my_M_w >= L0) ? (my_M_w >> c0) : size_t{ 0 };
767 const size_t base = w * stride;
768 for (size_t d = 1; d < my_D; ++d) {
769 try_filter_idx(is_present, static_cast<uint32_t>(base + (d * L0)), dbl_reals, real_count);
770 }
771 }
772 // c0 chained doublings on the same real list.
773 if (real_count > 0) {
774 for (size_t j = 0; j < c0; ++j) {
775 bb::group_elements::batch_affine_double_indexed_impl<typename Curve::AffineElement,
776 typename Curve::BaseField>(
777 buckets, dbl_reals, real_count, inv_scratch);
778 }
779 }
780 }
781 // Successive: at L1 = 2*L0, 4*L0, ..., stride/2: every d ≥ 1 in the sub-partition
782 // grid of size `stride / L1` gets one more doubling.
783 {
784 size_t L1 = 2 * L0;
785 while (L1 < stride) {
786 size_t real_count = 0;
787 for (size_t w = 0; w < windows_in_batch; ++w) {
788 if (chunk_infos[w].empty != 0) {
789 continue;
790 }
791 const size_t my_M = chunk_infos[w].buckets_padded;
792 if (L1 >= my_M) {
793 continue; // this window has no sub-partitions at this hierarchy
794 }
795 const size_t my_D1 = my_M / L1;
796 const size_t base = w * stride;
797 for (size_t d = 1; d < my_D1; ++d) {
798 try_filter_idx(is_present, static_cast<uint32_t>(base + (d * L1)), dbl_reals, real_count);
799 }
800 }
801 if (real_count > 0) {
802 bb::group_elements::batch_affine_double_indexed_impl<typename Curve::AffineElement,
803 typename Curve::BaseField>(
804 buckets, dbl_reals, real_count, inv_scratch);
805 }
806 L1 *= 2;
807 }
808 }
809
810 // Phase D: flat tree-add over the buckets_padded slots. For m = 1, 2, 4, ...,
811 // buckets_padded/2: pair (slot pos, slot pos+m) for pos = 0, 2m, 4m, ...
812 // Once the level's candidate count drops below BATCH_AFFINE_BREAKEVEN, the per-batch
813 // inversion overhead exceeds the projective per-add cost; bail and finish in Jacobian.
814 constexpr size_t BATCH_AFFINE_BREAKEVEN = 32;
815 size_t m = 1;
816 while (m < stride) {
817 // Live-slot count after this iter: stride / (2m) per window worst-case.
818 // Decision: would this iter's batch be too small? Estimate as
819 // `windows_in_batch * stride / (2m)` (upper bound on candidates).
820 const size_t est_cands_this_iter = windows_in_batch * (stride / (2 * m));
821 if (est_cands_this_iter < BATCH_AFFINE_BREAKEVEN) {
822 break;
823 }
824 size_t real_count = 0;
825 const size_t step = 2 * m;
826 for (size_t w = 0; w < windows_in_batch; ++w) {
827 if (chunk_infos[w].empty != 0) {
828 continue;
829 }
830 const size_t my_M = chunk_infos[w].buckets_padded;
831 if (m >= my_M) {
832 continue;
833 }
834 const size_t base = w * stride;
835 for (size_t pos = 0; pos + m < my_M; pos += step) {
836 try_filter_pair<Curve>(buckets,
838 static_cast<uint32_t>(base + pos),
839 static_cast<uint32_t>(base + pos + m),
840 reals,
841 real_count);
842 }
843 }
844 if (real_count > 0) {
845 bb::group_elements::batch_affine_add_indexed_impl<typename Curve::AffineElement, typename Curve::BaseField>(
846 buckets, reals, real_count, inv_scratch);
847 }
848 m *= 2;
849 }
850
851 // Write L_c. After Phase D's loop, `m` is the level NOT performed (or `stride` if all
852 // levels ran). The "live" slots — those holding cumulative tree-sums of consecutive m
853 // original buckets each — are {0, m, 2m, 3m, ...} ∩ [0, my_M):
854 // - loop completed (m == stride): only slot 0 is live; it holds the final L.
855 // - loop broke at level m: sum the live slots in Jacobian (live_step = m).
856 // - loop broke at m == 1: every original bucket is still live, sum them all.
857 // The Jacobian sum recovers what the unfinished levels would have computed in the
858 // batch-affine inner loop.
859 for (size_t w = 0; w < windows_in_batch; ++w) {
860 if (chunk_infos[w].empty != 0) {
861 out_at(w).L = Curve::Group::point_at_infinity;
862 continue;
863 }
864 const size_t base = w * stride;
865 const size_t my_M = chunk_infos[w].buckets_padded;
866 Element L = Curve::Group::point_at_infinity;
867 const size_t live_step = m; // distance between live slots after the affine phase
868 for (size_t pos = 0; pos < my_M; pos += live_step) {
869 if (is_present[base + pos] != 0) {
870 L += Element(buckets[base + pos]);
871 }
872 }
873 out_at(w).L = L;
874 }
875}
876
892template <typename Curve>
893[[gnu::always_inline]] inline typename Curve::Element chunk_contribution(const ChunkOutput<Curve>& chunk) noexcept
894{
895 using Element = typename Curve::Element;
896 if (chunk.empty != 0) {
897 return Curve::Group::point_at_infinity;
898 }
899 const uint32_t k = chunk.lo - 1;
900 Element acc = chunk.L;
901 if (k != 0) {
902 Element p = chunk.R;
903 uint32_t kk = k;
904 while (kk != 0) {
905 if ((kk & 1U) != 0) {
906 acc += p;
907 }
908 kk >>= 1;
909 if (kk != 0) {
910 p.self_dbl();
911 }
912 }
913 }
914 return acc;
915}
916
917} // namespace
918// `pippenger_round_parallel_jacobian_fast` has external linkage via the `extern template`
919// declarations in the header (used by the batched driver). Defined at namespace scope.
920
940template <typename Curve>
944 size_t min_pts_per_thread_override) noexcept
945{
946 using Element = typename Curve::Element;
947 using ScalarField = typename Curve::ScalarField;
948 using BaseField = typename Curve::BaseField;
949
950 const size_t n = scalars.size();
951 if (n == 0) {
952 return Curve::Group::point_at_infinity;
953 }
954
955 constexpr size_t NUM_BITS = ScalarField::modulus.get_msb() + 1;
956
957 // Cost-model window-size selection (mirrors MSM_fast<Curve>::get_optimal_log_num_buckets,
958 // with BUCKET_ACCUMULATION_COST = 5 = J-J-add-equiv-muls / J-A-add-equiv-muls ≈ 16/11
959 // rounded up). We do NOT delegate to the public method — keeping it self-contained
960 // avoids dragging the AffineAddition / AFFINE_TRICK_THRESHOLD machinery in here.
961 constexpr size_t BUCKET_ACCUMULATION_COST = 5;
962 constexpr uint32_t MAX_C = 18;
963 auto cost = [n](uint32_t bits) -> size_t {
964 size_t rounds = (NUM_BITS + bits - 1) / bits;
965 size_t buckets = size_t{ 1 } << bits;
966 return rounds * (n + buckets * BUCKET_ACCUMULATION_COST);
967 };
968 uint32_t window_bits = 1;
969 size_t best_cost = cost(1);
970 for (uint32_t b = 2; b <= MAX_C; ++b) {
971 const size_t this_cost = cost(b);
972 if (this_cost < best_cost) {
973 best_cost = this_cost;
974 window_bits = b;
975 }
976 }
977 const size_t num_buckets = size_t{ 1 } << window_bits;
978 const uint32_t num_rounds = static_cast<uint32_t>((NUM_BITS + window_bits - 1) / window_bits);
979 const uint32_t last_round_bits =
980 static_cast<uint32_t>(NUM_BITS - (static_cast<size_t>(num_rounds - 1) * window_bits));
981
982 // Each thread owns a num_buckets-sized scratch slice and runs num_rounds passes; below
983 // ~256 points per thread the parallel_for wakeup + per-call bucket reset dominate.
984 // wasm is forced single-threaded — its barrier cost is much higher than native.
985#ifdef __wasm__
986 constexpr size_t MIN_PTS_PER_THREAD_DEFAULT = SIZE_MAX;
987#else
988 constexpr size_t MIN_PTS_PER_THREAD_DEFAULT = 256;
989#endif
990 const size_t MIN_PTS_PER_THREAD =
991 (min_pts_per_thread_override == 0) ? MIN_PTS_PER_THREAD_DEFAULT : min_pts_per_thread_override;
992 const size_t max_threads = get_num_cpus();
993 size_t num_threads = std::min(std::max<size_t>(1, n / MIN_PTS_PER_THREAD), max_threads);
994 if (num_threads == 0) {
995 num_threads = 1;
996 }
997
998 // Allocate the per-thread bucket + presence scratch ONCE, indexed by tid inside the
999 // parallel_for. Allocating inside the lambda body would re-malloc on every call (and
1000 // on WASM the malloc cost is non-trivial relative to the arithmetic work at small n).
1001 std::vector<Element> per_thread_results(num_threads);
1002 std::vector<Element> all_buckets(num_threads * num_buckets);
1003 std::vector<uint8_t> all_present(num_threads * num_buckets);
1004
1005 auto thread_body = [&](size_t tid) {
1006 const size_t lo = (tid * n) / num_threads;
1007 const size_t hi = ((tid + 1) * n) / num_threads;
1008
1009 Element* const buckets = all_buckets.data() + (tid * num_buckets);
1010 uint8_t* const present = all_present.data() + (tid * num_buckets);
1011
1012 Element result = Curve::Group::point_at_infinity;
1013
1014 for (uint32_t round = 0; round < num_rounds; ++round) {
1015 std::memset(present, 0, num_buckets);
1016
1017 const size_t hi_bit = NUM_BITS - (static_cast<size_t>(round) * window_bits);
1018 const size_t lo_bit = (hi_bit < window_bits) ? size_t{ 0 } : (hi_bit - window_bits);
1019 const size_t actual_size = hi_bit - lo_bit;
1020 const size_t start_limb = lo_bit >> 6;
1021 const size_t end_limb = hi_bit >> 6;
1022 const size_t lo_off = lo_bit & 63;
1023 const size_t lo_bits = (64 - lo_off < actual_size) ? (64 - lo_off) : actual_size;
1024 const size_t hi_bits = actual_size - lo_bits;
1025 const uint64_t lo_mask = (lo_bits == 64) ? ~uint64_t{ 0 } : ((uint64_t{ 1 } << lo_bits) - 1);
1026 const uint64_t hi_mask = (hi_bits == 0) ? uint64_t{ 0 } : ((uint64_t{ 1 } << hi_bits) - 1);
1027
1028 for (size_t i = lo; i < hi; ++i) {
1029 const uint64_t s_lo = (scalars[i].data[start_limb] >> lo_off) & lo_mask;
1030 const uint64_t s_hi = (start_limb != end_limb) ? (scalars[i].data[end_limb] & hi_mask) : uint64_t{ 0 };
1031 const uint32_t slice = static_cast<uint32_t>(s_lo | (s_hi << lo_bits));
1032 if (slice == 0) {
1033 continue;
1034 }
1035 if (present[slice] == 0) {
1036 buckets[slice].x = points[i].x;
1037 buckets[slice].y = points[i].y;
1038 buckets[slice].z = BaseField::one();
1039 present[slice] = 1;
1040 } else {
1041 buckets[slice] += points[i];
1042 }
1043 }
1044
1045 // Running suffix sum over populated buckets only.
1046 // acc = Σ_{j ≥ i, present[j]} bucket[j]
1047 // bucket_sum = Σ_{i in [first_pop_low, top]} acc(i) = Σ_k k * bucket[k]
1048 // Bucket 0 carries no contribution and is never added.
1049 std::ptrdiff_t top = static_cast<std::ptrdiff_t>(num_buckets) - 1;
1050 while (top >= 1 && present[static_cast<size_t>(top)] == 0) {
1051 --top;
1052 }
1053 Element bucket_sum = Curve::Group::point_at_infinity;
1054 if (top >= 1) {
1055 Element acc = buckets[static_cast<size_t>(top)];
1056 bucket_sum = acc;
1057 for (std::ptrdiff_t i = top - 1; i >= 1; --i) {
1058 if (present[static_cast<size_t>(i)] != 0) {
1059 acc += buckets[static_cast<size_t>(i)];
1060 }
1061 bucket_sum += acc;
1062 }
1063 }
1064
1065 const uint32_t doublings = (round == num_rounds - 1) ? last_round_bits : window_bits;
1066 for (uint32_t d = 0; d < doublings; ++d) {
1067 result.self_dbl();
1068 }
1069 result += bucket_sum;
1070 }
1071
1072 per_thread_results[tid] = result;
1073 };
1074
1075 if (num_threads == 1) {
1076 thread_body(0);
1077 } else {
1078 bb::parallel_for(num_threads, thread_body);
1079 }
1080
1081 Element total = per_thread_results[0];
1082 for (size_t t = 1; t < num_threads; ++t) {
1083 total += per_thread_results[t];
1084 }
1085 return total;
1086}
1087
1088// PerWorkerArenaLayout (and its dependencies BATCH_CAPACITY, DEDUP_MAX_CHUNK_MEMBERS,
1089// AffineBucketChunkInfo) lives in `pippenger_arena_layout.hpp`. Used by the sizer
1090// below, the live allocator in `pippenger_round_parallel`, and the arena-layout
1091// regression test.
1092} // namespace round_parallel_detail
1093
1112
1113// Compute the exact arena bytes a single MSM_fast of `n_input` points will need.
1114// Mirrors the inline budget calculation inside `pippenger_round_parallel`.
1115// Returns 0 when N is small enough that we'll fall back to the Jacobian fast path
1116// (no affine arena needed). Exposed (declared in `scalar_multiplication_fast.hpp`)
1117// so the test suite can exercise the same sizer the live allocator uses.
1118template <typename Curve>
1119size_t compute_arena_bytes_for_msm(size_t n_input, bool external_glv_provided, bool dedup_active) noexcept
1120{
1121 using ScalarField = typename Curve::ScalarField;
1122 constexpr size_t FULL_NUM_BITS = ScalarField::modulus.get_msb() + 1;
1123
1124 if (n_input < 4) {
1125 return 0; // trivial path
1126 }
1127
1128 const bool use_glv = external_glv_provided || (n_input <= round_parallel_detail::GLV_SMALL_N_THRESHOLD);
1129 const size_t n = use_glv ? 2 * n_input : n_input;
1130 const size_t NUM_BITS = use_glv ? size_t{ 128 } : FULL_NUM_BITS;
1131 BB_ASSERT_LTE(n,
1133 "working scalar indices must fit in the 29-bit schedule payload");
1134
1139
1140 // window-bits selection uses the ideal per-window oversubscription factor (not the dispatch lmul).
1141 const size_t num_logical_threads_for_c = bb::get_num_cpus() * window_bits_tuning_oversub_factor(n_input);
1142 const size_t window_bits =
1143 round_parallel_detail::choose_window_bits(n, NUM_BITS, n_input, num_logical_threads_for_c);
1144 const size_t num_windows = (NUM_BITS + 2 + window_bits - 1) / window_bits;
1145 const size_t num_buckets = (size_t{ 1 } << (window_bits - 1)) + 1;
1146
1147 const size_t desired_threads = std::max<size_t>(1, bb::get_num_cpus());
1148 const size_t max_threads_for_min_batch = n / MIN_BATCH_CAPACITY;
1149 const size_t min_threads_allowed =
1150 std::max<size_t>(1, (desired_threads + MIN_AFFINE_THREAD_RATIO - 1) / MIN_AFFINE_THREAD_RATIO);
1151
1152 if (max_threads_for_min_batch < min_threads_allowed) {
1153 return 0; // jacobian-fast fallback, no affine arena
1154 }
1155
1156 const size_t num_threads = std::min(desired_threads, std::max<size_t>(1, max_threads_for_min_batch));
1157
1158 // num_threads sizes the per-task arrays; worker_total sizes the per-OS-thread scratch
1159 // (FIFO-shared by every task that lands on that OS thread).
1160 const size_t worker_total_for_budget = num_threads;
1161 const size_t dense_stride_est = round_parallel_detail::compute_dense_stride(num_buckets, num_threads);
1162
1163 // Pre-schedule conservative per-window cost: uses `num_buckets` (= 2^(c-1)+1) as the
1164 // B upper bound. The lambda below recomputes once the actual schedule is built.
1165 const size_t per_window_bytes = round_parallel_detail::compute_per_window_bytes<Curve>(
1166 num_threads, num_buckets, n, dense_stride_est, worker_total_for_budget);
1167
1168 const size_t global_max_overflow_per_window =
1169 round_parallel_detail::compute_global_max_overflow_per_window(n, num_threads, SUBCHUNK_ENTRIES_CAP);
1170
1171 const bool inline_glv_double = use_glv && !external_glv_provided;
1172 const size_t profile_threads = std::max<size_t>(1, bb::get_num_cpus());
1173 const size_t phase_one_prologue_bytes =
1174 round_parallel_detail::compute_phase_one_prologue_bytes(n, use_glv, inline_glv_double, profile_threads);
1175
1176 const auto phase_a_caps = round_parallel_detail::compute_phase_a_caps(n, num_threads);
1177 const size_t phase_a_cluster_members_cap = phase_a_caps.members_cap;
1178 const size_t phase_a_cluster_offsets_cap = phase_a_caps.offsets_cap;
1179
1180 // Zone W per-worker UNION via the canonical layout walk. Stage 6a, Stage 6b, and
1181 // Phase A overlay the same per-worker bytes; the struct returns the max-of-layouts
1182 // (the Stage 6 wpb-dependent tail is added below once `windows_per_batch` is known).
1183 // Passing `windows_per_batch = 0` here skips the tail — we only need the union bytes
1184 // for the fixed_overhead → wpb solve.
1185 const round_parallel_detail::PerWorkerArenaLayout<Curve> union_layout(/*chunk_capacity=*/SUBCHUNK_ENTRIES_CAP,
1186 global_max_overflow_per_window,
1187 dedup_active,
1188 phase_a_cluster_members_cap,
1189 phase_a_cluster_offsets_cap,
1190 /*windows_per_batch=*/0,
1191 /*dense_stride_est=*/0);
1192 const size_t worker_union_bytes = union_layout.per_worker_union_bytes;
1193
1194 const size_t fixed_overhead = (worker_union_bytes * worker_total_for_budget) +
1195 (size_t{ 96 } * round_parallel_detail::VAR_WINDOW_MAX_WINDOWS) // window_sums_storage
1196 + (size_t{ 8 } * (num_threads + 1)) // rebalanced_bucket_lo_partition
1197 + phase_one_prologue_bytes;
1198
1199 // wpb fallback when fixed_overhead has eaten the BATCH_MEM_BUDGET headroom: the inline
1200 // `solve_wpb` in `pippenger_round_parallel` returns `W_R` (the whole region) — running
1201 // every window in a single batch — when `available_budget == 0`. Previously the sizer
1202 // returned `wpb = 1` and relied on a `worst_case_arena = BATCH_MEM_BUDGET + 32K` floor;
1203 // that floor failed for large num_threads where fixed_overhead alone exceeds the budget.
1204 const size_t available_budget_outer =
1205 (BATCH_MEM_BUDGET > fixed_overhead) ? (BATCH_MEM_BUDGET - fixed_overhead) : size_t{ 0 };
1206 const size_t windows_per_batch =
1207 round_parallel_detail::solve_wpb(per_window_bytes, available_budget_outer, num_windows);
1208 // Dedup state lives in the arena (allocated post-Phase-1, retained through Stage 6a).
1209 // Worst-case sizes: redirect_lookup is one uint32 per working scalar (4n bytes);
1210 // extra_points is the fixed DEDUP_MAX_CLUSTERS cap (≈1 MB) regardless of n.
1211 const size_t dedup_bytes = dedup_active ? ((size_t{ 4 } * n) + (size_t{ sizeof(typename Curve::AffineElement) } *
1213 : size_t{ 0 };
1214 auto arena_bytes_for_window_layout = [&](size_t bit_budget, size_t wb) {
1215 const auto layout_sched = round_parallel_detail::build_var_window_schedule(bit_budget, wb);
1216 // Uniform schedule: the widest window's bucket count is the per-window cap.
1217 const size_t B_eff_layout = (size_t{ 1 } << (wb - 1)) + 1;
1218 const size_t dense_stride_layout = round_parallel_detail::compute_dense_stride(B_eff_layout, num_threads);
1219 const size_t per_window_bytes_layout = round_parallel_detail::compute_per_window_bytes<Curve>(
1220 num_threads, B_eff_layout, n, dense_stride_layout, worker_total_for_budget);
1221
1222 const size_t available_budget =
1223 (BATCH_MEM_BUDGET > fixed_overhead) ? (BATCH_MEM_BUDGET - fixed_overhead) : size_t{ 0 };
1224 const size_t wpb = round_parallel_detail::solve_wpb(
1225 per_window_bytes_layout, available_budget, static_cast<size_t>(layout_sched.num_windows));
1226 return fixed_overhead + (wpb * per_window_bytes_layout) + 32768 + dedup_bytes;
1227 };
1228
1229 // Tight return: the arena holds `fixed_overhead + wpb · per_window_bytes` of typed
1230 // buffers plus a 32 KiB alignment pad and the dedup state (when active). Sizing
1231 // tightly — rather than padding up to BATCH_MEM_BUDGET — matters for many-MSM_fast flows
1232 // (e.g. PerMsmChonk's 256 separate per-circuit MSMs) where every per-MSM_fast
1233 // `make_unique_for_overwrite<std::byte[]>` mmap/munmaps the buffer above glibc's
1234 // M_MMAP_THRESHOLD; a 32 MiB floor here would tax every MSM_fast with the page-fault
1235 // first-touch cost regardless of how much of the arena the small MSM_fast actually uses.
1236 size_t arena_bytes = fixed_overhead + (windows_per_batch * per_window_bytes) + 32768 + dedup_bytes;
1237
1238 // The live pipeline chooses window_bits from the *effective* (nonzero) scalar count and the
1239 // observed bit budget after Phase 1: c = choose_window_bits(n_active, effective_num_bits) with
1240 // n_active <= n and effective_num_bits <= NUM_BITS. Fewer active points => smaller c => more
1241 // windows => a larger arena (most sharply once fixed_overhead has eaten the batch budget and
1242 // every window runs in a single batch). Size for the worst reachable c so the bound holds for
1243 // any scalar density, with no extra scalar scan.
1244 //
1245 // For a fixed c, bit_budget = NUM_BITS maximizes the window count (effective_num_bits <=
1246 // NUM_BITS) and 2^(c-1)+1 caps B_eff, so arena_bytes_for_window_layout(NUM_BITS, c) dominates
1247 // every live (effective_num_bits, c) layout. The reachable c span is [2, c_max]: choose is
1248 // non-decreasing in the point count (n_active <= n bounds it above), but the ceil() in the round
1249 // count makes it non-monotonic in the bit budget by ±1, so c_max is the max over bit budgets,
1250 // not simply choose(n, NUM_BITS).
1251 size_t c_max_reachable = window_bits;
1252 for (size_t bit_budget = 1; bit_budget <= NUM_BITS; ++bit_budget) {
1253 c_max_reachable = std::max(c_max_reachable,
1255 n, bit_budget, n_input, num_logical_threads_for_c)));
1256 }
1257 for (size_t wb = 2; wb <= c_max_reachable; ++wb) {
1258 arena_bytes = std::max(arena_bytes, arena_bytes_for_window_layout(NUM_BITS, wb));
1259 }
1260 return arena_bytes;
1261}
1262
1263// Round-parallel Pippenger MSM_fast.
1264// `external_glv_doubled` — optional caller-supplied [P_0, φP_0, …, P_{n-1}, φP_{n-1}]
1265// buffer (length 2·n_input). When non-empty, forces use_glv=true and skips the
1266// internal doubling pass. The interleaved layout means longer-prefix aliasing
1267// (length 2·Nmax) is valid for any n ≤ Nmax with no copy.
1268// `external_arena` — optional caller-supplied scratch buffer ≥ this MSM_fast's required
1269// bytes. When empty, allocate per-MSM_fast via make_unique_for_overwrite and free at
1270// return. The batched driver supplies a single arena sized to the largest member.
1271template <typename Curve>
1272// NOLINTNEXTLINE(readability-function-size, readability-function-cognitive-complexity,
1273// google-readability-function-size)
1276 bool dedup_hint,
1278 std::span<std::byte> external_arena) noexcept
1279{
1280 using Element = typename Curve::Element;
1281 using AffineElement = typename Curve::AffineElement;
1282 using ScalarField = typename Curve::ScalarField;
1283 using BaseField = typename Curve::BaseField;
1284
1285 const size_t n_input = scalars_span.size();
1286 if (n_input == 0) {
1287 return Curve::Group::point_at_infinity;
1288 }
1289
1290 // Bail to trivial_msm_threaded when each worker would own fewer than
1291 // MIN_PTS_PER_THREAD_FOR_PIPPENGER points — pippenger_fast's per-window scaffolding loses
1292 // to straus_msm at this density. Caller-supplied GLV doubling is wasted at this size,
1293 // but the overhead is negligible.
1294 {
1295 const size_t max_threads = bb::get_num_cpus();
1296 const size_t num_threads_dispatch = std::max<size_t>(1, std::min(n_input, max_threads));
1297 const size_t pts_per_thread = (n_input + num_threads_dispatch - 1) / num_threads_dispatch;
1298 if (pts_per_thread < MIN_PTS_PER_THREAD_FOR_PIPPENGER) {
1299 return trivial_msm_threaded<Curve>(scalars_span, all_points);
1300 }
1301 }
1302
1303 BB_ASSERT_GTE(all_points.size(), scalars_span.start_index + n_input);
1304 std::span<const AffineElement> input_points(&all_points[scalars_span.start_index], n_input);
1305
1306 constexpr size_t FULL_NUM_BITS = ScalarField::modulus.get_msb() + 1;
1307
1308 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
1309 ScalarField* scalar_ptr = const_cast<ScalarField*>(&scalars_span[scalars_span.start_index]);
1310 std::span<ScalarField> input_scalars(scalar_ptr, n_input);
1311
1312 // GLV: split k ≡ k1 − k2·λ (mod r), giving 2n pairs at NUM_BITS=128. Halves num_windows;
1313 // costs an extra n point doubles. Applied only below GLV_SMALL_N_THRESHOLD where the
1314 // win-on-windows beats the lose-on-doubled-scan, OR forced on by the batched dispatcher
1315 // supplying `external_glv_doubled` (it amortises the doubling across the whole batch).
1316 // Empirical crossover (best-of-3 sweep at HC=16, P ∈ {4, 8, 16}): wasmtime keeps GLV up
1317 // to n=2^16; native to n=2^13 (clang's branchless bias-decode is fast enough that the 2×
1318 // point-count cost dominates above that). Threshold is platform-conditional in the
1319 // hoisted GLV_SMALL_N_THRESHOLD declaration.
1320 const bool external_glv_provided = !external_glv_doubled.empty();
1321 const bool use_glv = external_glv_provided || n_input <= round_parallel_detail::GLV_SMALL_N_THRESHOLD;
1322
1323 // Stage 6 splits into 6a (per-thread bucket partials over the contiguous-by-schedule-
1324 // index partition) and 6b (cross-thread bucket reduction over a uniform-width digit
1325 // slice). Small MSMs short-circuit to trivial_msm_threaded above this point.
1326
1327 // n is the working scalar/point count (GLV doubles it); NUM_BITS is the post-recoding
1328 // window-bit budget (128 for GLV, FULL_NUM_BITS otherwise).
1329 const size_t n = use_glv ? (2 * n_input) : n_input;
1330 const size_t NUM_BITS = use_glv ? size_t{ 128 } : FULL_NUM_BITS;
1331 BB_ASSERT_LTE(n,
1333 "working scalar indices must fit in the 29-bit schedule payload");
1334 std::span<ScalarField> scalars;
1335 std::span<const AffineElement> points;
1336 const bool inline_glv_double = use_glv && !external_glv_provided;
1337
1338 // Activation gate: caller-supplied hint opts this MSM_fast into the dedup pre-pass.
1339 // Hint-driven so polynomials with low duplicate density (PC counters, range checks)
1340 // skip the O(n) tagging cost. The small-n bail above (pts_per_thread <
1341 // MIN_PTS_PER_THREAD_FOR_PIPPENGER) already shed every case where dedup wouldn't fit
1342 // — n ≥ MIN_PTS_PER_THREAD_FOR_PIPPENGER * 1 = 24 here.
1343 const bool dedup_active = dedup_hint;
1344
1345 // ---------------------------------------------------------------------------------------
1346 // Arena setup (pre-Phase-1).
1347 //
1348 // The per-MSM_fast arena is allocated BEFORE Phase 1 so the Phase 1 prologue (msb_per_scalar,
1349 // glv_*_storage, per_thread_msb_hist) lives inside the arena instead of on the heap.
1350 // Once Phase 1 finishes and the window schedule is known (T, B_eff, dense_stride, wpb),
1351 // we partition the remaining capacity into three named zones
1352 // (Zone P / Zone W / Zone S) — see the "Arena zone layout" block after the wpb solve.
1353 //
1354 // We size the buffer using `compute_arena_bytes_for_msm`, whose conservative bound
1355 // dominates the inline-tight (P + W + S) sum for any wpb we choose below.
1356 // ---------------------------------------------------------------------------------------
1357 const size_t arena_total_bytes = compute_arena_bytes_for_msm<Curve>(n_input, external_glv_provided, dedup_active);
1358 round_parallel_detail::MsmArena arena(arena_total_bytes, external_arena);
1359
1360 // ---------------------------------------------------------------------------------------
1361 // Phase 1 — convert scalars from Montgomery, optionally GLV-split, populate msb buffer.
1362 // The msb_per_scalar buffer feeds max-msb num_windows selection;
1363 // per-thread msb_hist counts (bin 0 = zero, bin k+1 = msb == k) feed the n_active gate
1364 // and the active-scalar gate.
1365 //
1366 // When dedup is active the per-scalar dedup work (hash + linear-probe shared atomic
1367 // table, per-thread dup_pair recording) is fused into the same per-thread loop so
1368 // scalars stay hot in L1 between from-Mont and the hash. The post-pass (sort, cluster
1369 // build, chunked tree-reduce, redirect_lookup) runs sequentially after the parallel_for
1370 // — see `dedup_finalize_parallel`.
1371 // ---------------------------------------------------------------------------------------
1372 using round_parallel_detail::MSB_ZERO_SENTINEL;
1373 const size_t profile_threads = std::max<size_t>(1, bb::get_num_cpus());
1374 auto msb_per_scalar = arena.template alloc<uint8_t>(n);
1375 auto per_thread_msb_hist = arena.template alloc<std::array<uint32_t, 256>>(profile_threads);
1376 // MsmArena::alloc returns uninitialised memory; the histograms must be zero-initialised so
1377 // record_msb's increments land on a clean slate.
1378 std::fill_n(per_thread_msb_hist.data(), profile_threads, std::array<uint32_t, 256>{});
1379
1380 // GLV storage (optional). `glv_scalars_storage` is the GLV-split working scalar buffer;
1381 // `glv_points_storage` is the inline-doubled point buffer (skipped when the caller
1382 // supplied an external doubled buffer). Both span empty when `use_glv` is false.
1383 std::span<ScalarField> glv_scalars_storage;
1384 std::span<AffineElement> glv_points_storage;
1385 if (use_glv) {
1386 glv_scalars_storage = arena.template alloc<ScalarField>(n);
1387 if (inline_glv_double) {
1388 glv_points_storage = arena.template alloc<AffineElement>(n);
1389 } else {
1390 BB_ASSERT_EQ(external_glv_doubled.size(), n);
1391 }
1392 }
1393
1394 if (use_glv) {
1395 // Convert each input scalar from-Mont into a stack local, GLV-split it, store both
1396 // 128-bit halves and their msb into the profile buffer. input_scalars is read-only on
1397 // this path so the user's buffer is preserved (no Montgomery restore needed). Inline
1398 // path additionally GLV-doubles the points in the same parallel pass; external path
1399 // aliases the caller-supplied doubled buffer.
1400 const BaseField beta = inline_glv_double ? BaseField::cube_root_of_unity() : BaseField{};
1401 bb::parallel_for(bb::get_num_cpus(), [&](const ThreadChunk& chunk) {
1402 auto& th_hist = per_thread_msb_hist[chunk.thread_index];
1403 for (size_t i : chunk.range(n_input)) {
1404 const ScalarField canonical = input_scalars[i].from_montgomery_form_reduced();
1405 const auto split = ScalarField::split_into_endomorphism_scalars(canonical);
1406 const auto& k1 = split.first;
1407 const auto& k2 = split.second;
1408 glv_scalars_storage[2 * i].data[0] = k1[0];
1409 glv_scalars_storage[2 * i].data[1] = k1[1];
1410 glv_scalars_storage[2 * i].data[2] = 0;
1411 glv_scalars_storage[2 * i].data[3] = 0;
1412 glv_scalars_storage[(2 * i) + 1].data[0] = k2[0];
1413 glv_scalars_storage[(2 * i) + 1].data[1] = k2[1];
1414 glv_scalars_storage[(2 * i) + 1].data[2] = 0;
1415 glv_scalars_storage[(2 * i) + 1].data[3] = 0;
1416 if (inline_glv_double) {
1417 glv_points_storage[2 * i] = input_points[i];
1418 glv_points_storage[(2 * i) + 1].x = input_points[i].x * beta;
1419 glv_points_storage[(2 * i) + 1].y = -input_points[i].y;
1420 }
1421 round_parallel_detail::record_msb(
1422 round_parallel_detail::msb_of_2limb(k1[0], k1[1]), msb_per_scalar[2 * i], th_hist);
1423 round_parallel_detail::record_msb(
1424 round_parallel_detail::msb_of_2limb(k2[0], k2[1]), msb_per_scalar[(2 * i) + 1], th_hist);
1425 }
1426 });
1427 points =
1428 inline_glv_double ? std::span<const AffineElement>(glv_points_storage.data(), n) : external_glv_doubled;
1429 scalars = glv_scalars_storage;
1430 } else {
1431 // Non-GLV path: in-place from-Mont (later restored in the Stage-7 epilogue).
1432 bb::parallel_for(bb::get_num_cpus(), [&](const ThreadChunk& chunk) {
1433 auto& th_hist = per_thread_msb_hist[chunk.thread_index];
1434 for (size_t i : chunk.range(n_input)) {
1435 input_scalars[i].self_from_montgomery_form_reduced();
1436 round_parallel_detail::record_msb(
1437 round_parallel_detail::msb_of_4limb(input_scalars[i].data), msb_per_scalar[i], th_hist);
1438 }
1439 });
1440 scalars = input_scalars;
1441 points = input_points;
1442 }
1443
1444 std::array<uint64_t, 256> msb_hist{};
1445 for (size_t t = 0; t < profile_threads; ++t) {
1446 for (size_t b = 0; b < 256; ++b) {
1447 msb_hist[b] += per_thread_msb_hist[t][b];
1448 }
1449 }
1450 const size_t n_active_early = n - static_cast<size_t>(msb_hist[0]);
1451
1452 // ---------------------------------------------------------------------------------------
1453 // Phase 2 — bail to trivial_msm_threaded when n_active is too small to amortise pippenger_fast's
1454 // per-window scaffolding. trivial_msm_threaded -> straus_msm wants Montgomery scalars, so
1455 // re-Mont-form them in parallel before dispatching.
1456 // ---------------------------------------------------------------------------------------
1457 {
1458 const size_t max_threads_dispatch = bb::get_num_cpus();
1459 const size_t threads_for_dispatch = std::max<size_t>(1, std::min(n_active_early, max_threads_dispatch));
1460 const size_t pts_per_thread = (n_active_early + threads_for_dispatch - 1) / threads_for_dispatch;
1461 if (pts_per_thread < MIN_PTS_PER_THREAD_FOR_PIPPENGER) {
1462 bb::parallel_for(bb::get_num_cpus(), [&](const ThreadChunk& chunk) {
1463 for (size_t i : chunk.range(n)) {
1464 scalars[i].self_to_montgomery_form();
1465 }
1466 });
1467 std::span<const ScalarField> scalars_const(scalars.data(), n);
1468 PolynomialSpan<const ScalarField> ps(0, scalars_const);
1469 return trivial_msm_threaded<Curve>(ps, points);
1470 }
1471 }
1472
1473 // ---------------------------------------------------------------------------------------
1474 // Phase 3 — pick the window layout, build the schedule, run the pipeline, sum into the result.
1475 // ---------------------------------------------------------------------------------------
1476 const size_t num_logical_threads_for_c = bb::get_num_cpus() * window_bits_tuning_oversub_factor(n_input);
1477
1478 // Shrink the bit budget to the highest non-empty msb_hist bin so num_windows is determined
1479 // by the actual data, not the conservative GLV / FULL_NUM_BITS bound.
1480 size_t effective_num_bits = 0;
1481 for (size_t bin = 256; bin > 1;) {
1482 --bin;
1483 if (msb_hist[bin] != 0) {
1484 effective_num_bits = bin;
1485 break;
1486 }
1487 }
1488 if (effective_num_bits == 0 || effective_num_bits > NUM_BITS) {
1489 effective_num_bits = NUM_BITS;
1490 }
1491 const size_t window_bits =
1492 round_parallel_detail::choose_window_bits(n, effective_num_bits, n_input, num_logical_threads_for_c);
1493 const size_t num_buckets = (size_t{ 1 } << (window_bits - 1)) + 1;
1494
1495 // Schedule-based dedup state. The two arrays are allocated from the per-MSM_fast arena
1496 // *from the arena after Phase 1.
1497 // Until then, both spans are empty.
1498 // Lifetimes:
1499 // redirect_lookup — written by Phase A; read by Stage 4b's dedup_patch_schedule per batch
1500 // extra_points — written by Phase A; read by Stage 6a's reduce_chunk per batch
1501 // Both must survive until the last Stage 6a, so they sit in the arena (which is freed
1502 // when this function returns).
1504
1505 // Variable-window split was removed from the production path after Chonk traces showed
1506 // it regressing this rewrite. Keep the schedule uniform and run one region over all
1507 // non-zero scalars.
1508 const auto sched = round_parallel_detail::build_var_window_schedule(effective_num_bits, window_bits);
1509 BB_ASSERT_LTE(sched.num_windows,
1511 "window schedule exceeds compile-time max window count");
1512
1517
1518 // Thread count: aim for `lmul × physical_cpus` logical tasks so the rpmsm pool can
1519 // FIFO-balance heterogeneous P/E cores; cap at `n / MIN_BATCH_CAPACITY` so each chunk
1520 // can saturate the batched-affine drains. `bb::get_num_cpus() <= 1` is the chonk
1521 // batch-verifier's signal that outer parallelism owns all cores — run sequentially.
1522 const size_t desired_threads = std::max<size_t>(1, bb::get_num_cpus());
1523 const size_t max_threads_for_min_batch = std::max<size_t>(1, n / MIN_BATCH_CAPACITY);
1524 const size_t num_threads = std::min(desired_threads, max_threads_for_min_batch);
1525
1526 // Stage 6's tree-reduce splits each thread's chunk into sub-chunks of at most
1527 // SUBCHUNK_ENTRIES_CAP entries before calling reduce_chunk, bounding per-thread scratch
1528 // independent of n. 2048 keeps level-0 saturated (≥ 4 BATCH_CAPACITY drains at typical
1529 // c=16) while the deepest level still hits BATCH_AFFINE_BREAKEVEN (~32 pairs); halving
1530 // breaks the deep levels and doubling wastes memory.
1531 // Pick windows_in_batch so per-MSM_fast working set fits in ~32 MB. Empirically 32 MB
1532 // performs as well as 128 MB on the WASM grid (the recursive affine bucket reduction
1533 // recovers most of the small-batch loss).
1534 // The per_window_bytes / fixed_overhead formulas below mirror this enum of allocations
1535 // exactly. Anyone adding an arena buffer must update both the alloc and the corresponding
1536 // term in those formulas, otherwise windows_per_batch drifts off the BATCH_MEM_BUDGET.
1537
1538 // Per-(w, t) slot stride must fit the widest schedule window.
1539 size_t B_eff = num_buckets;
1540 for (size_t w = 0; w < sched.num_windows; ++w) {
1541 B_eff = std::max(B_eff, static_cast<size_t>(sched.num_buckets[w]));
1542 }
1543
1544 const size_t worker_total_for_budget = num_threads;
1545 const size_t dense_stride_est = round_parallel_detail::compute_dense_stride(B_eff, num_threads);
1546 const size_t bucket_partials_per_window_max =
1548 const size_t per_window_bytes_lo = round_parallel_detail::compute_per_window_bytes<Curve>(
1549 num_threads, B_eff, n, dense_stride_est, worker_total_for_budget);
1550
1551 const size_t global_max_overflow_per_window_for_budget =
1552 round_parallel_detail::compute_global_max_overflow_per_window(n, num_threads, SUBCHUNK_ENTRIES_CAP);
1553
1554 const size_t phase_one_prologue_bytes =
1555 round_parallel_detail::compute_phase_one_prologue_bytes(n, use_glv, inline_glv_double, profile_threads);
1556
1557 const auto phase_a_caps = round_parallel_detail::compute_phase_a_caps(n, num_threads);
1558 const size_t phase_a_cluster_members_cap = phase_a_caps.members_cap;
1559 const size_t phase_a_cluster_offsets_cap = phase_a_caps.offsets_cap;
1560
1561 // Zone W per-worker UNION via the canonical layout walk. The wpb-dependent Stage 6
1562 // tail is added separately after `windows_per_batch` is solved; here we only need
1563 // the union bytes for the fixed_overhead → wpb budget.
1565 /*chunk_capacity=*/SUBCHUNK_ENTRIES_CAP,
1566 global_max_overflow_per_window_for_budget,
1567 dedup_active,
1568 phase_a_cluster_members_cap,
1569 phase_a_cluster_offsets_cap,
1570 /*windows_per_batch=*/0,
1571 /*dense_stride_est=*/0);
1572 const size_t worker_union_bytes_for_budget = budget_layout.per_worker_union_bytes;
1573
1574 const size_t fixed_overhead = (worker_union_bytes_for_budget * worker_total_for_budget) +
1575 (size_t{ 96 } * round_parallel_detail::VAR_WINDOW_MAX_WINDOWS) // window_sums_storage
1576 + (size_t{ 8 } * (num_threads + 1)) // rebalanced_bucket_lo_partition
1577 + phase_one_prologue_bytes;
1578
1579 // Solve `wpb · per_window_bytes ≤ BATCH_MEM_BUDGET − fixed_overhead`.
1580 const size_t available_budget =
1581 (BATCH_MEM_BUDGET > fixed_overhead) ? (BATCH_MEM_BUDGET - fixed_overhead) : size_t{ 0 };
1582 const size_t windows_per_batch =
1583 round_parallel_detail::solve_wpb(per_window_bytes_lo, available_budget, sched.num_windows);
1584
1585 // Per-thread chunk-capacity scratch sizing. A thread's per-window slice is split into
1586 // sub-chunks of at most SUBCHUNK_ENTRIES_CAP entries. Worst-case overflow per
1587 // (thread, window) is one partial per sub-chunk boundary that lands mid-run, bounded
1588 // above by `ceil(max_chunk_len / SUBCHUNK_ENTRIES_CAP)` where max_chunk_len ≤ n/T.
1589 // The Stage 6a end-of-window overflow merge runs tree_reduce on `2 × overflow` entries
1590 // (each affected slot contributes a dense head + ≥1 overflow entry). Tree-reduce
1591 // scratch must fit either a sub-chunk's reduce_chunk input (up to SUBCHUNK_ENTRIES_CAP)
1592 // or a full overflow merge — take the max.
1593 const size_t global_max_chunk_len = (n + num_threads - 1) / num_threads;
1594 const size_t global_max_overflow_per_window =
1595 (global_max_chunk_len + SUBCHUNK_ENTRIES_CAP - 1) / SUBCHUNK_ENTRIES_CAP;
1596 const size_t chunk_capacity = std::max(SUBCHUNK_ENTRIES_CAP, 2 * global_max_overflow_per_window);
1597
1598 // Per-OS-thread scratch. The rpmsm pool dispatches `num_threads` logical tasks across
1599 // `worker_total = num_threads = physical_cpus` OS threads. Tasks on the same
1600 // OS thread run sequentially (FIFO claim), so they share scratch — every field in
1601 // ThreadScratch is overwritten fresh at task start, never read across tasks. Indexing
1602 // by `worker_id` (rather than `tid`) keeps memory linear in physical_cpus instead of
1603 // num_threads = lmul × physical_cpus.
1604 const size_t worker_total = num_threads;
1605 std::vector<round_parallel_detail::ThreadScratch<Curve>> thread_scratch(worker_total);
1607 if (dedup_active) {
1608 phase_a_scratch.resize(worker_total);
1609 }
1610
1611 // ---------------------------------------------------------------------------------------
1612 // Arena zone layout — set up after Phase 1 and schedule selection (see
1613 // https://gist.github.com/AztecBot/7c5ef0581350f6fdb9711679552fd86f §1, §4, §5).
1614 //
1615 // [0 .. bytes_P) Zone P — whole-MSM_fast permanent
1616 // msb_per_scalar (already alloc'd above)
1617 // glv_scalars / glv_points (already alloc'd above)
1618 // per_thread_msb_hist (already alloc'd above)
1619 // window_sums (Stage 7 accumulator)
1620 // redirect_lookup, extra_points (dedup, if active)
1621 // [bytes_P .. bytes_P + bytes_W) Zone W — per-worker union slab × T
1622 // Stage 6a/6b ThreadScratch fields and PhaseA
1623 // scratch overlay the same per-worker bytes; the
1624 // wpb-dependent Stage 6 fields sit immediately
1625 // after the union. Stage 6a, Stage 6b, and Phase A
1626 // run in distinct parallel_for invocations and
1627 // never co-exist on a worker.
1628 // [bytes_P + bytes_W .. arena.capacity)
1629 // Zone S — per-batch swing region (schedule, HIST slot,
1630 // DENSE slot, partition metadata).
1631 // HIST slot overlays H ↔ O on one byte slab:
1632 // H (S1-S4): digit_cursors
1633 // O (S6b-S7): chunk_outputs/window_partial_sums
1634 // Slot per-window = max(H, O). At chonk this is
1635 // H-bound (~256 KiB/window).
1636 // DENSE slot is dedicated for D (S6a-S6b):
1637 // bucket_partials_dense / _present
1638 // (~135 KiB/window at chonk). The D-class was
1639 // moved out of the HIST slot to eliminate L1
1640 // cache aliasing on the Stage 6a scatter writes
1641 // (+1.29% regression observed when D was overlaid
1642 // at the HIST offset).
1643 //
1644 // wpb solve: BATCH_MEM_BUDGET - bytes_P - bytes_W_fixed - bytes_S_shared - 32 KiB pad,
1645 // divided by (bytes_S_per_window + bytes_W_per_wpb). per_window_bytes_shared accounts
1646 // for HIST + DENSE as two separate slots.
1647 // ---------------------------------------------------------------------------------------
1648
1649 // Freeze Zone P prefix at the post-Phase-1 cursor — everything allocated so far
1650 // (msb_per_scalar, glv storage, per_thread_msb_hist) is Zone P permanent state.
1651 const size_t bytes_P_prefix = arena.cursor;
1652
1653 // Per-worker fixed-bytes "union": ThreadScratch's wpb-independent fields overlay the
1654 // PhaseAScratch fields. Compute each layout's strict byte requirement (including the
1655 // alignment slop a bump cursor would consume), then take the max.
1656 auto align_up = [](size_t off, size_t align) -> size_t { return (off + align - 1) & ~(align - 1); };
1657 auto layout_add = [&](size_t& off, size_t bytes, size_t align) { off = align_up(off, align) + bytes; };
1658
1659 // Per-worker layout via the canonical walk (single source of truth shared with
1660 // `compute_arena_bytes_for_msm`). Pre-wpb-solve usage there passes wpb=0; here we
1661 // pass the actual windows_per_batch so the Stage 6 wpb-dependent tail is included.
1662 const round_parallel_detail::PerWorkerArenaLayout<Curve> worker_layout(chunk_capacity,
1663 global_max_overflow_per_window,
1664 dedup_active,
1665 phase_a_cluster_members_cap,
1666 phase_a_cluster_offsets_cap,
1667 windows_per_batch,
1668 dense_stride_est);
1670 const size_t per_worker_union_bytes = worker_layout.per_worker_union_bytes;
1671 const size_t per_worker_bytes = worker_layout.per_worker_bytes;
1672
1673 // Zone P extra (post-decision permanent state): window_sums + dedup state. Sized
1674 // with the strict alignment a bump cursor would apply.
1675 constexpr size_t VAR_WINDOW_WINDOW_SUMS_CAP = round_parallel_detail::VAR_WINDOW_MAX_WINDOWS;
1676 size_t bytes_P_extra_layout = 0;
1677 layout_add(bytes_P_extra_layout, sizeof(Element) * VAR_WINDOW_WINDOW_SUMS_CAP, alignof(Element));
1678 if (dedup_active) {
1679 layout_add(bytes_P_extra_layout, sizeof(uint32_t) * n, alignof(uint32_t));
1680 layout_add(bytes_P_extra_layout,
1681 sizeof(AffineElement) * round_parallel_detail::DEDUP_MAX_CLUSTERS,
1682 alignof(AffineElement));
1683 }
1684
1685 // Zone sizes. The Zone W slab uses `MsmArena::bump_alloc` which aligns in ABSOLUTE address
1686 // space (the arena buffer base is only `__STDCPP_DEFAULT_NEW_ALIGNMENT__`-aligned, but
1687 // AffineElement is alignas(64)). To make the per-worker layout match the layout-only
1688 // calc (which assumes the slab starts on a 64-byte boundary), bias bytes_P so the
1689 // absolute address `arena.data + bytes_P` is 64-aligned.
1690 const size_t arena_base_misalign = static_cast<size_t>(arena.base_addr & (WORKER_SLAB_ALIGN - 1));
1691 const size_t bytes_P_min = align_up(bytes_P_prefix, alignof(Element)) + bytes_P_extra_layout;
1692 const size_t bytes_P = align_up(bytes_P_min + arena_base_misalign, WORKER_SLAB_ALIGN) - arena_base_misalign;
1693 // bytes_W: per_worker_bytes is already rounded to WORKER_SLAB_ALIGN, so consecutive
1694 // slabs stay aligned once the first slab is aligned.
1695 const size_t bytes_W = per_worker_bytes * worker_total;
1696
1697 // Sanity: zones must fit. The conservative `compute_arena_bytes_for_msm` upper bound
1698 // sized the buffer to `BATCH_MEM_BUDGET + 32K + dedup_bytes` at worst, which dominates
1699 // every reachable (P + W + S) sum at the inline-tight wpb chosen above.
1700 BB_ASSERT_LTE(bytes_P + bytes_W, arena.capacity);
1701 const size_t bytes_S_total = arena.capacity - bytes_P - bytes_W;
1702
1703 // Per-zone bump cursors. Zone P continues from `bytes_P_prefix`; Zones W and S start
1704 // fresh at their zone base. Zone P's bound is `bytes_P` so the bump cursor stays inside
1705 // its slot even if the extra slabs alignment-slop a hair.
1706 size_t zone_P_cursor = bytes_P_prefix;
1707 size_t zone_S_cursor = 0;
1708 auto zone_P_alloc = [&]<typename T>(size_t count) -> std::span<T> {
1709 return arena.template bump_alloc<T>(count, zone_P_cursor, bytes_P, 0);
1710 };
1711 auto zone_S_alloc = [&]<typename T>(size_t count) -> std::span<T> {
1712 return arena.template bump_alloc<T>(count, zone_S_cursor, bytes_S_total, bytes_P + bytes_W);
1713 };
1714 // Zone W is carved into per-worker slabs directly via `MsmArena::bump_alloc` below — each
1715 // worker gets its own (cursor, bound) pair, so a single zone-wide allocator would not
1716 // capture the per-worker discipline.
1717 // The pre-Phase-1 `MsmArena::alloc` cursor is retired here — every subsequent allocation
1718 // routes through `zone_P_alloc`, the per-worker Zone W allocators, or `zone_S_alloc`.
1719
1720 // Zone W: per-worker union slab — Stage6a/6b ThreadScratch and PhaseA fields overlay the
1721 // same per-worker bytes, with the wpb-dependent Stage 6 fields immediately after.
1722 for (size_t t = 0; t < worker_total; ++t) {
1723 // Each worker's slab is a contiguous `per_worker_bytes` window inside Zone W.
1724 const size_t slab_base = t * per_worker_bytes;
1725 auto& s = thread_scratch[t];
1726
1727 // ThreadScratch fixed fields — first view into the union. Bound = union size.
1728 size_t ts_fixed_cur = 0;
1729 auto ts_fixed_alloc = [&]<typename T>(size_t count) -> std::span<T> {
1730 return arena.template bump_alloc<T>(count, ts_fixed_cur, per_worker_union_bytes, bytes_P + slab_base);
1731 };
1732 s.curr_pts = ts_fixed_alloc.template operator()<AffineElement>(chunk_capacity);
1733 s.curr_buckets = ts_fixed_alloc.template operator()<uint32_t>(chunk_capacity);
1734 s.points_to_add = ts_fixed_alloc.template operator()<AffineElement>(2 * BATCH_CAPACITY);
1735 s.inversion_scratch = ts_fixed_alloc.template operator()<BaseField>(BATCH_CAPACITY);
1736 s.pair_dest = ts_fixed_alloc.template operator()<uint32_t>(BATCH_CAPACITY);
1737 s.overflow_slots = ts_fixed_alloc.template operator()<uint32_t>(global_max_overflow_per_window);
1738 s.overflow_pts = ts_fixed_alloc.template operator()<AffineElement>(global_max_overflow_per_window);
1739
1740 // PhaseA fields — second view, overlays the SAME per-worker union bytes. PhaseA's
1741 // parallel_for never overlaps Stage 6a/6b on the same worker, so reusing the bytes is
1742 // safe; the union's size is max(ts_fixed_layout, pa_layout) by construction.
1743 if (dedup_active) {
1744 size_t pa_cur = 0;
1745 auto pa_alloc = [&]<typename T>(size_t count) -> std::span<T> {
1746 return arena.template bump_alloc<T>(count, pa_cur, per_worker_union_bytes, bytes_P + slab_base);
1747 };
1748 auto& ps = phase_a_scratch[t];
1750 ps.cluster_members = pa_alloc.template operator()<uint32_t>(phase_a_cluster_members_cap);
1751 ps.cluster_offsets = pa_alloc.template operator()<uint32_t>(phase_a_cluster_offsets_cap);
1752 ps.dirty_slots = pa_alloc.template operator()<uint16_t>(PWAL::PHASE_A_DIRTY_SLOTS_CAP);
1753 ps.bucket_rep = pa_alloc.template operator()<uint32_t>(PWAL::PHASE_A_BUCKET_REP_CAP);
1754 ps.staged = pa_alloc.template operator()<std::pair<uint32_t, uint32_t>>(PWAL::PHASE_A_STAGED_CAP);
1755 ps.chunk_pts = pa_alloc.template operator()<AffineElement>(PWAL::PHASE_A_CHUNK_CAP);
1756 ps.chunk_ids = pa_alloc.template operator()<uint32_t>(PWAL::PHASE_A_CHUNK_CAP);
1757 }
1758
1759 // Stage 6 wpb-dependent fields — tail of the per-worker slab, BEYOND the union. Bound
1760 // = full per-worker slab size; cursor starts at per_worker_union_bytes so we don't
1761 // overwrite the union region.
1762 size_t ts_tail_cur = per_worker_union_bytes;
1763 auto ts_tail_alloc = [&]<typename T>(size_t count) -> std::span<T> {
1764 return arena.template bump_alloc<T>(count, ts_tail_cur, per_worker_bytes, bytes_P + slab_base);
1765 };
1766 const size_t dense_total = windows_per_batch * dense_stride_est;
1767 const size_t dense_pair_max = dense_total / 2;
1768 s.dense_buckets = ts_tail_alloc.template operator()<AffineElement>(dense_total);
1769 s.is_present = ts_tail_alloc.template operator()<uint8_t>(dense_total);
1770 s.affine_bucket_pairs = ts_tail_alloc.template operator()<std::pair<uint32_t, uint32_t>>(dense_pair_max);
1771 s.affine_bucket_indices = ts_tail_alloc.template operator()<uint32_t>(dense_pair_max);
1772 s.affine_bucket_inversion_scratch = ts_tail_alloc.template operator()<BaseField>(dense_pair_max);
1773 s.chunk_infos =
1774 ts_tail_alloc.template operator()<round_parallel_detail::AffineBucketChunkInfo>(windows_per_batch);
1775 std::fill_n(s.chunk_infos.begin(), windows_per_batch, round_parallel_detail::AffineBucketChunkInfo{});
1776 s.affine_bucket_stride = dense_stride_est;
1777 }
1778
1779 // Zone S: per-batch swing region — schedule + HIST slot + DENSE slot + partition metadata.
1780 const size_t schedule_total = windows_per_batch * n;
1781 auto schedule = zone_S_alloc.template operator()<uint32_t>(schedule_total);
1782
1783 // ----- HIST slot ------------------------------------------------------------------
1784 // Single byte slab backing two non-coexisting lifetime classes:
1785 // Epoch H (Stages 1-4): digit_cursors.
1786 // Epoch O (Stages 6b-7): chunk_outputs, window_partial_sums.
1787 // H dies before O is born (Stage 4 cursor advance ends before Stage 6b first writes
1788 // chunk_outputs / window_partial_sums).
1789 //
1790 // D-class (bucket_partials_dense + bucket_partials_present) previously overlaid this
1791 // slot too, but a 10× interleaved WASM Chonk bench showed Stage 6a regressed +1.29%
1792 // (t=+58) because of L1 cache aliasing on the `dense[slot]/present[slot]` scatter
1793 // writes when D sat at the HIST-overlaid offset. D-class now has its own dedicated
1794 // Zone-S DENSE slot below — see "DENSE slot" comment block.
1795 //
1796 // Phase 4: `digit_cursors` is dual-role within epoch H. After Stage 1 it holds
1797 // per-(w, t) counts of digit d; Stage 2 walks each (w, d) column from t = 0..T-1
1798 // reading the count from slot k and writing back the exclusive prefix-sum offset
1799 // (the count is consumed into `running` BEFORE the slot is overwritten, so the
1800 // in-place transform is mathematically identical to the previous out-of-place
1801 // version). Stage 4 then advances each (w, t) slice as a per-thread cursor.
1802 // Strict aliasing: every access goes through a std::span<T> obtained by
1803 // reinterpret_cast<T*>(hist_slot.data() + offset)
1804 // which is well-defined because std::byte is allowed by [basic.lval] to alias any
1805 // POD type. All overlaid types (uint32_t, size_t, Element, ChunkOutput<Curve>) are
1806 // trivially copyable / standard layout so the two epochs do not require construction
1807 // or destruction calls when the role of the bytes changes.
1808 static_assert(alignof(Element) <= 32, "HIST slot O layout assumes alignof(Element) <= 32");
1809 static_assert(alignof(round_parallel_detail::ChunkOutput<Curve>) <= 32,
1810 "HIST slot O layout assumes alignof(ChunkOutput) <= 32");
1811
1812 auto align_up_local = [](size_t off, size_t a) -> size_t { return (off + a - 1) & ~(a - 1); };
1813
1814 // Exact byte requirements for each epoch (matches the budget formula above).
1815 const size_t hist_h_bytes_total = (size_t{ 4 } * windows_per_batch * num_threads * B_eff); // digit_cursors
1816
1817 // O epoch layout — chunk_outputs first, then window_partial_sums. Both are alignof
1818 // <= 32; align each up to its own alignment.
1819 size_t o_layout_cur = 0;
1820 o_layout_cur = align_up_local(o_layout_cur, alignof(round_parallel_detail::ChunkOutput<Curve>));
1821 const size_t off_chunk_outputs = o_layout_cur;
1822 o_layout_cur += sizeof(round_parallel_detail::ChunkOutput<Curve>) * windows_per_batch * num_threads;
1823 o_layout_cur = align_up_local(o_layout_cur, alignof(typename Curve::Element));
1824 const size_t off_window_partial_sums = o_layout_cur;
1825 o_layout_cur += sizeof(typename Curve::Element) * num_threads * windows_per_batch;
1826 const size_t hist_o_bytes_total = o_layout_cur;
1827
1828 const size_t hist_slot_bytes_total = std::max(hist_h_bytes_total, hist_o_bytes_total);
1829 // Round up to AffineElement size so the bump allocator below treats the slot as a
1830 // whole number of 64-byte alignas(64) cells. Allocate via AffineElement to force the
1831 // slot base to be 64-byte aligned in absolute address space — sufficient for the
1832 // H-epoch uint32 digit_cursors span (alignof 4) and the O-epoch ChunkOutput/Element
1833 // spans (alignof ≤ 32).
1834 const size_t hist_slot_cells = (hist_slot_bytes_total + sizeof(AffineElement) - 1) / sizeof(AffineElement);
1835 auto hist_slot_cells_span = zone_S_alloc.template operator()<AffineElement>(hist_slot_cells);
1836 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
1837 std::byte* const hist_slot_bytes = reinterpret_cast<std::byte*>(hist_slot_cells_span.data());
1838
1839 // H-epoch view — live S1..S4. `digit_cursors[(w*T + t) * stride + d]` holds three
1840 // distinct meanings depending on stage:
1841 // * After Stage 1: per-(w, t) count of digit d's occurrences in thread t's slice.
1842 // * After Stage 2: per-(w, t) exclusive prefix-sum offset (cursor base) for the
1843 // bucket-d run inside that window's schedule slot.
1844 // * After Stage 4: offset + count (final cursor end-state); dead from then on.
1845 // Stage 2 reads each (w, t, d) count from this buffer and writes the running prefix
1846 // sum back to the SAME slot before advancing `running`, so the count is preserved
1847 // long enough to feed the accumulator. Stage 4's `++` post-increment on each
1848 // thread's slice runs without atomics because each thread owns its (w, t, *) row
1849 // exclusively.
1850 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
1851 auto digit_cursors =
1852 std::span<uint32_t>{ reinterpret_cast<uint32_t*>(hist_slot_bytes), windows_per_batch * num_threads * B_eff };
1853
1854 // O-epoch views — live S6b..S7. Backed by the SAME bytes as above; H contents are
1855 // dead by the time these are touched. ChunkOutput<Curve> and Curve::Element have
1856 // user-defined constructors so are not formally trivially_copyable, but they are
1857 // standard-layout PODs of fixed bytes (Element is alignas(32) over a fixed-width Fq
1858 // field array). The existing arena pre-Phase-3 already aliases them through std::byte
1859 // buffers via `make_unique_for_overwrite<std::byte[]>` + reinterpret_cast; the
1860 // std::byte aliasing rule in [basic.lval] applies regardless of trivial-copyability.
1862 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
1863 reinterpret_cast<round_parallel_detail::ChunkOutput<Curve>*>(hist_slot_bytes + off_chunk_outputs),
1864 windows_per_batch * num_threads
1865 };
1866 auto window_partial_sums = std::span<typename Curve::Element>{
1867 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
1868 reinterpret_cast<typename Curve::Element*>(hist_slot_bytes + off_window_partial_sums),
1869 num_threads * windows_per_batch
1870 };
1871 // window_partial_sums is reset to identity at the start of each Stage 6b worker
1872 // (`my_partials[w] = point_at_infinity` loop), so we deliberately do NOT initialise
1873 // it here. chunk_outputs is written unconditionally per (w, tprime) in Stage 6b
1874 // (the empty path sets `out.empty = 1`), so no pre-init is needed either.
1875 // ----- end HIST slot --------------------------------------------------------------
1876
1877 // ----- DENSE slot -----------------------------------------------------------------
1878 // Dedicated Zone-S slot for D-class (bucket_partials_dense + bucket_partials_present).
1879 // Lifetime is Stages 6a-6b only. Isolated from the HIST slot so Stage 6a's tight
1880 // scatter loop
1881 // `dst_dense[slot] = pt; dst_present[slot] = 1;`
1882 // does not L1-alias against the HIST slot's H/O bytes (the previous co-located
1883 // layout caused a +1.29% Stage 6a regression in WASM, t=+58 across 10× interleaved
1884 // runs). The dense ↔ present pair stays packed at fixed aligned offsets within this
1885 // slot — they MUST stay close because Stage 6a reads `present[slot]` then writes
1886 // `dense[slot]` / `present[slot]` in tandem in the inner loop.
1887 static_assert(alignof(AffineElement) == 64, "DENSE slot D layout assumes alignof(AffineElement) == 64");
1888 const size_t bp_total = windows_per_batch * bucket_partials_per_window_max;
1889 size_t d_layout_cur = 0;
1890 const size_t off_dense = d_layout_cur;
1891 d_layout_cur += sizeof(AffineElement) * bp_total; // bucket_partials_dense
1892 const size_t off_present = d_layout_cur;
1893 d_layout_cur += sizeof(uint8_t) * bp_total; // bucket_partials_present
1894 const size_t dense_slot_bytes_total = d_layout_cur;
1895 const size_t dense_slot_cells = (dense_slot_bytes_total + sizeof(AffineElement) - 1) / sizeof(AffineElement);
1896 // Allocate via AffineElement to force 64-byte alignment for the leading
1897 // bucket_partials_dense view.
1898 auto dense_slot_cells_span = zone_S_alloc.template operator()<AffineElement>(dense_slot_cells);
1899 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
1900 std::byte* const dense_slot_bytes = reinterpret_cast<std::byte*>(dense_slot_cells_span.data());
1901
1902 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
1903 auto bucket_partials_dense =
1904 std::span<AffineElement>{ reinterpret_cast<AffineElement*>(dense_slot_bytes + off_dense), bp_total };
1905 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
1906 auto bucket_partials_present =
1907 std::span<uint8_t>{ reinterpret_cast<uint8_t*>(dense_slot_bytes + off_present), bp_total };
1908 // ----- end DENSE slot -------------------------------------------------------------
1909
1910 auto bucket_start_all = zone_S_alloc.template operator()<size_t>(windows_per_batch * (B_eff + 1));
1911 auto chunk_start_all = zone_S_alloc.template operator()<size_t>(windows_per_batch * (num_threads + 1));
1912 // chunk_bucket_lo_all[w*(T+1) + t] = bucket index of the first schedule entry in
1913 // chunk t of window w.
1914 // chunk_bucket_hi_all[w*T + t] = bucket index of the last schedule entry in chunk t.
1915 // Chunks are partitioned by schedule index (uniform t·m/T), not by bucket boundary, so
1916 // a bucket's run can straddle threads — both threads then carry a partial for that
1917 // shared bucket and Stage 7's chunk_contribution sum (Σ_d d · partial_d_in_t over t)
1918 // combines them without an explicit merge step.
1919 auto chunk_bucket_lo_all = zone_S_alloc.template operator()<size_t>(windows_per_batch * (num_threads + 1));
1920 auto chunk_bucket_hi_all = zone_S_alloc.template operator()<size_t>(windows_per_batch * num_threads);
1921
1922 // bucket_partials_offsets is the index table that maps (thread, window) -> slot
1923 // start in bucket_partials_dense/present. Lives S5..S6b alongside chunk_start_all,
1924 // and stays as its own Zone S allocation (separate from the DENSE slot).
1925 auto bucket_partials_offsets = zone_S_alloc.template operator()<size_t>((num_threads * windows_per_batch) + 1);
1926
1927 // Stage 6b rebalanced-task partition. The bucket range [1, num_buckets) is split evenly
1928 // across `num_threads` rebalanced tasks t'. The partition is uniform in num_buckets so
1929 // we store T+1 boundaries (not per-window). For each window we record the half-open
1930 // range of original threads whose chunk range intersects each task t' — usually 1-2
1931 // originals per task.
1932 auto rebalanced_bucket_lo_partition = zone_S_alloc.template operator()<size_t>(num_threads + 1);
1933 auto orig_thread_lo = zone_S_alloc.template operator()<size_t>(windows_per_batch * num_threads);
1934 auto orig_thread_hi = zone_S_alloc.template operator()<size_t>(windows_per_batch * num_threads);
1935
1936 // Zone P: window_sums (Stage 7 accumulator — survives the whole MSM_fast).
1937 auto window_sums = zone_P_alloc.template operator()<typename Curve::Element>(VAR_WINDOW_WINDOW_SUMS_CAP);
1938 std::fill_n(window_sums.begin(), VAR_WINDOW_WINDOW_SUMS_CAP, Curve::Group::point_at_infinity);
1939
1940 // Zone P: dedup state — written by Phase A and read through Stage 6a of every batch,
1941 // so it must outlive every batch.
1942 // - redirect_lookup: parallel-filled with DEDUP_INVALID_EXTRA below before Phase A reads it.
1943 // - extra_points: no init needed; Phase A writes per-thread cid ranges, and consumers
1944 // only read indices Phase A actually populated.
1945 if (dedup_active) {
1946 dedup_state.redirect_lookup = zone_P_alloc.template operator()<uint32_t>(n);
1947 dedup_state.extra_points =
1948 zone_P_alloc.template operator()<AffineElement>(round_parallel_detail::DEDUP_MAX_CLUSTERS);
1949 BB_BENCH_NAME("MSM_fast::dedup/redirect_invalid_fill");
1950 uint32_t* const rl = dedup_state.redirect_lookup.data();
1951 bb::parallel_for(bb::get_num_cpus(), [&](const ThreadChunk& chunk) {
1952 for (size_t i : chunk.range(n)) {
1954 }
1955 });
1956 }
1957
1958 // BUCKET_MASK strips the sign bit off a packed (sign | bucket) digit produced by
1959 // get_constantine_packed_digit, leaving the unsigned bucket index.
1960 constexpr uint32_t BUCKET_MASK = (uint32_t{ 1 } << 31) - 1;
1961
1962 // Phase A runs at most once per MSM_fast (not per batch). Cluster membership is determined
1963 // by scalar value (memcmp) — independent of which window we walk — and bucket
1964 // adjacency holds in any window's sorted schedule because true duplicates land in the
1965 // same bucket of every window. So we Phase A on the very first batch's window-0
1966 // schedule, populate `dedup_state.{redirect_lookup, extra_points}` once, and reuse the
1967 // result for every subsequent batch.
1968 bool phase_a_done = false;
1969
1970 auto run_batch = [&](size_t batch_start, size_t windows_in_batch, size_t B_R) noexcept {
1971 // Per-(w, t) slot stride uses `B_eff` = max(num_buckets, B_lo, B_hi); each call
1972 // iterates only the region's first B_R entries. The arena was sized for B_eff per slot.
1973 const size_t bucket_stride = B_eff;
1974 // Per-window slice params. The final window can be narrower when the bit budget
1975 // does not divide evenly by the default window size; the Booth recoder must use
1976 // that narrower width or it encroaches on bits beyond the schedule.
1977 constexpr size_t SCALAR_UINT64_LIMBS = sizeof(ScalarField) / sizeof(uint64_t);
1984 std::array<uint8_t, 128> per_window_bits{};
1985 constexpr size_t SCALAR_U32_LIMBS = sizeof(ScalarField) / sizeof(uint32_t);
1986 for (size_t w = 0; w < windows_in_batch; ++w) {
1987 const size_t global_w = batch_start + w;
1988 const size_t window_bits_w = sched.window_bits_per_window[global_w];
1989 per_window_bits[w] = static_cast<uint8_t>(window_bits_w);
1991 sched.bit_base[global_w], window_bits_w, SCALAR_UINT64_LIMBS);
1993 sched.bit_base[global_w], window_bits_w, SCALAR_U32_LIMBS);
1994 slice_paths[w] = round_parallel_detail::classify_slice_path_u32(slice_params_u32[w]);
1995 const uint32_t lo_mask = slice_params_u32[w].lo_mask;
1996 const uint32_t hi_mask = slice_params_u32[w].hi_mask;
1997 const uint32_t val_mask = (uint32_t{ 1 } << static_cast<uint32_t>(window_bits_w)) - 1;
1998 lo_mask_vectors[w] = round_parallel_detail::SimdU32x4{ lo_mask, lo_mask, lo_mask, lo_mask };
1999 hi_mask_vectors[w] = round_parallel_detail::SimdU32x4{ hi_mask, hi_mask, hi_mask, hi_mask };
2000 val_mask_vectors[w] = round_parallel_detail::SimdU32x4{ val_mask, val_mask, val_mask, val_mask };
2001 }
2002
2003 constexpr size_t SIMD_BATCH = 64;
2004 static_assert(SIMD_BATCH % 4 == 0, "SIMD_BATCH must be divisible by 4");
2005 constexpr size_t LIMBS_PER_SCALAR = sizeof(ScalarField) / sizeof(uint32_t);
2006 const auto* scalars_u32 = reinterpret_cast<const uint32_t*>(scalars.data());
2008 auto fill_packed_digit_buffer = [&](size_t w, size_t i, uint32_t* packed_buf) noexcept {
2009 const auto& sp32 = slice_params_u32[w];
2010 const uint32_t window_bits_w = static_cast<uint32_t>(per_window_bits[w]);
2012 for (size_t k = 0; k < SIMD_BATCH; k += 4) {
2014 packed_buf + k,
2015 scalars_u32 + ((i + k + 0) * LIMBS_PER_SCALAR),
2016 scalars_u32 + ((i + k + 1) * LIMBS_PER_SCALAR),
2017 scalars_u32 + ((i + k + 2) * LIMBS_PER_SCALAR),
2018 scalars_u32 + ((i + k + 3) * LIMBS_PER_SCALAR),
2019 sp32.lo_limb,
2020 sp32.lo_off,
2021 lo_mask_vectors[w],
2022 one_v,
2023 val_mask_vectors[w],
2024 window_bits_w);
2025 }
2026 } else if (slice_paths[w] == round_parallel_detail::ConstantineSlicePath::Bottom) {
2027 for (size_t k = 0; k < SIMD_BATCH; k += 4) {
2029 packed_buf + k,
2030 scalars_u32 + ((i + k + 0) * LIMBS_PER_SCALAR),
2031 scalars_u32 + ((i + k + 1) * LIMBS_PER_SCALAR),
2032 scalars_u32 + ((i + k + 2) * LIMBS_PER_SCALAR),
2033 scalars_u32 + ((i + k + 3) * LIMBS_PER_SCALAR),
2034 sp32.hi_limb,
2035 sp32.lo_bits,
2036 hi_mask_vectors[w],
2037 one_v,
2038 val_mask_vectors[w],
2039 window_bits_w);
2040 }
2041 } else {
2042 for (size_t k = 0; k < SIMD_BATCH; k += 4) {
2044 packed_buf + k,
2045 scalars_u32 + ((i + k + 0) * LIMBS_PER_SCALAR),
2046 scalars_u32 + ((i + k + 1) * LIMBS_PER_SCALAR),
2047 scalars_u32 + ((i + k + 2) * LIMBS_PER_SCALAR),
2048 scalars_u32 + ((i + k + 3) * LIMBS_PER_SCALAR),
2049 sp32.lo_limb,
2050 sp32.hi_limb,
2051 sp32.lo_off,
2052 sp32.lo_bits,
2053 lo_mask_vectors[w],
2054 hi_mask_vectors[w],
2055 one_v,
2056 val_mask_vectors[w],
2057 window_bits_w);
2058 }
2059 }
2060 };
2061
2062 // Capture the dedup state before Stage 1. The first batch must build the ordinary
2063 // R14 schedule so Phase A can discover clusters, then patch+compact that batch.
2064 // Later batches can schedule cluster reps directly and omit non-reps up front.
2065 const bool phase_a_done_at_batch_start = phase_a_done;
2066 const bool dedup_known_for_batch =
2067 dedup_active && phase_a_done_at_batch_start && dedup_state.n_dedup_extras != 0;
2068
2069 // Stage 1 (digit extraction): per-thread per-window bucket histograms. Work is
2070 // scalar-blocked across the windows in this batch so scalars/msb/dedup metadata are
2071 // read once per block and reused while still hot.
2072 auto stage1_digit_extract = [&]<bool DedupKnown>(size_t tid) noexcept {
2073 [[maybe_unused]] const uint32_t* const rl_data = dedup_state.redirect_lookup.data();
2074 for (size_t w = 0; w < windows_in_batch; ++w) {
2075 uint32_t* my_counts = digit_cursors.data() + (((w * num_threads) + tid) * bucket_stride);
2076 std::memset(my_counts, 0, B_R * sizeof(uint32_t));
2077 }
2078 const size_t start = tid * n / num_threads;
2079 const size_t end = (tid + 1) * n / num_threads;
2080
2081 alignas(16) std::array<uint32_t, SIMD_BATCH> packed_buf{};
2082 // Pack the per-block filter into a uint64 bitmask. When every scalar in the block
2083 // is active (common in dense workloads), the inner scatter takes an all_included
2084 // fast path that drops the per-element predicate; mixed blocks bit-scan the mask.
2085 auto compute_include_mask = [&](size_t block_start) noexcept -> uint64_t {
2086 uint64_t include_mask = 0;
2087 for (size_t k = 0; k < SIMD_BATCH; ++k) {
2088 const size_t scalar_idx = block_start + k;
2089 const uint8_t m = msb_per_scalar[scalar_idx];
2090 bool include = (m != MSB_ZERO_SENTINEL);
2091 if constexpr (DedupKnown) {
2092 if (include) {
2093 const uint32_t patch = rl_data[scalar_idx];
2094 include = (patch == round_parallel_detail::DEDUP_INVALID_EXTRA ||
2096 }
2097 }
2098 include_mask |= static_cast<uint64_t>(include) << k;
2099 }
2100 return include_mask;
2101 };
2102
2103 size_t i = start;
2104 while (i + SIMD_BATCH <= end) {
2105 const uint64_t include_mask = compute_include_mask(i);
2106 if (include_mask == 0) {
2107 i += SIMD_BATCH;
2108 continue;
2109 }
2110 const bool all_included = include_mask == ~uint64_t{ 0 };
2111 for (size_t w = 0; w < windows_in_batch; ++w) {
2112 fill_packed_digit_buffer(w, i, packed_buf.data());
2113 uint32_t* my_counts = digit_cursors.data() + (((w * num_threads) + tid) * bucket_stride);
2114 if (all_included) {
2115 for (size_t k = 0; k < SIMD_BATCH; ++k) {
2116 ++my_counts[packed_buf[k] & BUCKET_MASK];
2117 }
2118 } else {
2119 uint64_t scatter_mask = include_mask;
2120 for (size_t k = 0; k < SIMD_BATCH; ++k) {
2121 if ((scatter_mask & uint64_t{ 1 }) != 0) {
2122 ++my_counts[packed_buf[k] & BUCKET_MASK];
2123 }
2124 scatter_mask >>= 1;
2125 }
2126 }
2127 }
2128 i += SIMD_BATCH;
2129 }
2130
2131 // Tail (0..SIMD_BATCH-1 scalars). Same scalar-major loop order; per-scalar
2132 // active check inlined since the block is short.
2133 for (; i < end; ++i) {
2134 const uint8_t m = msb_per_scalar[i];
2135 if (m == MSB_ZERO_SENTINEL) {
2136 continue;
2137 }
2138 if constexpr (DedupKnown) {
2139 const uint32_t patch = rl_data[i];
2142 continue;
2143 }
2144 }
2145 for (size_t w = 0; w < windows_in_batch; ++w) {
2146 uint32_t* my_counts = digit_cursors.data() + (((w * num_threads) + tid) * bucket_stride);
2147 const round_parallel_detail::ConstantineSliceParams sp = slice_params[w];
2148 const uint32_t window_bits_w = static_cast<uint32_t>(per_window_bits[w]);
2149 const uint32_t packed =
2151 sp.lo_limb,
2152 sp.hi_limb,
2153 sp.lo_off,
2154 sp.lo_bits,
2155 sp.lo_mask,
2156 sp.hi_mask,
2158 window_bits_w);
2159 ++my_counts[packed & BUCKET_MASK];
2160 }
2161 }
2162 };
2163 if (dedup_known_for_batch) {
2164 bb::parallel_for(num_threads, [&](size_t tid) { stage1_digit_extract.template operator()<true>(tid); });
2165 } else {
2166 bb::parallel_for(num_threads, [&](size_t tid) { stage1_digit_extract.template operator()<false>(tid); });
2167 }
2168
2169 // Stage 2 (bucket histogram): per-window per-digit totals + per-thread within-digit
2170 // offsets. Parallelised over digit-chunks; each worker handles its slice of 2^window_bits
2171 // for all windows_in_batch windows. In-place exclusive prefix-sum: each slot
2172 // `digit_cursors[(w*T + t) * stride + d]` is read for its Stage 1 count and then
2173 // overwritten with the running prefix sum (== the cursor base Stage 4 needs). The
2174 // count must be read BEFORE the write or `running` would skip its contribution.
2175 // Phase 5: the per-digit total `running` is written directly into
2176 // `bucket_start_all[w][d+1]` (one cell past where Stage 3 will read), so Stage 3 can
2177 // prefix-sum in place without a separate `bucket_total_counts` buffer. The size_t
2178 // bucket_start cell widens the uint32_t total implicitly.
2179 bb::parallel_for(num_threads, [&](size_t tid) {
2180 const size_t d_start = tid * B_R / num_threads;
2181 const size_t d_end = (tid + 1) * B_R / num_threads;
2182 for (size_t w = 0; w < windows_in_batch; ++w) {
2183 size_t* const bucket_start_w = bucket_start_all.data() + (w * (bucket_stride + 1));
2184 for (size_t d = d_start; d < d_end; ++d) {
2185 if (d == 0) {
2186 continue;
2187 }
2188 uint32_t running = 0;
2189 for (size_t t = 0; t < num_threads; ++t) {
2190 const size_t k = (((w * num_threads) + t) * bucket_stride) + d;
2191 const uint32_t cnt = digit_cursors[k];
2192 digit_cursors[k] = running;
2193 running += cnt;
2194 }
2195 bucket_start_w[d + 1] = running;
2196 }
2197 }
2198 });
2199
2200 // Stage 3 (bucket offsets / prefix sum): per-window serial prefix sum in place.
2201 // Stage 2 already deposited each digit's per-window total at bucket_start[d+1];
2202 // the loop accumulates the running prefix-sum without a separate counts buffer.
2203 {
2204 BB_BENCH_NAME("MSM_fast::Stage2_3_bucket_offsets");
2205 auto build_bucket_offsets_for_window = [&](size_t w) noexcept {
2206 size_t* bucket_start = bucket_start_all.data() + (w * (bucket_stride + 1));
2207 bucket_start[0] = 0;
2208 bucket_start[1] = 0;
2209 for (size_t d = 1; d < B_R; ++d) {
2210 bucket_start[d + 1] += bucket_start[d];
2211 }
2212 };
2213 const size_t offset_threads = std::min(num_threads, windows_in_batch);
2214 if (offset_threads <= 1) {
2215 for (size_t w = 0; w < windows_in_batch; ++w) {
2216 build_bucket_offsets_for_window(w);
2217 }
2218 } else {
2219 bb::parallel_for(offset_threads, [&](size_t tid) {
2220 for (size_t w = tid; w < windows_in_batch; w += offset_threads) {
2221 build_bucket_offsets_for_window(w);
2222 }
2223 });
2224 }
2225 }
2226
2227 // Stage 4 (digit scatter): scalar-cache-blocked, window-local scatter. Re-decodes each
2228 // (point, window) signed digit via the same Constantine carry-less recoder Stage 1 used.
2229 // Stage 4 stores only `sign | scalar_idx`; bucket magnitude is recovered later from
2230 // bucket_start ranges.
2231 // Stage 1 benefits from full scalar-major order because it only updates compact
2232 // per-window histograms. Stage 4 writes large bucket schedules, so full scalar-major
2233 // order opens too many cold write/cursor streams. Instead, process a scalar tile across
2234 // all windows: scalar/msb/dedup metadata are reused while the tile is cache-hot, but each
2235 // inner loop still scatters to one window's schedule at a time.
2236 //
2237 // First-batch Stage 4 is dedup-unaware: every scalar is emitted as
2238 // `sched_w[idx] = sign | scalar_idx`, then Phase A + patch/compact tags cluster
2239 // reps and removes non-reps. Later batches with known dedup state skip non-reps
2240 // here and emit redirect reps directly.
2241 // Splitting the dedup work out of this hot loop avoids a per-iteration
2242 // closure-indirection chain through `dedup_state.redirect_lookup[i]`
2243 // that the WASM JIT does not hoist (~13 ns/iter penalty observed).
2244 auto stage4_emit = [&]<bool DedupKnown>(size_t tid) noexcept {
2245 [[maybe_unused]] const uint32_t* const rl_data = dedup_state.redirect_lookup.data();
2246 const size_t start = tid * n / num_threads;
2247 const size_t end = (tid + 1) * n / num_threads;
2249 std::array<const size_t*, 128> bucket_starts{};
2250 std::array<uint32_t*, 128> schedules{};
2251 for (size_t w = 0; w < windows_in_batch; ++w) {
2252 cursors[w] = digit_cursors.data() + (((w * num_threads) + tid) * bucket_stride);
2253 bucket_starts[w] = bucket_start_all.data() + (w * (bucket_stride + 1));
2254 schedules[w] = schedule.data() + (w * n);
2255 }
2256
2257 alignas(16) std::array<uint32_t, SIMD_BATCH> packed_buf{};
2258 constexpr size_t STAGE4_SCALAR_TILE = 2048;
2260 [[maybe_unused]] std::array<uint32_t, STAGE4_SCALAR_TILE> out_base_tile{};
2261
2262 for (size_t tile_start = start; tile_start < end; tile_start += STAGE4_SCALAR_TILE) {
2263 const size_t tile_end = std::min(end, tile_start + STAGE4_SCALAR_TILE);
2264 const size_t tile_len = tile_end - tile_start;
2265 for (size_t j = 0; j < tile_len; ++j) {
2266 const size_t scalar_idx = tile_start + j;
2267 const uint8_t m = msb_per_scalar[scalar_idx];
2268 bool include = (m != MSB_ZERO_SENTINEL);
2269 if constexpr (DedupKnown) {
2270 uint32_t out_base = static_cast<uint32_t>(scalar_idx);
2271 if (include) {
2272 const uint32_t patch = rl_data[scalar_idx];
2274 include = (patch & round_parallel_detail::DEDUP_SKIP_BIT) == 0;
2275 out_base = patch;
2276 }
2277 }
2278 out_base_tile[j] = out_base;
2279 }
2280 active_tile[j] = static_cast<uint8_t>(include);
2281 }
2282
2283 for (size_t w = 0; w < windows_in_batch; ++w) {
2284 uint32_t* my_cursor = cursors[w];
2285 const size_t* bucket_start = bucket_starts[w];
2286 uint32_t* sched_w = schedules[w];
2287 size_t i = tile_start;
2288 while (i + SIMD_BATCH <= tile_end) {
2289 const size_t rel = i - tile_start;
2290 uint64_t include_mask = 0;
2291 for (size_t k = 0; k < SIMD_BATCH; ++k) {
2292 include_mask |= static_cast<uint64_t>(active_tile[rel + k]) << k;
2293 }
2294 if (include_mask == 0) {
2295 i += SIMD_BATCH;
2296 continue;
2297 }
2298 fill_packed_digit_buffer(w, i, packed_buf.data());
2299 uint64_t scatter_mask = include_mask;
2300 for (size_t k = 0; k < SIMD_BATCH; ++k) {
2301 if ((scatter_mask & uint64_t{ 1 }) != 0) {
2302 const uint32_t packed = packed_buf[k];
2303 const uint32_t bucket_idx = packed & BUCKET_MASK;
2304 if (bucket_idx != 0) {
2305 const uint32_t idx =
2306 static_cast<uint32_t>(bucket_start[bucket_idx]) + my_cursor[bucket_idx]++;
2307 uint32_t out = packed & round_parallel_detail::SCHEDULE_SIGN_BIT;
2308 if constexpr (DedupKnown) {
2309 out |= out_base_tile[rel + k];
2310 } else {
2311 out |= static_cast<uint32_t>(i + k);
2312 }
2313 sched_w[idx] = out;
2314 }
2315 }
2316 scatter_mask >>= 1;
2317 }
2318 i += SIMD_BATCH;
2319 }
2320 for (; i < tile_end; ++i) {
2321 const size_t rel = i - tile_start;
2322 if (active_tile[rel] == 0) {
2323 continue;
2324 }
2325 const round_parallel_detail::ConstantineSliceParams sp = slice_params[w];
2327 scalars[i].data,
2328 sp.lo_limb,
2329 sp.hi_limb,
2330 sp.lo_off,
2331 sp.lo_bits,
2332 sp.lo_mask,
2333 sp.hi_mask,
2335 static_cast<uint32_t>(per_window_bits[w]));
2336 const uint32_t bucket_idx = packed & BUCKET_MASK;
2337 if (bucket_idx != 0) {
2338 const uint32_t idx =
2339 static_cast<uint32_t>(bucket_start[bucket_idx]) + my_cursor[bucket_idx]++;
2340 uint32_t out = packed & round_parallel_detail::SCHEDULE_SIGN_BIT;
2341 if constexpr (DedupKnown) {
2342 out |= out_base_tile[rel];
2343 } else {
2344 out |= static_cast<uint32_t>(i);
2345 }
2346 sched_w[idx] = out;
2347 }
2348 }
2349 }
2350 }
2351 };
2352
2353 if (dedup_known_for_batch) {
2354 bb::parallel_for(num_threads, [&](size_t tid) { stage4_emit.template operator()<true>(tid); });
2355 } else {
2356 bb::parallel_for(num_threads, [&](size_t tid) { stage4_emit.template operator()<false>(tid); });
2357 }
2358
2359 // Phase A: schedule-based dedup detection on window 0. Each thread owns a
2360 // contiguous range of window 0's schedule. Detects duplicate clusters via
2361 // consecutive-pair check (same bucket + memcmp on full scalar value), tree-reduces
2362 // members into an aggregate, and publishes results into `dedup_state.extra_points`,
2363 // `dedup_state.redirect_lookup`, and zeroed `msb_per_scalar` entries for non-reps.
2364 // Per-thread cluster-id ranges keep writes disjoint — no atomics needed.
2365 // Phase A: schedule-based dedup detection. Runs at most ONCE per MSM_fast (gated on
2366 // `phase_a_done` from the enclosing function scope). Cluster membership is decided
2367 // by scalar value (memcmp), so any window's bucket-sorted schedule places duplicates
2368 // consecutively — Phase A on this first-batch's window-0 schedule produces the
2369 // correct redirect_lookup + extra_points for all subsequent batches. We deliberately
2370 // do not re-run Phase A per batch: the dedup_state is populated once and reused.
2371 if (dedup_active && windows_in_batch > 0 && !phase_a_done) {
2372 BB_BENCH_NAME("MSM_fast::PhaseA_dedup_detect");
2373 uint32_t* sched_w0 = schedule.data();
2374 // Pre-Phase-A bucket sort: Stage 4 emits each bucket's run in scalar-emit
2375 // order, so different-value scalars that happen to share a window-0 digit
2376 // (bucket collisions are common — c=11 → 2048 buckets vs 60-90k entries)
2377 // interleave with same-value entries and break Phase A's consecutive-pair
2378 // detection. Sorting each bucket's run by scalar value makes same-value
2379 // entries adjacent so the simple consecutive-pair walk finds every cluster.
2380 // Sort cost: per bucket of size K, ~K log K comparisons × 32-byte memcmp;
2381 // for typical K=44 this is ~500 cycles per bucket × 2048 buckets = ~1 ms
2382 // wall (parallelized across threads).
2383 const uint32_t cids_per_thread =
2384 static_cast<uint32_t>(round_parallel_detail::DEDUP_MAX_CLUSTERS / num_threads);
2385 // Hash-based per-bucket dedup detection: every thread owns a
2386 // contiguous bucket range of window-0's schedule and runs an
2387 // open-addressing hash table over that range's long-scalar entries.
2388 // O(K) per bucket, avoids the 32-byte memcmp comparator inside any
2389 // sort, and keeps thread balance uniform because short-scalar
2390 // entries (the source of mega-buckets like digit_0 = 1) are skipped.
2391 // Catches ~99.94 % of long-scalar duplicates against MSM_DUMP's
2392 // theoretical maximum (`dup_input_extras`).
2393 {
2394 BB_BENCH_NAME("MSM_fast::PhaseA_dedup_detect_hash");
2395 const size_t* const w0_bucket_start = bucket_start_all.data();
2396 std::atomic<size_t> dedup_cluster_count{ 0 };
2397 bb::parallel_for(num_threads, [&, w0_bucket_start](size_t tid) noexcept {
2398 const size_t b_lo = 1 + ((tid * (B_R - 1)) / num_threads);
2399 const size_t b_hi = 1 + (((tid + 1) * (B_R - 1)) / num_threads);
2400 const uint32_t cid_lo = static_cast<uint32_t>(tid) * cids_per_thread;
2401 const uint32_t cid_max = cid_lo + cids_per_thread;
2402 const size_t local_clusters = round_parallel_detail::dedup_phase_a_worker_hash<Curve>(
2403 sched_w0,
2404 w0_bucket_start,
2405 b_lo,
2406 b_hi,
2407 std::span<const ScalarField>(scalars.data(), n),
2408 points,
2410 std::span<uint32_t>(dedup_state.redirect_lookup),
2411 msb_per_scalar.data(),
2412 window_bits,
2413 cid_lo,
2414 cid_max,
2415 phase_a_scratch[tid]);
2416 if (local_clusters != 0) {
2417 dedup_cluster_count.fetch_add(local_clusters, std::memory_order_relaxed);
2418 }
2419 });
2420 dedup_state.n_dedup_extras = dedup_cluster_count.load(std::memory_order_relaxed);
2421 }
2422 phase_a_done = true;
2423 }
2424
2425 // Schedule patch post-pass: tags cluster-member entries with SKIP/REDIRECT bits.
2426 // Runs only for the batch that just ran Phase A: later batches with known dedup
2427 // state skip non-reps in Stage 1/4 and emit redirect reps directly.
2428 // Parallel by window (one window per worker) because each window's slice of the
2429 // schedule is disjoint. Hoisting `redirect_lookup.data()` to a raw pointer outside
2430 // the lambda + passing it by value into the inner function avoids the per-iter
2431 // closure-indirection chain that made the inline form 3× slower per iter on WASM.
2432 auto partition_chunks_for_window = [&](size_t w) noexcept {
2433 const size_t* bucket_start = bucket_start_all.data() + (w * (bucket_stride + 1));
2434 const size_t* const bucket_start_end = bucket_start + B_R + 1;
2435 size_t* chunk_start = chunk_start_all.data() + (w * (num_threads + 1));
2436 size_t* chunk_bucket_lo = chunk_bucket_lo_all.data() + (w * (num_threads + 1));
2437 size_t* chunk_bucket_hi = chunk_bucket_hi_all.data() + (w * num_threads);
2438 const size_t m = bucket_start[B_R];
2439 const size_t* search_begin = bucket_start + 1;
2440 size_t lo = 0;
2441 chunk_start[0] = lo;
2442 for (size_t t = 0; t < num_threads; ++t) {
2443 const size_t hi = ((t + 1) == num_threads) ? m : (((t + 1) * m) / num_threads);
2444 chunk_start[t + 1] = hi;
2445 if (lo < hi) {
2446 const size_t* const lo_it = std::upper_bound(search_begin, bucket_start_end, lo);
2447 const size_t lo_bucket = static_cast<size_t>(lo_it - bucket_start - 1);
2448 const size_t* const hi_it = std::upper_bound(lo_it, bucket_start_end, hi - 1);
2449 const size_t hi_bucket = static_cast<size_t>(hi_it - bucket_start - 1);
2450 chunk_bucket_lo[t] = lo_bucket;
2451 chunk_bucket_hi[t] = hi_bucket;
2452 search_begin = hi_it;
2453 } else {
2454 chunk_bucket_lo[t] = B_R;
2455 chunk_bucket_hi[t] = 0;
2456 }
2457 lo = hi;
2458 }
2459 chunk_bucket_lo[num_threads] = B_R;
2460 };
2461
2462 bool chunk_partition_done = false;
2463 if (dedup_active && windows_in_batch > 0 && phase_a_done && !phase_a_done_at_batch_start) {
2464 BB_BENCH_NAME("MSM_fast::dedup_patch_schedule");
2465 const uint32_t* const rl_data = dedup_state.redirect_lookup.data();
2466 const size_t bs_stride = bucket_stride + 1;
2467 const size_t br = B_R;
2468 const size_t cap_R = n;
2469 bb::parallel_for(num_threads, [&, rl_data, bs_stride, br, cap_R](size_t tid) noexcept {
2470 for (size_t w = tid; w < windows_in_batch; w += num_threads) {
2471 uint32_t* sched_w = schedule.data() + (w * cap_R);
2472 size_t* bucket_start_w = bucket_start_all.data() + (w * bs_stride);
2473 round_parallel_detail::dedup_patch_schedule_window<Curve>(sched_w, bucket_start_w, br, rl_data);
2474 partition_chunks_for_window(w);
2475 }
2476 });
2477 chunk_partition_done = true;
2478 }
2479
2480 // Per-window chunk partition at schedule-index granularity (chunk_start[t] = t·m/T).
2481 // Balances across threads regardless of bucket-distribution skew. When the partition
2482 // lands mid-bucket, both adjacent threads build their own partial into the boundary
2483 // bucket; chunk_contribution combines them in Stage 7.
2484 {
2485 BB_BENCH_NAME("MSM_fast::Stage5_chunk_partition");
2486 if (!chunk_partition_done) {
2487 for (size_t w = 0; w < windows_in_batch; ++w) {
2488 partition_chunks_for_window(w);
2489 }
2490 }
2491 }
2492
2493 // Stage 6 bucket accumulation per thread:
2494 // (1) For each window w: reduce_chunk emits a digit-sorted (point, digit) list,
2495 // which we densify into a per-window dense bucket array at
2496 // tid's affine bucket buffer + w * stride. Empty slots stay identity.
2497 // (2) Call recursive_affine_bucket_reduce_strided once across all windows_in_batch
2498 // chunks; it computes (R_w, L_w) for each non-empty chunk via batch-affine
2499 // arithmetic, amortising the inversion across windows at every phase step.
2500 // (3) chunk_contribution(out) folds L_w + (lo_w-1)·R_w into the thread's per-window
2501 // partial.
2502 // The Stage-6 scratch is pre-sized for every thread BEFORE entering the parallel_for
2503 // so the per-thread vector resizes don't race the heap allocator.
2504 auto next_pow2 = [](size_t x) -> size_t {
2505 if (x <= 1) {
2506 return 1;
2507 }
2508 size_t p = 1;
2509 while (p < x) {
2510 p <<= 1;
2511 }
2512 return p;
2513 };
2514 // Drives reduce_chunk's per-thread tree-reduce buffer sizing.
2515 size_t max_chunk_len = 0;
2516 for (size_t t = 0; t < num_threads; ++t) {
2517 for (size_t w = 0; w < windows_in_batch; ++w) {
2518 const size_t* chunk_start = chunk_start_all.data() + (w * (num_threads + 1));
2519 const size_t entries_in_chunk = chunk_start[t + 1] - chunk_start[t];
2520 if (entries_in_chunk == 0) {
2521 continue;
2522 }
2523 max_chunk_len = std::max(max_chunk_len, entries_in_chunk);
2524 }
2525 }
2526
2527 // global_stride drives the per-thread `dense_buckets` layout (sized via
2528 // `ensure_affine_bucket_capacity` below). Stage 6a writes its per-thread bucket
2529 // partials into `bucket_partials_dense` (a separate buffer packed via
2530 // `bucket_partials_offsets`, no power-of-two stride); Stage 6b copies them into
2531 // `s.dense_buckets` keyed by Stage 6b's uniform bucket-index slice of width
2532 // `buckets_per_task ≈ ⌈(num_buckets-1)/T⌉`. The recursive bucket-reduction
2533 // algorithm (phases A-D) operates on `s.dense_buckets` with power-of-two row
2534 // stride — that's where `next_pow2` matters.
2535 size_t global_stride = 0;
2536
2537 {
2538 // Stage 6b's bucket-balanced partition. Uniform across windows: each rebalanced
2539 // task t' owns active digits [d_lo'[t'], d_hi'[t']] where d_lo'[t'] = 1 + t · (B-1) / T.
2540 const size_t active_digits = (B_R > 0) ? (B_R - 1) : 0;
2541 for (size_t t = 0; t <= num_threads; ++t) {
2542 rebalanced_bucket_lo_partition[t] = 1 + (t * active_digits) / num_threads;
2543 }
2544 rebalanced_bucket_lo_partition[num_threads] = B_R;
2545 size_t max_buckets_per_task = 0;
2546 for (size_t t = 0; t + 1 <= num_threads; ++t) {
2547 const size_t hi_d = (t + 1 == num_threads) ? (B_R - 1) : (rebalanced_bucket_lo_partition[t + 1] - 1);
2548 const size_t lo_d = rebalanced_bucket_lo_partition[t];
2549 if (hi_d >= lo_d) {
2550 max_buckets_per_task = std::max(max_buckets_per_task, hi_d - lo_d + 1);
2551 }
2552 }
2553 global_stride = next_pow2(max_buckets_per_task);
2554 global_stride = std::max<size_t>(global_stride, 2);
2555
2556 // Per-window orig-thread contributing ranges (O(W·T·T) total — only paid for
2557 // the rebalance path, where T is small enough that this is sub-µs).
2558 for (size_t w = 0; w < windows_in_batch; ++w) {
2559 const size_t* chunk_bucket_lo = chunk_bucket_lo_all.data() + (w * (num_threads + 1));
2560 const size_t* chunk_bucket_hi = chunk_bucket_hi_all.data() + (w * num_threads);
2561 const size_t* chunk_start_w = chunk_start_all.data() + (w * (num_threads + 1));
2562 for (size_t tprime = 0; tprime < num_threads; ++tprime) {
2563 const size_t lo_d = rebalanced_bucket_lo_partition[tprime];
2564 const size_t hi_d =
2565 (tprime + 1 == num_threads) ? (B_R - 1) : (rebalanced_bucket_lo_partition[tprime + 1] - 1);
2566 size_t lo_orig = num_threads;
2567 size_t hi_orig = 0;
2568 for (size_t t = 0; t < num_threads; ++t) {
2569 const size_t entries = chunk_start_w[t + 1] - chunk_start_w[t];
2570 if (entries == 0) {
2571 continue;
2572 }
2573 const size_t cl = chunk_bucket_lo[t];
2574 const size_t ch = chunk_bucket_hi[t];
2575 if (ch < lo_d || cl > hi_d) {
2576 continue;
2577 }
2578 if (lo_orig == num_threads) {
2579 lo_orig = t;
2580 }
2581 hi_orig = t;
2582 }
2583 orig_thread_lo[(w * num_threads) + tprime] = lo_orig;
2584 orig_thread_hi[(w * num_threads) + tprime] = hi_orig;
2585 }
2586 }
2587
2588 // bucket_partials_dense / _present packed via bucket_partials_offsets — each
2589 // (thread, window) row holds exactly buckets_per_thread[t][w] AffineElements (no
2590 // padding). The arena pre-sized to `windows_per_batch · (num_buckets - 1 + T)`
2591 // (covers the T-1 boundary-bucket shares); only the actual prefix is touched.
2592 size_t bucket_partials_cursor = 0;
2593 for (size_t t = 0; t < num_threads; ++t) {
2594 for (size_t w = 0; w < windows_in_batch; ++w) {
2595 bucket_partials_offsets[(t * windows_in_batch) + w] = bucket_partials_cursor;
2596 const size_t* chunk_bucket_lo_w = chunk_bucket_lo_all.data() + (w * (num_threads + 1));
2597 const size_t* chunk_bucket_hi_w = chunk_bucket_hi_all.data() + (w * num_threads);
2598 const size_t* chunk_start_w = chunk_start_all.data() + (w * (num_threads + 1));
2599 const size_t entries = chunk_start_w[t + 1] - chunk_start_w[t];
2600 if (entries > 0) {
2601 bucket_partials_cursor += chunk_bucket_hi_w[t] - chunk_bucket_lo_w[t] + 1;
2602 }
2603 }
2604 }
2605 bucket_partials_offsets[num_threads * windows_in_batch] = bucket_partials_cursor;
2606 const size_t bucket_partials_total = bucket_partials_cursor;
2607 BB_ASSERT_LTE(bucket_partials_total, bucket_partials_dense.size());
2608 std::memset(bucket_partials_present.data(), 0, bucket_partials_total);
2609 }
2610
2611 // thread_scratch is worker-indexed (one slot per OS thread, FIFO-shared by tasks);
2612 // update the stride on each worker's slot.
2613 for (size_t t = 0; t < worker_total; ++t) {
2614 thread_scratch[t].affine_bucket_stride = global_stride;
2615 }
2616
2617 {
2618 // Stage 6a — per-thread bucket partials. Each thread `tid` reduces its schedule
2619 // slice via reduce_chunk and scatters the (digit, point) output directly into the
2620 // per-thread dense bucket buffer at slot `(digit - chunk_bucket_lo[tid])`. Stage
2621 // 6b then reads this buffer with O(1) slot lookup. `bucket_partials_present` is
2622 // pre-zeroed per batch.
2623 auto bucket_partials_per_thread_lambda = [&](size_t tid) {
2624 auto& s = thread_scratch[tid];
2625 for (size_t w = 0; w < windows_in_batch; ++w) {
2626 const size_t* chunk_start_w = chunk_start_all.data() + (w * (num_threads + 1));
2627 const size_t cs_lo = chunk_start_w[tid];
2628 const size_t cs_hi = chunk_start_w[tid + 1];
2629 if (cs_lo == cs_hi) {
2630 continue;
2631 }
2632 const uint32_t* sched_w = schedule.data() + (w * n);
2633 const size_t* bucket_start = bucket_start_all.data() + (w * (bucket_stride + 1));
2634 AffineElement* dst_dense =
2635 bucket_partials_dense.data() + bucket_partials_offsets[(tid * windows_in_batch) + w];
2636 uint8_t* dst_present =
2637 bucket_partials_present.data() + bucket_partials_offsets[(tid * windows_in_batch) + w];
2638 const size_t* chunk_bucket_lo = chunk_bucket_lo_all.data() + (w * (num_threads + 1));
2639 const uint32_t my_lo = static_cast<uint32_t>(chunk_bucket_lo[tid]);
2640 const size_t my_hi = chunk_bucket_hi_all[(w * num_threads) + tid];
2641 size_t bucket_cursor = my_lo;
2642
2643 for (size_t pos = cs_lo; pos < cs_hi;) {
2644 const size_t end = std::min(pos + SUBCHUNK_ENTRIES_CAP, cs_hi);
2645 reduce_chunk<Curve>(s,
2646 sched_w,
2647 bucket_start,
2648 pos,
2649 end,
2650 bucket_cursor,
2651 my_hi,
2652 points,
2653 std::span<const AffineElement>(dedup_state.extra_points));
2654 const size_t len = s.result_len;
2655 for (size_t k = 0; k < len; ++k) {
2656 const uint32_t d = s.curr_buckets[k];
2657 const size_t slot = d - my_lo;
2658 if (dst_present[slot]) {
2659 s.overflow_slots[s.overflow_len] = static_cast<uint32_t>(slot);
2660 s.overflow_pts[s.overflow_len] = s.curr_pts[k];
2661 ++s.overflow_len;
2662 } else {
2663 dst_dense[slot] = s.curr_pts[k];
2664 dst_present[slot] = 1;
2665 }
2666 }
2667 pos = end;
2668 }
2669 merge_overflow<Curve>(s, dst_dense);
2670 }
2671 };
2672
2673 // Stage 6b (cross-thread bucket reduction): each rebalanced task `tprime` owns a
2674 // uniform-width slice of the bucket-index space [d_lo'(tprime), d_hi'(tprime)].
2675 // For each window in the batch, walk the contributing original threads' Stage 6a
2676 // dense outputs (range [orig_thread_lo, orig_thread_hi]), filter to digits in
2677 // this task's slice, scatter into the task's local dense_buckets (with
2678 // projective-add accumulation on the at-most-2 boundary digits per pair of
2679 // contributing originals), then run recursive_affine_bucket_reduce_strided +
2680 // chunk_contribution on a guaranteed-equal buckets_padded across all tasks.
2681 auto bucket_reduce_cross_thread_lambda = [&](size_t tprime) {
2682 auto& s = thread_scratch[tprime];
2683 Element* my_partials = window_partial_sums.data() + (tprime * windows_per_batch);
2684 for (size_t w = 0; w < windows_in_batch; ++w) {
2685 my_partials[w] = Curve::Group::point_at_infinity;
2686 }
2687
2688 const size_t stride = s.affine_bucket_stride;
2689 std::memset(s.is_present.data(), 0, windows_in_batch * stride);
2690
2691 const size_t lo_d = rebalanced_bucket_lo_partition[tprime];
2692 const size_t hi_d =
2693 (tprime + 1 == num_threads) ? (B_R - 1) : (rebalanced_bucket_lo_partition[tprime + 1] - 1);
2694 const uint32_t lo_d_u = static_cast<uint32_t>(lo_d);
2695 const uint32_t hi_d_u = static_cast<uint32_t>(hi_d);
2696
2697 bool any_nonempty = false;
2698 for (size_t w = 0; w < windows_in_batch; ++w) {
2699 auto& info = s.chunk_infos[w];
2700 auto& out = chunk_outputs[(w * num_threads) + tprime];
2701 if (lo_d > hi_d) {
2702 info.empty = 1;
2703 info.lo = 0;
2704 info.hi = 0;
2705 info.buckets_padded = 0;
2706 out.empty = 1;
2707 continue;
2708 }
2709 const size_t orig_lo = orig_thread_lo[(w * num_threads) + tprime];
2710 const size_t orig_hi = orig_thread_hi[(w * num_threads) + tprime];
2711 if (orig_lo == num_threads) {
2712 info.empty = 1;
2713 info.lo = 0;
2714 info.hi = 0;
2715 info.buckets_padded = 0;
2716 out.empty = 1;
2717 continue;
2718 }
2719 const size_t base = w * stride;
2720 bool has_data = false;
2721
2722 // bucket_partials_dense holds per-(orig_t, w, slot) bucket points with
2723 // bucket_partials_present as the populated-slot bitmap. For each
2724 // contributing orig_t, intersect its [chunk_bucket_lo, chunk_bucket_hi]
2725 // range with this task's [lo_d, hi_d] slice and walk the intersection
2726 // only — no sorted scan, O(1) lookup per slot.
2727 const size_t* chunk_bucket_lo_w = chunk_bucket_lo_all.data() + (w * (num_threads + 1));
2728 const size_t* chunk_bucket_hi_w = chunk_bucket_hi_all.data() + (w * num_threads);
2729 for (size_t t = orig_lo; t <= orig_hi; ++t) {
2730 const size_t cl = chunk_bucket_lo_w[t];
2731 const size_t ch = chunk_bucket_hi_w[t];
2732 const size_t d_lo_clip = std::max<size_t>(lo_d, cl);
2733 const size_t d_hi_clip = std::min<size_t>(hi_d, ch);
2734 if (d_lo_clip > d_hi_clip) {
2735 continue;
2736 }
2737 const AffineElement* src_dense =
2738 bucket_partials_dense.data() + bucket_partials_offsets[(t * windows_in_batch) + w];
2739 const uint8_t* src_present =
2740 bucket_partials_present.data() + bucket_partials_offsets[(t * windows_in_batch) + w];
2741 for (size_t d = d_lo_clip; d <= d_hi_clip; ++d) {
2742 const size_t src_slot = d - cl;
2743 if (src_present[src_slot] == 0) {
2744 continue;
2745 }
2746 const size_t dst_slot = base + (d - lo_d);
2747 if (s.is_present[dst_slot] == 0) {
2748 s.dense_buckets[dst_slot] = src_dense[src_slot];
2749 s.is_present[dst_slot] = 1;
2750 } else {
2751 // Boundary digit shared between two consecutive originals
2752 // — projective add then re-normalise to affine. Under the
2753 // contiguous-by-schedule-index partition there are at most
2754 // W boundary points per task.
2755 Element acc = Element(s.dense_buckets[dst_slot]);
2756 acc += Element(src_dense[src_slot]);
2757 s.dense_buckets[dst_slot] = AffineElement(acc);
2758 }
2759 has_data = true;
2760 }
2761 }
2762 if (!has_data) {
2763 info.empty = 1;
2764 info.lo = 0;
2765 info.hi = 0;
2766 info.buckets_padded = 0;
2767 out.empty = 1;
2768 continue;
2769 }
2770 any_nonempty = true;
2771 const size_t M = hi_d - lo_d + 1;
2772 const uint32_t buckets_padded =
2773 (M == 1) ? 1 : (uint32_t{ 1 } << (32 - __builtin_clz(static_cast<uint32_t>(M - 1))));
2774 info.empty = 0;
2775 info.lo = lo_d_u;
2776 info.hi = hi_d_u;
2777 info.buckets_padded = buckets_padded;
2778 out.empty = 0;
2779 out.lo = lo_d_u;
2780 out.hi = hi_d_u;
2781 }
2782
2783 if (!any_nonempty) {
2784 return;
2785 }
2786
2787 round_parallel_detail::recursive_affine_bucket_reduce_strided<Curve>(
2788 s, s.chunk_infos.data(), windows_in_batch, chunk_outputs.data() + tprime, num_threads);
2789
2790 for (size_t w = 0; w < windows_in_batch; ++w) {
2791 auto& out = chunk_outputs[(w * num_threads) + tprime];
2792 if (out.empty == 0) {
2793 my_partials[w] = round_parallel_detail::chunk_contribution<Curve>(out);
2794 }
2795 }
2796 };
2797
2798 bb::parallel_for(num_threads, bucket_partials_per_thread_lambda);
2799 bb::parallel_for(num_threads, bucket_reduce_cross_thread_lambda);
2800 }
2801
2802 // Stage 7 (cross-window combine): per-window reduce of `num_threads` per-thread partials.
2803 // (Algebraic identity: `Σ_t (L_t + (lo_t − 1) · R_t) = window's bucket sum`,
2804 // with the per-chunk contributions already accumulated above.)
2805 {
2806 const size_t reduce_threads = std::min(num_threads, windows_in_batch);
2807 bb::parallel_for(reduce_threads, [&](size_t rid) {
2808 const size_t lo = rid * windows_in_batch / reduce_threads;
2809 const size_t hi = (rid + 1) * windows_in_batch / reduce_threads;
2810 for (size_t w = lo; w < hi; ++w) {
2811 Element sum = Curve::Group::point_at_infinity;
2812 for (size_t tid = 0; tid < num_threads; ++tid) {
2813 sum += window_partial_sums[(tid * windows_per_batch) + w];
2814 }
2815 window_sums[batch_start + w] = sum;
2816 }
2817 });
2818 }
2819 };
2820
2821 // Uniform-schedule dispatch over all windows.
2822 {
2823 const size_t B_R = (size_t{ 1 } << (window_bits - 1)) + 1;
2824 for (size_t batch_start = 0; batch_start < sched.num_windows; batch_start += windows_per_batch) {
2825 const size_t windows_in_batch = std::min(windows_per_batch, sched.num_windows - batch_start);
2826 run_batch(batch_start, windows_in_batch, B_R);
2827 }
2828 }
2829
2830 // Stage 7 horner: walk high-to-low, doubling by `window_bits_per_window[w]` between adjacent windows.
2831 // Init from the top window to skip a wasted doubling on identity.
2832 Element result = (sched.num_windows == 0) ? Curve::Group::point_at_infinity : window_sums[sched.num_windows - 1];
2833 for (size_t w_rev = sched.num_windows - 1; w_rev > 0; --w_rev) {
2834 const size_t window_bits_w = sched.window_bits_per_window[w_rev - 1];
2835 for (size_t d = 0; d < window_bits_w; ++d) {
2836 result.self_dbl();
2837 }
2838 result += window_sums[w_rev - 1];
2839 }
2840
2841 // GLV path leaves input_scalars untouched (it reads via from_montgomery_form_reduced into
2842 // a temporary). Non-GLV path mutated in place above and must restore.
2843 if (!use_glv) {
2844 bb::parallel_for(bb::get_num_cpus(), [&](const ThreadChunk& chunk) {
2845 for (size_t i : chunk.range(n_input)) {
2846 input_scalars[i].self_to_montgomery_form();
2847 }
2848 });
2849 }
2850
2851 return result;
2852}
2853
2854template <typename Curve>
2857 bool dedup_hint) noexcept
2858{
2859 return pippenger_round_parallel<Curve>(scalars, points, dedup_hint);
2860}
2861
2862template <typename Curve>
2865 bool handle_edge_cases,
2866 bool dedup_hint) noexcept
2867{
2868 using Element = typename Curve::Element;
2869 using ScalarField = typename Curve::ScalarField;
2870 if (!handle_edge_cases) {
2871 return pippenger_round_parallel<Curve>(scalars, points, dedup_hint);
2872 }
2873 // Edge-case-handling path: route through the Jacobian fast-path. It uses
2874 // Jacobian additions throughout, so point-at-infinity and equal-x bucket
2875 // collisions don't trigger the affine-add edge-case bug. We need to convert
2876 // PolynomialSpan to a plain ScalarField span: the jacobian fast-path takes
2877 // a contiguous std::span and ignores `start_index`.
2878 const size_t n = scalars.span.size();
2879 if (n == 0) {
2880 return Curve::Group::point_at_infinity;
2881 }
2882 // Trivially small N: skip Pippenger / Jacobian-fast-path scaffolding entirely.
2883 // Affine operator* + Jacobian sum already handles all edge cases.
2884 if (n < 4) {
2885 return trivial_msm<Curve>(scalars, points);
2886 }
2887 const auto& start = scalars.start_index;
2888 if (start >= points.size()) {
2889 return Curve::Group::point_at_infinity;
2890 }
2891 const size_t n_used = std::min<size_t>(n, points.size() - start);
2892 std::span<const typename Curve::AffineElement> point_slice(points.data() + start, n_used);
2893 std::span<const ScalarField> scalar_slice(scalars.span.data(), n_used);
2894 // Convert scalars to non-Montgomery form for the jacobian path's bit-extraction loop,
2895 // then restore. Mirrors the round-parallel fast-path's scalar lifecycle.
2896 // Use the `_reduced` variant: the bit-extraction loop reads only bits 0..253
2897 // (NUM_BITS = 254). Plain `self_from_montgomery_form` leaves the value in [0, 2p),
2898 // so values in [2^254, 2p) would have bit 254 set and silently drop the contribution
2899 // of that bit. `_reduced` brings the value into [0, p) ⊂ [0, 2^254).
2900 auto* mutable_scalars =
2901 const_cast<ScalarField*>(scalar_slice.data()); // NOLINT(cppcoreguidelines-pro-type-const-cast)
2902 bb::parallel_for(bb::get_num_cpus(), [&](const ThreadChunk& chunk) {
2903 for (size_t i : chunk.range(n_used)) {
2904 mutable_scalars[i].self_from_montgomery_form_reduced();
2905 }
2906 });
2907 const Element result =
2908 round_parallel_detail::pippenger_round_parallel_jacobian_fast<Curve>(scalar_slice, point_slice, 0);
2909 bb::parallel_for(bb::get_num_cpus(), [&](const ThreadChunk& chunk) {
2910 for (size_t i : chunk.range(n_used)) {
2911 mutable_scalars[i].self_to_montgomery_form();
2912 }
2913 });
2914 return result;
2915}
2916
2917template <typename Curve>
2920 bool handle_edge_cases,
2921 bool dedup_hint) noexcept
2922{
2923 return AffineElement(pippenger_fast<Curve>(scalars, points, handle_edge_cases, dedup_hint));
2924}
2925
2926#include "./pippenger_batched.hpp"
2927
2928// Explicit instantiations.
2932 bool dedup_hint) noexcept;
2936 bool dedup_hint) noexcept;
2939 bool handle_edge_cases,
2940 bool dedup_hint) noexcept;
2944 bool handle_edge_cases,
2945 bool dedup_hint) noexcept;
2946template class MSM_fast<curve::BN254>;
2947template class MSM_fast<curve::Grumpkin>;
2948
2952 bool dedup_hint,
2954 std::span<std::byte> external_arena) noexcept;
2955
2959 bool dedup_hint,
2961 std::span<std::byte> external_arena) noexcept;
2962
2966
2970
2974
2978
2979namespace round_parallel_detail {
2980template curve::BN254::Element pippenger_round_parallel_jacobian_fast<curve::BN254>(
2983 size_t min_pts_per_thread_override) noexcept;
2984
2985template curve::Grumpkin::Element pippenger_round_parallel_jacobian_fast<curve::Grumpkin>(
2988 size_t min_pts_per_thread_override) noexcept;
2989} // namespace round_parallel_detail
2990
2991template size_t compute_arena_bytes_for_msm<curve::BN254>(size_t, bool, bool) noexcept;
2992
2993} // namespace bb::scalar_multiplication
#define BB_ASSERT_GTE(left, right,...)
Definition assert.hpp:128
#define BB_ASSERT_GT(left, right,...)
Definition assert.hpp:113
#define BB_ASSERT_EQ(actual, expected,...)
Definition assert.hpp:83
#define BB_ASSERT_LTE(left, right,...)
Definition assert.hpp:158
#define BB_BENCH_NAME(name)
Definition bb_bench.hpp:264
typename Group::element Element
Definition bn254.hpp:21
typename Group::affine_element AffineElement
Definition bn254.hpp:22
typename Group::element Element
Definition grumpkin.hpp:63
typename Group::affine_element AffineElement
Definition grumpkin.hpp:64
static AffineElement msm(std::span< const AffineElement > points, PolynomialSpan< const ScalarField > scalars, bool handle_edge_cases=false, bool dedup_hint=false) noexcept
Single MSM_fast convenience wrapper — returns the result as an AffineElement.
#define info(...)
Definition log.hpp:93
FF a
FF b
uint32_t get_constantine_packed_digit(const uint64_t *scalar_data, uint32_t lo_limb, uint32_t hi_limb, uint32_t lo_off, uint32_t lo_bits, uint32_t lo_mask, uint32_t hi_mask, bool slice_localised_to_one_u64, size_t window_bits) noexcept
Read (window_bits+1) bits from scalar_data (uint64 limbs) using precomputed slice params and apply Co...
ConstantineSlicePath classify_slice_path_u32(const ConstantineSliceParamsU32 &sp) noexcept
size_t compute_global_max_overflow_per_window(size_t n, size_t num_threads, size_t subchunk_entries_cap) noexcept
size_t compute_phase_one_prologue_bytes(size_t n, bool use_glv, bool inline_glv_double, size_t profile_threads) noexcept
void store_constantine_packed_digits_x4_bottom(uint32_t *dst, const uint32_t *scalar_data_0, const uint32_t *scalar_data_1, const uint32_t *scalar_data_2, const uint32_t *scalar_data_3, uint32_t hi_limb, uint32_t lo_bits, SimdU32x4 hi_mask_v, SimdU32x4 one_v, SimdU32x4 val_mask, uint32_t window_bits) noexcept
size_t solve_wpb(size_t per_window_bytes, size_t available_budget, size_t W_R) noexcept
void store_constantine_packed_digits_x4_boundary(uint32_t *dst, const uint32_t *scalar_data_0, const uint32_t *scalar_data_1, const uint32_t *scalar_data_2, const uint32_t *scalar_data_3, uint32_t lo_limb, uint32_t hi_limb, uint32_t lo_off, uint32_t lo_bits, SimdU32x4 lo_mask_v, SimdU32x4 hi_mask_v, SimdU32x4 one_v, SimdU32x4 val_mask, uint32_t window_bits) noexcept
size_t compute_bucket_partials_max(size_t B_eff, size_t num_threads) noexcept
uint32_t __attribute__((vector_size(16))) SimdU32x4
PhaseACaps compute_phase_a_caps(size_t n, size_t num_threads) noexcept
ConstantineSliceParams compute_constantine_slice_params(size_t bit_offset, size_t window_bits, size_t num_uint64_limbs) noexcept
Curve::Element pippenger_round_parallel_jacobian_fast(std::span< const typename Curve::ScalarField > scalars, std::span< const typename Curve::AffineElement > points, size_t min_pts_per_thread_override) noexcept
Small-N fast-path: per-thread Jacobian Pippenger over a partition of the input.
void store_constantine_packed_digits_x4_localised(uint32_t *dst, const uint32_t *scalar_data_0, const uint32_t *scalar_data_1, const uint32_t *scalar_data_2, const uint32_t *scalar_data_3, uint32_t lo_limb, uint32_t lo_off, SimdU32x4 lo_mask_v, SimdU32x4 one_v, SimdU32x4 val_mask, uint32_t window_bits) noexcept
size_t compute_dense_stride(size_t B_eff, size_t num_threads) noexcept
uint32_t choose_window_bits(size_t num_points, size_t num_bits, size_t n_input, size_t num_logical_threads) noexcept
VariableWindowSchedule build_var_window_schedule(size_t num_bits, size_t window_bits) noexcept
ConstantineSliceParamsU32 compute_constantine_slice_params_u32(size_t bit_offset, size_t window_bits, size_t num_u32_limbs) noexcept
template curve::BN254::Element pippenger_fast< curve::BN254 >(PolynomialSpan< const curve::BN254::ScalarField > scalars, std::span< const curve::BN254::AffineElement > points, bool handle_edge_cases, bool dedup_hint) noexcept
size_t compute_arena_bytes_for_msm(size_t n_input, bool external_glv_provided, bool dedup_active) noexcept
Round-parallel Pippenger MSM_fast. Windows process sequentially (high-to-low) but each window is full...
template size_t compute_arena_bytes_for_msm< curve::BN254 >(size_t, bool, bool) noexcept
Curve::Element pippenger_unsafe_fast(PolynomialSpan< const typename Curve::ScalarField > scalars, std::span< const typename Curve::AffineElement > points, bool dedup_hint) noexcept
template curve::BN254::Element pippenger_round_parallel< curve::BN254 >(PolynomialSpan< const curve::BN254::ScalarField > scalars, std::span< const curve::BN254::AffineElement > points, bool dedup_hint, std::span< const curve::BN254::AffineElement > external_glv_doubled, std::span< std::byte > external_arena) noexcept
template curve::BN254::Element trivial_msm_threaded< curve::BN254 >(PolynomialSpan< const curve::BN254::ScalarField > scalars_span, std::span< const curve::BN254::AffineElement > all_points) noexcept
template curve::Grumpkin::Element pippenger_round_parallel< curve::Grumpkin >(PolynomialSpan< const curve::Grumpkin::ScalarField > scalars, std::span< const curve::Grumpkin::AffineElement > points, bool dedup_hint, std::span< const curve::Grumpkin::AffineElement > external_glv_doubled, std::span< std::byte > external_arena) noexcept
template curve::Grumpkin::Element trivial_msm< curve::Grumpkin >(PolynomialSpan< const curve::Grumpkin::ScalarField > scalars_span, std::span< const curve::Grumpkin::AffineElement > all_points) noexcept
template curve::Grumpkin::Element pippenger_fast< curve::Grumpkin >(PolynomialSpan< const curve::Grumpkin::ScalarField > scalars, std::span< const curve::Grumpkin::AffineElement > points, bool handle_edge_cases, bool dedup_hint) noexcept
Curve::Element pippenger_fast(PolynomialSpan< const typename Curve::ScalarField > scalars, std::span< const typename Curve::AffineElement > points, bool handle_edge_cases, bool dedup_hint) noexcept
template curve::BN254::Element trivial_msm< curve::BN254 >(PolynomialSpan< const curve::BN254::ScalarField > scalars_span, std::span< const curve::BN254::AffineElement > all_points) noexcept
template curve::Grumpkin::Element pippenger_unsafe_fast< curve::Grumpkin >(PolynomialSpan< const curve::Grumpkin::ScalarField > scalars, std::span< const curve::Grumpkin::AffineElement > points, bool dedup_hint) noexcept
size_t window_bits_tuning_oversub_factor(size_t n_input)
N-dependent oversubscription factor used ONLY for choose_window_bits' target_load formula (not for ac...
template curve::BN254::Element pippenger_unsafe_fast< curve::BN254 >(PolynomialSpan< const curve::BN254::ScalarField > scalars, std::span< const curve::BN254::AffineElement > points, bool dedup_hint) noexcept
Curve::Element pippenger_round_parallel(PolynomialSpan< const typename Curve::ScalarField > scalars_span, std::span< const typename Curve::AffineElement > all_points, bool dedup_hint, std::span< const typename Curve::AffineElement > external_glv_doubled, std::span< std::byte > external_arena) noexcept
State of the art pippenger_fast multiscalar multiplication algorithm.
template curve::Grumpkin::Element trivial_msm_threaded< curve::Grumpkin >(PolynomialSpan< const curve::Grumpkin::ScalarField > scalars_span, std::span< const curve::Grumpkin::AffineElement > all_points) noexcept
size_t get_num_cpus()
Definition thread.cpp:33
C slice(C const &container, size_t start)
Definition container.hpp:9
Inner sum(Cont< Inner, Args... > const &in)
Definition container.hpp:70
void parallel_for(size_t num_iterations, const std::function< void(size_t)> &func)
Definition thread.cpp:111
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
uint8_t len
std::span< uint32_t > affine_bucket_indices
uintptr_t base_addr
std::byte * data
std::span< BaseField > affine_bucket_inversion_scratch
std::span< AffineElement > points_to_add
std::span< uint32_t > pair_dest
std::span< uint8_t > is_present
std::span< AffineElement > overflow_pts
std::unique_ptr< std::byte[]> local_owner
std::span< std::pair< uint32_t, uint32_t > > affine_bucket_pairs
std::span< uint32_t > overflow_slots
size_t affine_bucket_stride
std::span< AffineElement > curr_pts
std::span< uint32_t > curr_buckets
std::span< BaseField > inversion_scratch
std::span< AffineElement > dense_buckets
std::span< AffineBucketChunkInfo > chunk_infos
Curve::Element Element
size_t thread_index
Definition thread.hpp:150
auto range(size_t size, size_t offset=0) const
Definition thread.hpp:152
Per-window precomputed slice parameters for the carry-less signed-Booth window recoding....
std::span< typename Curve::AffineElement > extra_points