Skip to content

feat(avm): dynamic work distribution in AVM sumcheck prover round#22643

Merged
jeanmon merged 2 commits intomerge-train/avmfrom
jean/avm-sumcheck-dynamic-thread-chunks
Apr 21, 2026
Merged

feat(avm): dynamic work distribution in AVM sumcheck prover round#22643
jeanmon merged 2 commits intomerge-train/avmfrom
jean/avm-sumcheck-dynamic-thread-chunks

Conversation

@jeanmon
Copy link
Copy Markdown
Contributor

@jeanmon jeanmon commented Apr 17, 2026

Switch compute_univariate_avm from static per-thread assignment to dynamic chunk distribution via the atomic thread pool. Each work item processes all relations for rows_per_chunk=16 consecutive rows, and threads atomically pick up the next chunk as they finish.

Here are results of benchmarks for 32, 15, 7 CPU cores over 20 runs.

Best-of-5 mean (sumcheck ms) across thread counts:

                                                    
  ┌─────────┬──────────┬─────────┬─────────────────────┐                                                                                                                                    
  │ Threads │ Baseline │ Dynamic │ Dynamic vs Baseline │
  ├─────────┼──────────┼─────────┼─────────────────────┤                                                                                                                                    
  │ 32      │ 2440.4   │ 2009.6  │ −17.6%              │                                                                                                                                    
  ├─────────┼──────────┼─────────┼─────────────────────┤                                                                                                                                    
  │ 15      │ 3432.2   │ 2993.6  │ −12.8%              │                                                                                                                                    
  ├─────────┼──────────┼─────────┼─────────────────────┤                                                                                                                                    
  │ 7       │ 6448.4   │ 5548.8  │ −14.0%              │                                                                                                                                    
  └─────────┴──────────┴─────────┴─────────────────────┘       

Median:


  ┌─────────┬──────────────┬─────────────┬────────────┐                                                                                                                                     
  │ Threads │ Baseline med │ Dynamic med │ Δ (median) │
  ├─────────┼──────────────┼─────────────┼────────────┤                                                                                                                                     
  │ 32      │ 2558.0       │ 2209.0      │ −13.6%     │                 
  ├─────────┼──────────────┼─────────────┼────────────┤                                                                                                                                     
  │ 15      │ 3733.5       │ 3211.0      │ −14.0%     │                 
  ├─────────┼──────────────┼─────────────┼────────────┤                                                                                                                                     
  │ 7       │ 6887.5       │ 5995.0      │ −13.0%     │
  └─────────┴──────────────┴─────────────┴────────────┘      

The value rows_per_chunk = 16 was determined experimentally. A value 8 and 32 lead to slightly worse performance.

@jeanmon jeanmon force-pushed the jean/avm-sumcheck-dynamic-thread-chunks branch from 49f5125 to 211e033 Compare April 17, 2026 16:59
@jeanmon jeanmon marked this pull request as ready for review April 17, 2026 17:01
@jeanmon jeanmon changed the title perf: dynamic work distribution in AVM sumcheck prover round feat(avm): dynamic work distribution in AVM sumcheck prover round Apr 17, 2026
Copy link
Copy Markdown
Contributor

@fcarreiro fcarreiro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! I'd suggest you run and benchmark with other cpu counts as well. E.g., smaller and maybe non power of 2.


// Accumulate the contribution from each sub-relation across each edge of the hyper-cube
parallel_for(num_threads, [&](size_t thread_idx) {
parallel_for(total_chunks, [&](size_t chunk_id) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I always forget that parallel_for will always have a pool of get_num_cpus() and not of the number of iterations passed.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is something I realized while working on this! The first argument in parallel_for() is the number of tasks you will distribute to the thread pool. The variable "num_threads" naming was a bit misleading.

// Accumulate the contribution from each sub-relation across each edge of the hyper-cube
parallel_for(num_threads, [&](size_t thread_idx) {
parallel_for(total_chunks, [&](size_t chunk_id) {
thread_local size_t thread_idx = thread_counter.fetch_add(1);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems ok. I wonder if it could be a problem if multiple sumcheck rounds are run. Maybe ask @ludamad. Worst case, parallel_for could be modified to provide the thread_idx on top of the chunk_id, if the lambda accepts it. It would be provided here: https://github.com/AztecProtocol/aztec-packages/blob/next/barretenberg/cpp/src/barretenberg/common/parallel_for_mutex_pool.cpp#L87 and needs to be plumbed from here https://github.com/AztecProtocol/aztec-packages/blob/next/barretenberg/cpp/src/barretenberg/common/parallel_for_mutex_pool.cpp#L134 via the unused size_t.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In https://github.com/AztecProtocol/aztec-packages/blob/next/barretenberg/cpp/src/barretenberg/common/parallel_for_mutex_pool.cpp#L157
there is a safeguard to prevent a "parallel_for" to be nested.
@fcarreiro I guess this should be fine.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not about nesting. It's about running one sumcheck then another. static thread_only will keep the same index for a thread it used in the previous sumcheck round. I think it will be ok, but in general it's conceptually "dirty". This class has state, and that state in principle is only kept on the same class instante. However, the thread index will carry across class instances.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, if the number of cpus changed in between (which it shouldnt), things would break.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fcarreiro ah I got it now. Thanks for the clarifications. It is a very good catch.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the static keyword for the atomic thread_counter should not be required. The main thread needs to wait that each thread is done with their task in parallel_for() otherwise we have another bigger problem. In other words, thread_counter as a standard local variable will be accessible to every thread of the current run before the main thread continues to the next computation ( for (auto& accumulators : thread_univariate_accumulators) {).
WDYT?: @fcarreiro

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, it seems there is an issue if we simply remove the static keyword. Namely, if a thread in the second call runs for the first time it will be allocated with thread_idx = 0 which would collide with a thread of the previous run and we get some race condition.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right. In some sense the "static std::atomic<size_t> thread_counter{ 0 };" could probably be moved to the line before the "thread_local size_t thread_idx = thread_counter.fetch_add(1);". in any case it will be static.

That's the downside of this solution, those 2 statics (thread_local implies static) are difficult to reason about. I think the cleanest solution would be to be able to get the thread index as I mentioned before.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fcarreiro @ledwards2225 I have a change in progress which would solve this without touching the thread.hpp interfaces.
We allocate the number of tasks equal to the number of threads and while fetching a new chunk the thread is using an atomic counter to get the next available chunk. According to Claude it is even faster than the mutex in the pool.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fcarreiro @ledwards2225 I pushed a second commit implementing the new variant I was mentioning above. Benchmarks showed no regression (even 2% speedup on average but it is too small to say for sure it is even faster.).

jeanmon added 2 commits April 20, 2026 17:27
Switch compute_univariate_avm from static per-thread assignment to
dynamic chunk distribution via the atomic thread pool. Each work item
processes all relations for rows_per_chunk=16 consecutive rows, and
threads atomically pick up the next chunk as they finish.

On a large AVM proof (173706 edges, 32 threads), this improves the
best-of-5 sumcheck time from 2440ms to 2010ms (-17.6%).
Replace the static atomic counter + thread_local slot assignment with a
work-stealing loop: parallel_for now dispatches get_num_cpus() outer
tasks, each owning a fixed accumulator slot and pulling chunks from a
local atomic counter. This removes the program-scoped state that
carried across SumcheckProverRound instances.

No functional change. Small median improvement (~1-2%) observed on
/tmp/bb-RxZhBt (10 runs, 32 threads).
@jeanmon jeanmon force-pushed the jean/avm-sumcheck-dynamic-thread-chunks branch from 211e033 to 28b569a Compare April 20, 2026 17:31
@ledwards2225
Copy link
Copy Markdown
Contributor

Nice!

Question: Would bumping chunk size from 2 to 16 with the old static approach have given about the same improvement? I was surprised to see that the previous (contiguous) chunk was only size 2, which seems rough for cache. I assume 2 was chosen to balance thread work across the inhomogeneous trace but I would have guessed a much larger chunk (e.g. 16-64) would still be sufficient for that purpose. Number of expected cache misses is a function of this contiguous chunk size and is independent of thread count - could that explain why the benefit is roughly flat across different thread counts?

In any case I like this change - simple and robust.

Copy link
Copy Markdown
Contributor

@fcarreiro fcarreiro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting that this is faster than the baseline approach since it's so similar :) notable that letting threads "pick up" chunks makes so much of a difference in the long term. cool!

PS: I hate while(true)s but that's life!

@fcarreiro
Copy link
Copy Markdown
Contributor

Nice!

Question: Would bumping chunk size from 2 to 16 with the old static approach have given about the same improvement? I was surprised to see that the previous (contiguous) chunk was only size 2, which seems rough for cache. I assume 2 was chosen to balance thread work across the inhomogeneous trace but I would have guessed a much larger chunk (e.g. 16-64) would still be sufficient for that purpose. Number of expected cache misses is a function of this contiguous chunk size and is independent of thread count - could that explain why the benefit is roughly flat across different thread counts?

In any case I like this change - simple and robust.

I would hope that that wouldn't be enough. Cache will not help here I think, especially for the AVM. As Jean mentioned, cache goes column-wise, and sumcheck reads row-wise. I would never expect reading one row would keep sth in cache of the next row.

I think the gain comes from not prescribing chunks for each thread and allocating them dynamically.

@ledwards2225
Copy link
Copy Markdown
Contributor

I think the gain comes from not prescribing chunks for each thread and allocating them dynamically.

But that would suggest that a chunk size of 2 is not sufficient to basically achieve uniform distribution of work - how could that be? Wouldn't that suggest that there are at least two single rows that are ~30% more expensive to execute than any other two rows?

@jeanmon
Copy link
Copy Markdown
Contributor Author

jeanmon commented Apr 21, 2026

I think the gain comes from not prescribing chunks for each thread and allocating them dynamically.

But that would suggest that a chunk size of 2 is not sufficient to basically achieve uniform distribution of work - how could that be? Wouldn't that suggest that there are at least two single rows that are ~30% more expensive to execute than any other two rows?

For the former approach, I remember that we benchmarked different chunk sizes and increasing it did not perform better. I can make a try with a higher value.
For the current approach, the chunk size of 16 was experimentally the sweet spot but that was before I introduced the atomic counter. I probably should measure again.

@ledwards2225 I cannot definitely say the real reason of the improvement but the static work allocation to the threads are very likely never the best for most traces as it is very unflexible. The AVM trace is so inhomogeneous that I can imagine that the difference between the easiest task and the largest task (over 32 cores) can easily have a variation of 15%.

@jeanmon
Copy link
Copy Markdown
Contributor Author

jeanmon commented Apr 21, 2026

@ledwards2225 You are right that the static version with 16 rows provide better performance. However, the dynamic variant (this PR) is the clear winner.

  ┌────────────┬─────┬──────┬────────┬────────┬──────┐                                                                                                                                      
  │   Config   │  N  │ Min  │ Median │  Mean  │ Max  │                                                                                                                                      
  ├────────────┼─────┼──────┼────────┼────────┼──────┤                                                                                                                                      
  │ base-rpt16 │ 10  │ 2206 │ 2336.5 │ 2379.0 │ 2654 │
  ├────────────┼─────┼──────┼────────┼────────┼──────┤                                                                                                                                      
  │ base-rpt8  │ 10  │ 2105 │ 2388.0 │ 2402.8 │ 2574 │                                                                                                                                      
  ├────────────┼─────┼──────┼────────┼────────┼──────┤
  │ base-rpt2  │ 10  │ 2000 │ 2566.5 │ 2504.4 │ 2675 │                                                                                                                                      
  ├────────────┼─────┼──────┼────────┼────────┼──────┤                                                                                                                                      
  │ dyn-rpc16  │ 10  │ 1985 │ 2028.5 │ 2040.6 │ 2111 │                                                                                                                                      
  └────────────┴─────┴──────┴────────┴────────┴──────┘                                                                                                                                      
     ```
                                                                                                                                                                                            
  Dynamic (rpc=16) vs static baseline:                                                                                                                                                      
  - vs base-rpt16: −14.2% mean, −13.2% median                                                                                                                                               
  - vs base-rpt8:  −15.1% mean, −15.1% median                                                                                                                                               
  - vs base-rpt2:  −18.5% mean, −21.0% median
  
  P.S. I also ran benchmarks with 32 and 64 and it gets worse. Especially, you get sometimes outliers (MAX) which goes up to 5-6 seconds.

@jeanmon jeanmon merged commit a31bef5 into merge-train/avm Apr 21, 2026
14 checks passed
@jeanmon jeanmon deleted the jean/avm-sumcheck-dynamic-thread-chunks branch April 21, 2026 08:12
@ledwards2225
Copy link
Copy Markdown
Contributor

Interesting! Thanks for running these

Yeah TBC I agree dynamic is best regardless, I'm just always looking to better understand the tradeoffs in sumcheck because intuition has not always been a good guide. I set up the same approach to be used with Honk and found 64 to be the ideal chunk size. We see about a 15% improvement. Was hoping for even more since we were previously doing something more naive than AVM (just splitting trace into num_threads many contiguous chunks) but we also have substantially less regional variation than you guys

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants