Skip to main content

ab_proof_of_space_gpu/
host.rs

1#[cfg(all(test, not(miri)))]
2mod tests;
3
4use crate::shader::constants::{
5    MAX_BUCKET_SIZE, MAX_TABLE_SIZE, NUM_BUCKETS, NUM_MATCH_BUCKETS, NUM_S_BUCKETS,
6    REDUCED_MATCHES_COUNT,
7};
8use crate::shader::find_matches_and_compute_f7::{NUM_ELEMENTS_PER_S_BUCKET, ProofTargets};
9use crate::shader::find_proofs::ProofsHost;
10use crate::shader::types::{Metadata, Position, PositionR};
11use crate::shader::{compute_f1, find_proofs, select_shader_features_limits};
12use ab_chacha8::{ChaCha8Block, ChaCha8State, block_to_bytes};
13use ab_core_primitives::pieces::{PieceOffset, Record};
14use ab_core_primitives::pos::PosSeed;
15use ab_core_primitives::sectors::SectorId;
16use ab_erasure_coding::ErasureCoding;
17use ab_farmer_components::plotting::RecordsEncoder;
18use ab_farmer_components::sector::SectorContentsMap;
19use async_lock::Mutex as AsyncMutex;
20use futures::stream::FuturesOrdered;
21use futures::{StreamExt, TryStreamExt};
22use parking_lot::Mutex;
23use rayon::prelude::*;
24use rayon::{ThreadPool, ThreadPoolBuilder};
25use rclite::Arc;
26use std::fmt;
27use std::num::NonZeroU8;
28use std::simd::Simd;
29use std::sync::Arc as StdArc;
30use std::sync::atomic::{AtomicBool, Ordering};
31use tracing::{debug, warn};
32use wgpu::{
33    AdapterInfo, Backend, BackendOptions, Backends, BindGroup, BindGroupDescriptor, BindGroupEntry,
34    BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType, Buffer, BufferAddress,
35    BufferAsyncError, BufferBindingType, BufferDescriptor, BufferUsages, CommandEncoderDescriptor,
36    ComputePassDescriptor, ComputePipeline, ComputePipelineDescriptor, DeviceDescriptor,
37    DeviceType, Instance, InstanceDescriptor, InstanceFlags, MapMode, MemoryBudgetThresholds,
38    PipelineCompilationOptions, PipelineLayoutDescriptor, PollError, PollType, Queue,
39    RequestDeviceError, ShaderModule, ShaderStages,
40};
41
42/// Proof creation error
43#[derive(Debug, thiserror::Error)]
44enum RecordEncodingError {
45    /// Too many records
46    #[error("Too many records: {0}")]
47    TooManyRecords(usize),
48    /// Proof creation failed previously and the device is now considered broken
49    #[error("Proof creation failed previously and the device is now considered broken")]
50    DeviceBroken,
51    /// Failed to map buffer
52    #[error("Failed to map buffer: {0}")]
53    BufferMapping(#[from] BufferAsyncError),
54    /// Poll error
55    #[error("Poll error: {0}")]
56    DevicePoll(#[from] PollError),
57}
58
59struct ProofsHostWrapper<'a> {
60    proofs: &'a ProofsHost,
61    proofs_host: &'a Buffer,
62}
63
64impl Drop for ProofsHostWrapper<'_> {
65    fn drop(&mut self) {
66        self.proofs_host.unmap();
67    }
68}
69
70/// Wrapper data structure encapsulating a single compatible device
71#[derive(Clone)]
72pub struct Device {
73    id: u32,
74    devices: Vec<(wgpu::Device, Queue, ShaderModule)>,
75    adapter_info: AdapterInfo,
76}
77
78impl fmt::Debug for Device {
79    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80        f.debug_struct("Device")
81            .field("id", &self.id)
82            .field("name", &self.adapter_info.name)
83            .field("device_type", &self.adapter_info.device_type)
84            .field("driver", &self.adapter_info.driver)
85            .field("driver_info", &self.adapter_info.driver_info)
86            .field("backend", &self.adapter_info.backend)
87            .finish_non_exhaustive()
88    }
89}
90
91impl Device {
92    /// Returns [`Device`] for each available device
93    pub async fn enumerate<NOQ>(number_of_queues: NOQ) -> Vec<Self>
94    where
95        NOQ: Fn(DeviceType) -> NonZeroU8,
96    {
97        let backends = Backends::from_env().unwrap_or(Backends::METAL | Backends::VULKAN);
98        let instance = Instance::new(InstanceDescriptor {
99            backends,
100            flags: if cfg!(debug_assertions) {
101                InstanceFlags::debugging().with_env()
102            } else {
103                InstanceFlags::from_env_or_default()
104            },
105            memory_budget_thresholds: MemoryBudgetThresholds::default(),
106            backend_options: BackendOptions::from_env_or_default(),
107            display: None,
108        });
109
110        let adapters = instance.enumerate_adapters(backends).await;
111        let number_of_queues = &number_of_queues;
112
113        adapters
114            .into_iter()
115            .zip(0..)
116            .map(|(adapter, id)| async move {
117                let adapter_info = adapter.get_info();
118
119                let (shader, required_features, required_limits) =
120                    match select_shader_features_limits(&adapter) {
121                        Some((shader, required_features, required_limits)) => {
122                            debug!(
123                                %id,
124                                adapter_info = ?adapter_info,
125                                "Compatible adapter found"
126                            );
127
128                            (shader, required_features, required_limits)
129                        }
130                        None => {
131                            debug!(
132                                %id,
133                                adapter_info = ?adapter_info,
134                                "Incompatible adapter found"
135                            );
136
137                            return None;
138                        }
139                    };
140
141                // TODO: creation of multiple devices is a workaround for lack of support for
142                //  multiple queues: https://github.com/gfx-rs/wgpu/issues/1066
143                let devices = (0..number_of_queues(adapter_info.device_type).get())
144                    .map(|_| async {
145                        let (device, queue) = adapter
146                            .request_device(&DeviceDescriptor {
147                                label: None,
148                                required_features,
149                                required_limits: required_limits.clone(),
150                                ..DeviceDescriptor::default()
151                            })
152                            .await
153                            .inspect_err(|error| {
154                                warn!(%id, ?adapter_info, %error, "Failed to request the device");
155                            })?;
156                        let module = device.create_shader_module(shader.clone());
157
158                        Ok::<_, RequestDeviceError>((device, queue, module))
159                    })
160                    .collect::<FuturesOrdered<_>>()
161                    .try_collect::<Vec<_>>()
162                    .await
163                    .ok()?;
164
165                Some(Self {
166                    id,
167                    devices,
168                    adapter_info,
169                })
170            })
171            .collect::<FuturesOrdered<_>>()
172            .filter_map(|device| async move { device })
173            .collect()
174            .await
175    }
176
177    /// Gpu ID
178    pub fn id(&self) -> u32 {
179        self.id
180    }
181
182    /// Device name
183    pub fn name(&self) -> &str {
184        &self.adapter_info.name
185    }
186
187    /// Device type
188    pub fn device_type(&self) -> DeviceType {
189        self.adapter_info.device_type
190    }
191
192    /// Driver
193    pub fn driver(&self) -> &str {
194        &self.adapter_info.driver
195    }
196
197    /// Driver info
198    pub fn driver_info(&self) -> &str {
199        &self.adapter_info.driver_info
200    }
201
202    /// Backend
203    pub fn backend(&self) -> Backend {
204        self.adapter_info.backend
205    }
206
207    pub fn instantiate(
208        &self,
209        erasure_coding: ErasureCoding,
210        global_mutex: StdArc<AsyncMutex<()>>,
211    ) -> anyhow::Result<GpuRecordsEncoder> {
212        GpuRecordsEncoder::new(
213            self.id,
214            self.devices.clone(),
215            self.adapter_info.clone(),
216            erasure_coding,
217            global_mutex,
218        )
219    }
220}
221
222pub struct GpuRecordsEncoder {
223    id: u32,
224    instances: Vec<Mutex<GpuRecordsEncoderInstance>>,
225    thread_pool: ThreadPool,
226    adapter_info: AdapterInfo,
227    erasure_coding: ErasureCoding,
228    global_mutex: StdArc<AsyncMutex<()>>,
229}
230
231impl fmt::Debug for GpuRecordsEncoder {
232    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
233        f.debug_struct("GpuRecordsEncoder")
234            .field("id", &self.id)
235            .field("name", &self.adapter_info.name)
236            .field("device_type", &self.adapter_info.device_type)
237            .field("driver", &self.adapter_info.driver)
238            .field("driver_info", &self.adapter_info.driver_info)
239            .field("backend", &self.adapter_info.backend)
240            .finish_non_exhaustive()
241    }
242}
243
244impl RecordsEncoder for GpuRecordsEncoder {
245    // TODO: Run more than one encoding per device concurrently
246    fn encode_records(
247        &mut self,
248        sector_id: &SectorId,
249        records: &mut [Record],
250        abort_early: &AtomicBool,
251    ) -> anyhow::Result<SectorContentsMap> {
252        let mut sector_contents_map = SectorContentsMap::new(
253            u16::try_from(records.len())
254                .map_err(|_| RecordEncodingError::TooManyRecords(records.len()))?,
255        );
256
257        let maybe_error = self.thread_pool.install(|| {
258            records
259                .par_iter_mut()
260                .zip(sector_contents_map.iter_record_chunks_used_mut())
261                .enumerate()
262                .find_map_any(|(piece_offset, (record, record_chunks_used))| {
263                    // Take mutex briefly to make sure encoding is allowed right now
264                    self.global_mutex.lock_blocking();
265
266                    let mut parity_record_chunks = Record::new_boxed();
267
268                    // TODO: Do erasure coding on the GPU
269                    // Erasure code source record chunks
270                    self.erasure_coding
271                        .extend(record.iter(), parity_record_chunks.iter_mut())
272                        .expect("Statically guaranteed valid inputs; qed");
273
274                    if abort_early.load(Ordering::Relaxed) {
275                        return None;
276                    }
277                    let seed =
278                        sector_id.derive_evaluation_seed(PieceOffset::from(piece_offset as u16));
279                    let thread_index = rayon::current_thread_index().unwrap_or_default();
280                    let mut encoder_instance = self.instances[thread_index]
281                        .try_lock()
282                        .expect("1:1 mapping between threads and devices; qed");
283                    let proofs = match encoder_instance.create_proofs(&seed) {
284                        Ok(proofs) => proofs,
285                        Err(error) => {
286                            return Some(error);
287                        }
288                    };
289                    let proofs = proofs.proofs;
290
291                    *record_chunks_used = proofs.found_proofs;
292
293                    // TODO: Record encoding on the GPU
294                    let mut num_found_proofs = 0_usize;
295                    for (s_buckets, found_proofs) in (0..Record::NUM_S_BUCKETS)
296                        .array_chunks::<{ u8::BITS as usize }>()
297                        .zip(record_chunks_used)
298                    {
299                        for (proof_offset, s_bucket) in s_buckets.into_iter().enumerate() {
300                            if num_found_proofs == Record::NUM_CHUNKS {
301                                // Enough proofs collected, clear the rest of the bits
302                                *found_proofs &=
303                                    u8::MAX.unbounded_shr(u8::BITS - proof_offset as u32);
304                                break;
305                            }
306                            if (*found_proofs & (1 << proof_offset)) != 0 {
307                                let record_chunk = if s_bucket < Record::NUM_CHUNKS {
308                                    record[s_bucket]
309                                } else {
310                                    parity_record_chunks[s_bucket - Record::NUM_CHUNKS]
311                                };
312
313                                record[num_found_proofs] = (Simd::from(record_chunk)
314                                    ^ Simd::from(*proofs.proofs[s_bucket].hash()))
315                                .to_array();
316                                num_found_proofs += 1;
317                            }
318                        }
319                    }
320
321                    None
322                })
323        });
324
325        if let Some(error) = maybe_error {
326            return Err(error.into());
327        }
328
329        Ok(sector_contents_map)
330    }
331}
332
333impl GpuRecordsEncoder {
334    fn new(
335        id: u32,
336        devices: Vec<(wgpu::Device, Queue, ShaderModule)>,
337        adapter_info: AdapterInfo,
338        erasure_coding: ErasureCoding,
339        global_mutex: StdArc<AsyncMutex<()>>,
340    ) -> anyhow::Result<Self> {
341        let thread_pool = ThreadPoolBuilder::new()
342            .thread_name(move |thread_index| format!("pos-gpu-{id}.{thread_index}"))
343            .num_threads(devices.len())
344            .build()?;
345
346        Ok(Self {
347            id,
348            instances: devices
349                .into_iter()
350                .map(|(device, queue, module)| {
351                    Mutex::new(GpuRecordsEncoderInstance::new(device, queue, module))
352                })
353                .collect(),
354            thread_pool,
355            adapter_info,
356            erasure_coding,
357            global_mutex,
358        })
359    }
360}
361
362struct GpuRecordsEncoderInstance {
363    device: wgpu::Device,
364    queue: Queue,
365    mapping_error: Arc<Mutex<Option<BufferAsyncError>>>,
366    tainted: bool,
367    initial_state_host: Buffer,
368    initial_state_gpu: Buffer,
369    proofs_host: Buffer,
370    proofs_gpu: Buffer,
371    bind_group_compute_f1: BindGroup,
372    compute_pipeline_compute_f1: ComputePipeline,
373    bind_group_sort_buckets_a: BindGroup,
374    compute_pipeline_sort_buckets_a: ComputePipeline,
375    bind_group_sort_buckets_b: BindGroup,
376    compute_pipeline_sort_buckets_b: ComputePipeline,
377    bind_group_find_matches_and_compute_f2: BindGroup,
378    compute_pipeline_find_matches_and_compute_f2: ComputePipeline,
379    bind_group_find_matches_and_compute_f3: BindGroup,
380    compute_pipeline_find_matches_and_compute_f3: ComputePipeline,
381    bind_group_find_matches_and_compute_f4: BindGroup,
382    compute_pipeline_find_matches_and_compute_f4: ComputePipeline,
383    bind_group_find_matches_and_compute_f5: BindGroup,
384    compute_pipeline_find_matches_and_compute_f5: ComputePipeline,
385    bind_group_find_matches_and_compute_f6: BindGroup,
386    compute_pipeline_find_matches_and_compute_f6: ComputePipeline,
387    bind_group_find_matches_and_compute_f7: BindGroup,
388    compute_pipeline_find_matches_and_compute_f7: ComputePipeline,
389    bind_group_find_proofs: BindGroup,
390    compute_pipeline_find_proofs: ComputePipeline,
391}
392
393impl GpuRecordsEncoderInstance {
394    fn new(device: wgpu::Device, queue: Queue, module: ShaderModule) -> Self {
395        let initial_state_host = device.create_buffer(&BufferDescriptor {
396            label: Some("initial_state_host"),
397            size: size_of::<ChaCha8Block>() as BufferAddress,
398            usage: BufferUsages::MAP_WRITE | BufferUsages::COPY_SRC,
399            mapped_at_creation: true,
400        });
401
402        let initial_state_gpu = device.create_buffer(&BufferDescriptor {
403            label: Some("initial_state_gpu"),
404            size: initial_state_host.size(),
405            usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST,
406            mapped_at_creation: false,
407        });
408
409        let bucket_sizes_gpu_buffer_size = size_of::<[u32; NUM_BUCKETS]>() as BufferAddress;
410        let table_6_proof_targets_sizes_gpu_buffer_size =
411            size_of::<[u32; NUM_S_BUCKETS]>() as BufferAddress;
412        // TODO: Sizes are excessive, for `bucket_sizes_gpu` are less than `u16` and could use SWAR
413        //  approach for storing bucket sizes. Similarly, `table_6_proof_targets_sizes_gpu` sizes
414        //  are less than `u8` and could use SWAR too with even higher compression ratio
415        let bucket_sizes_gpu = device.create_buffer(&BufferDescriptor {
416            label: Some("bucket_sizes_gpu"),
417            size: bucket_sizes_gpu_buffer_size.max(table_6_proof_targets_sizes_gpu_buffer_size),
418            usage: BufferUsages::STORAGE,
419            mapped_at_creation: false,
420        });
421        // Reuse the same buffer as `bucket_sizes_gpu`, they are not overlapping in use
422        let table_6_proof_targets_sizes_gpu = bucket_sizes_gpu.clone();
423
424        let buckets_a_gpu = device.create_buffer(&BufferDescriptor {
425            label: Some("buckets_a_gpu"),
426            size: size_of::<[[PositionR; MAX_BUCKET_SIZE]; NUM_BUCKETS]>() as BufferAddress,
427            usage: BufferUsages::STORAGE,
428            mapped_at_creation: false,
429        });
430
431        let buckets_b_gpu = device.create_buffer(&BufferDescriptor {
432            label: Some("buckets_b_gpu"),
433            size: buckets_a_gpu.size(),
434            usage: BufferUsages::STORAGE,
435            mapped_at_creation: false,
436        });
437
438        let positions_f2_gpu = device.create_buffer(&BufferDescriptor {
439            label: Some("positions_f2_gpu"),
440            size: size_of::<[[[Position; 2]; REDUCED_MATCHES_COUNT]; NUM_MATCH_BUCKETS]>()
441                as BufferAddress,
442            usage: BufferUsages::STORAGE,
443            mapped_at_creation: false,
444        });
445
446        let positions_f3_gpu = device.create_buffer(&BufferDescriptor {
447            label: Some("positions_f3_gpu"),
448            size: positions_f2_gpu.size(),
449            usage: BufferUsages::STORAGE,
450            mapped_at_creation: false,
451        });
452
453        let positions_f4_gpu = device.create_buffer(&BufferDescriptor {
454            label: Some("positions_f4_gpu"),
455            size: positions_f2_gpu.size(),
456            usage: BufferUsages::STORAGE,
457            mapped_at_creation: false,
458        });
459
460        let positions_f5_gpu = device.create_buffer(&BufferDescriptor {
461            label: Some("positions_f5_gpu"),
462            size: positions_f2_gpu.size(),
463            usage: BufferUsages::STORAGE,
464            mapped_at_creation: false,
465        });
466
467        let positions_f6_gpu = device.create_buffer(&BufferDescriptor {
468            label: Some("positions_f6_gpu"),
469            size: positions_f2_gpu.size(),
470            usage: BufferUsages::STORAGE,
471            mapped_at_creation: false,
472        });
473
474        let metadatas_gpu_buffer_size =
475            size_of::<[[Metadata; REDUCED_MATCHES_COUNT]; NUM_MATCH_BUCKETS]>() as BufferAddress;
476        let table_6_proof_targets_gpu_buffer_size = size_of::<
477            [[ProofTargets; NUM_ELEMENTS_PER_S_BUCKET]; NUM_S_BUCKETS],
478        >() as BufferAddress;
479        let metadatas_a_gpu = device.create_buffer(&BufferDescriptor {
480            label: Some("metadatas_a_gpu"),
481            size: metadatas_gpu_buffer_size.max(table_6_proof_targets_gpu_buffer_size),
482            usage: BufferUsages::STORAGE,
483            mapped_at_creation: false,
484        });
485        // Reuse the same buffer as `metadatas_a_gpu`, they are not overlapping in use
486        let table_6_proof_targets_gpu = metadatas_a_gpu.clone();
487
488        let metadatas_b_gpu = device.create_buffer(&BufferDescriptor {
489            label: Some("metadatas_b_gpu"),
490            size: metadatas_gpu_buffer_size,
491            usage: BufferUsages::STORAGE,
492            mapped_at_creation: false,
493        });
494
495        let proofs_host = device.create_buffer(&BufferDescriptor {
496            label: Some("proofs_host"),
497            size: size_of::<ProofsHost>() as BufferAddress,
498            usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
499            mapped_at_creation: false,
500        });
501
502        let proofs_gpu = device.create_buffer(&BufferDescriptor {
503            label: Some("proofs_gpu"),
504            size: proofs_host.size(),
505            usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
506            mapped_at_creation: false,
507        });
508
509        let (bind_group_compute_f1, compute_pipeline_compute_f1) =
510            bind_group_and_pipeline_compute_f1(
511                &device,
512                &module,
513                &initial_state_gpu,
514                &bucket_sizes_gpu,
515                &buckets_a_gpu,
516            );
517        let (bind_group_sort_buckets_a, compute_pipeline_sort_buckets_a) =
518            bind_group_and_pipeline_sort_buckets(
519                &device,
520                &module,
521                &bucket_sizes_gpu,
522                &buckets_a_gpu,
523            );
524
525        let (bind_group_sort_buckets_b, compute_pipeline_sort_buckets_b) =
526            bind_group_and_pipeline_sort_buckets(
527                &device,
528                &module,
529                &bucket_sizes_gpu,
530                &buckets_b_gpu,
531            );
532
533        let (bind_group_find_matches_and_compute_f2, compute_pipeline_find_matches_and_compute_f2) =
534            bind_group_and_pipeline_find_matches_and_compute_f2(
535                &device,
536                &module,
537                &buckets_a_gpu,
538                &bucket_sizes_gpu,
539                &buckets_b_gpu,
540                &positions_f2_gpu,
541                &metadatas_b_gpu,
542            );
543
544        let (bind_group_find_matches_and_compute_f3, compute_pipeline_find_matches_and_compute_f3) =
545            bind_group_and_pipeline_find_matches_and_compute_fn::<3>(
546                &device,
547                &module,
548                &buckets_b_gpu,
549                &metadatas_b_gpu,
550                &bucket_sizes_gpu,
551                &buckets_a_gpu,
552                &positions_f3_gpu,
553                &metadatas_a_gpu,
554            );
555
556        let (bind_group_find_matches_and_compute_f4, compute_pipeline_find_matches_and_compute_f4) =
557            bind_group_and_pipeline_find_matches_and_compute_fn::<4>(
558                &device,
559                &module,
560                &buckets_a_gpu,
561                &metadatas_a_gpu,
562                &bucket_sizes_gpu,
563                &buckets_b_gpu,
564                &positions_f4_gpu,
565                &metadatas_b_gpu,
566            );
567
568        let (bind_group_find_matches_and_compute_f5, compute_pipeline_find_matches_and_compute_f5) =
569            bind_group_and_pipeline_find_matches_and_compute_fn::<5>(
570                &device,
571                &module,
572                &buckets_b_gpu,
573                &metadatas_b_gpu,
574                &bucket_sizes_gpu,
575                &buckets_a_gpu,
576                &positions_f5_gpu,
577                &metadatas_a_gpu,
578            );
579
580        let (bind_group_find_matches_and_compute_f6, compute_pipeline_find_matches_and_compute_f6) =
581            bind_group_and_pipeline_find_matches_and_compute_fn::<6>(
582                &device,
583                &module,
584                &buckets_a_gpu,
585                &metadatas_a_gpu,
586                &bucket_sizes_gpu,
587                &buckets_b_gpu,
588                &positions_f6_gpu,
589                &metadatas_b_gpu,
590            );
591
592        let (bind_group_find_matches_and_compute_f7, compute_pipeline_find_matches_and_compute_f7) =
593            bind_group_and_pipeline_find_matches_and_compute_f7(
594                &device,
595                &module,
596                &buckets_b_gpu,
597                &metadatas_b_gpu,
598                &table_6_proof_targets_sizes_gpu,
599                &table_6_proof_targets_gpu,
600            );
601
602        let (bind_group_find_proofs, compute_pipeline_find_proofs) =
603            bind_group_and_pipeline_find_proofs(
604                &device,
605                &module,
606                &positions_f2_gpu,
607                &positions_f3_gpu,
608                &positions_f4_gpu,
609                &positions_f5_gpu,
610                &positions_f6_gpu,
611                &table_6_proof_targets_sizes_gpu,
612                &table_6_proof_targets_gpu,
613                &proofs_gpu,
614            );
615
616        Self {
617            device,
618            queue,
619            mapping_error: Arc::new(Mutex::new(None)),
620            tainted: false,
621            initial_state_host,
622            initial_state_gpu,
623            proofs_host,
624            proofs_gpu,
625            bind_group_compute_f1,
626            compute_pipeline_compute_f1,
627            bind_group_sort_buckets_a,
628            compute_pipeline_sort_buckets_a,
629            bind_group_sort_buckets_b,
630            compute_pipeline_sort_buckets_b,
631            bind_group_find_matches_and_compute_f2,
632            compute_pipeline_find_matches_and_compute_f2,
633            bind_group_find_matches_and_compute_f3,
634            compute_pipeline_find_matches_and_compute_f3,
635            bind_group_find_matches_and_compute_f4,
636            compute_pipeline_find_matches_and_compute_f4,
637            bind_group_find_matches_and_compute_f5,
638            compute_pipeline_find_matches_and_compute_f5,
639            bind_group_find_matches_and_compute_f6,
640            compute_pipeline_find_matches_and_compute_f6,
641            bind_group_find_matches_and_compute_f7,
642            compute_pipeline_find_matches_and_compute_f7,
643            bind_group_find_proofs,
644            compute_pipeline_find_proofs,
645        }
646    }
647
648    fn create_proofs(
649        &mut self,
650        seed: &PosSeed,
651    ) -> Result<ProofsHostWrapper<'_>, RecordEncodingError> {
652        if self.tainted {
653            return Err(RecordEncodingError::DeviceBroken);
654        }
655        self.tainted = true;
656
657        let mut encoder = self
658            .device
659            .create_command_encoder(&CommandEncoderDescriptor {
660                label: Some("create_proofs"),
661            });
662
663        // Mapped initially and re-mapped at the end of the computation
664        self.initial_state_host
665            .get_mapped_range_mut(..)
666            .copy_from_slice(&block_to_bytes(
667                &ChaCha8State::init(seed, &[0; _]).to_repr(),
668            ));
669        self.initial_state_host.unmap();
670
671        encoder.copy_buffer_to_buffer(
672            &self.initial_state_host,
673            0,
674            &self.initial_state_gpu,
675            0,
676            self.initial_state_host.size(),
677        );
678
679        {
680            let mut cpass = encoder.begin_compute_pass(&ComputePassDescriptor {
681                label: Some("create_proofs"),
682                timestamp_writes: None,
683            });
684
685            cpass.set_bind_group(0, &self.bind_group_compute_f1, &[]);
686            cpass.set_pipeline(&self.compute_pipeline_compute_f1);
687            cpass.dispatch_workgroups(
688                MAX_TABLE_SIZE
689                    .div_ceil(compute_f1::WORKGROUP_SIZE * compute_f1::ELEMENTS_PER_INVOCATION),
690                1,
691                1,
692            );
693
694            cpass.set_bind_group(0, &self.bind_group_sort_buckets_a, &[]);
695            cpass.set_pipeline(&self.compute_pipeline_sort_buckets_a);
696            cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
697
698            cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f2, &[]);
699            cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f2);
700            cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
701
702            cpass.set_bind_group(0, &self.bind_group_sort_buckets_b, &[]);
703            cpass.set_pipeline(&self.compute_pipeline_sort_buckets_b);
704            cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
705
706            cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f3, &[]);
707            cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f3);
708            cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
709
710            cpass.set_bind_group(0, &self.bind_group_sort_buckets_a, &[]);
711            cpass.set_pipeline(&self.compute_pipeline_sort_buckets_a);
712            cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
713
714            cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f4, &[]);
715            cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f4);
716            cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
717
718            cpass.set_bind_group(0, &self.bind_group_sort_buckets_b, &[]);
719            cpass.set_pipeline(&self.compute_pipeline_sort_buckets_b);
720            cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
721
722            cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f5, &[]);
723            cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f5);
724            cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
725
726            cpass.set_bind_group(0, &self.bind_group_sort_buckets_a, &[]);
727            cpass.set_pipeline(&self.compute_pipeline_sort_buckets_a);
728            cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
729
730            cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f6, &[]);
731            cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f6);
732            cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
733
734            cpass.set_bind_group(0, &self.bind_group_sort_buckets_b, &[]);
735            cpass.set_pipeline(&self.compute_pipeline_sort_buckets_b);
736            cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
737
738            cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f7, &[]);
739            cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f7);
740            cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
741
742            cpass.set_bind_group(0, &self.bind_group_find_proofs, &[]);
743            cpass.set_pipeline(&self.compute_pipeline_find_proofs);
744            cpass.dispatch_workgroups(NUM_S_BUCKETS as u32 / find_proofs::WORKGROUP_SIZE, 1, 1);
745        }
746
747        encoder.copy_buffer_to_buffer(
748            &self.proofs_gpu,
749            0,
750            &self.proofs_host,
751            0,
752            self.proofs_host.size(),
753        );
754
755        // Map initial state for writes for the next iteration
756        encoder.map_buffer_on_submit(&self.initial_state_host, MapMode::Write, .., {
757            let mapping_error = Arc::clone(&self.mapping_error);
758
759            move |r| {
760                if let Err(error) = r {
761                    mapping_error.lock().replace(error);
762                }
763            }
764        });
765        encoder.map_buffer_on_submit(&self.proofs_host, MapMode::Read, .., {
766            let mapping_error = Arc::clone(&self.mapping_error);
767
768            move |r| {
769                if let Err(error) = r {
770                    mapping_error.lock().replace(error);
771                }
772            }
773        });
774
775        let submission_index = self.queue.submit([encoder.finish()]);
776
777        self.device.poll(PollType::Wait {
778            submission_index: Some(submission_index),
779            timeout: None,
780        })?;
781
782        if let Some(error) = self.mapping_error.lock().take() {
783            return Err(RecordEncodingError::BufferMapping(error));
784        }
785
786        let proofs = {
787            let proofs_host_ptr = self
788                .proofs_host
789                .get_mapped_range(..)
790                .as_ptr()
791                .cast::<ProofsHost>();
792            // SAFETY: Initialized on the GPU
793            unsafe { &*proofs_host_ptr }
794        };
795
796        self.tainted = false;
797
798        Ok(ProofsHostWrapper {
799            proofs,
800            proofs_host: &self.proofs_host,
801        })
802    }
803}
804
805fn bind_group_and_pipeline_compute_f1(
806    device: &wgpu::Device,
807    module: &ShaderModule,
808    initial_state_gpu: &Buffer,
809    bucket_sizes_gpu: &Buffer,
810    buckets_gpu: &Buffer,
811) -> (BindGroup, ComputePipeline) {
812    let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
813        label: Some("compute_f1"),
814        entries: &[
815            BindGroupLayoutEntry {
816                binding: 0,
817                count: None,
818                visibility: ShaderStages::COMPUTE,
819                ty: BindingType::Buffer {
820                    has_dynamic_offset: false,
821                    min_binding_size: None,
822                    ty: BufferBindingType::Uniform,
823                },
824            },
825            BindGroupLayoutEntry {
826                binding: 1,
827                count: None,
828                visibility: ShaderStages::COMPUTE,
829                ty: BindingType::Buffer {
830                    has_dynamic_offset: false,
831                    min_binding_size: None,
832                    ty: BufferBindingType::Storage { read_only: false },
833                },
834            },
835            BindGroupLayoutEntry {
836                binding: 2,
837                count: None,
838                visibility: ShaderStages::COMPUTE,
839                ty: BindingType::Buffer {
840                    has_dynamic_offset: false,
841                    min_binding_size: None,
842                    ty: BufferBindingType::Storage { read_only: false },
843                },
844            },
845        ],
846    });
847
848    let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
849        label: Some("compute_f1"),
850        bind_group_layouts: &[Some(&bind_group_layout)],
851        immediate_size: 0,
852    });
853
854    let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
855        compilation_options: PipelineCompilationOptions {
856            constants: &[],
857            zero_initialize_workgroup_memory: false,
858        },
859        cache: None,
860        label: Some("compute_f1"),
861        layout: Some(&pipeline_layout),
862        module,
863        entry_point: Some("compute_f1"),
864    });
865
866    let bind_group = device.create_bind_group(&BindGroupDescriptor {
867        label: Some("compute_f1"),
868        layout: &bind_group_layout,
869        entries: &[
870            BindGroupEntry {
871                binding: 0,
872                resource: initial_state_gpu.as_entire_binding(),
873            },
874            BindGroupEntry {
875                binding: 1,
876                resource: bucket_sizes_gpu.as_entire_binding(),
877            },
878            BindGroupEntry {
879                binding: 2,
880                resource: buckets_gpu.as_entire_binding(),
881            },
882        ],
883    });
884
885    (bind_group, compute_pipeline)
886}
887
888fn bind_group_and_pipeline_sort_buckets(
889    device: &wgpu::Device,
890    module: &ShaderModule,
891    bucket_sizes_gpu: &Buffer,
892    buckets_gpu: &Buffer,
893) -> (BindGroup, ComputePipeline) {
894    let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
895        label: Some("sort_buckets"),
896        entries: &[
897            BindGroupLayoutEntry {
898                binding: 0,
899                count: None,
900                visibility: ShaderStages::COMPUTE,
901                ty: BindingType::Buffer {
902                    has_dynamic_offset: false,
903                    min_binding_size: None,
904                    ty: BufferBindingType::Storage { read_only: false },
905                },
906            },
907            BindGroupLayoutEntry {
908                binding: 1,
909                count: None,
910                visibility: ShaderStages::COMPUTE,
911                ty: BindingType::Buffer {
912                    has_dynamic_offset: false,
913                    min_binding_size: None,
914                    ty: BufferBindingType::Storage { read_only: false },
915                },
916            },
917        ],
918    });
919
920    let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
921        label: Some("sort_buckets"),
922        bind_group_layouts: &[Some(&bind_group_layout)],
923        immediate_size: 0,
924    });
925
926    let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
927        compilation_options: PipelineCompilationOptions {
928            constants: &[],
929            zero_initialize_workgroup_memory: false,
930        },
931        cache: None,
932        label: Some("sort_buckets"),
933        layout: Some(&pipeline_layout),
934        module,
935        entry_point: Some("sort_buckets"),
936    });
937
938    let bind_group = device.create_bind_group(&BindGroupDescriptor {
939        label: Some("sort_buckets"),
940        layout: &bind_group_layout,
941        entries: &[
942            BindGroupEntry {
943                binding: 0,
944                resource: bucket_sizes_gpu.as_entire_binding(),
945            },
946            BindGroupEntry {
947                binding: 1,
948                resource: buckets_gpu.as_entire_binding(),
949            },
950        ],
951    });
952
953    (bind_group, compute_pipeline)
954}
955
956fn bind_group_and_pipeline_find_matches_and_compute_f2(
957    device: &wgpu::Device,
958    module: &ShaderModule,
959    parent_buckets_gpu: &Buffer,
960    bucket_sizes_gpu: &Buffer,
961    buckets_gpu: &Buffer,
962    positions_gpu: &Buffer,
963    metadatas_gpu: &Buffer,
964) -> (BindGroup, ComputePipeline) {
965    let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
966        label: Some("find_matches_and_compute_f2"),
967        entries: &[
968            BindGroupLayoutEntry {
969                binding: 0,
970                count: None,
971                visibility: ShaderStages::COMPUTE,
972                ty: BindingType::Buffer {
973                    has_dynamic_offset: false,
974                    min_binding_size: None,
975                    ty: BufferBindingType::Storage { read_only: true },
976                },
977            },
978            BindGroupLayoutEntry {
979                binding: 1,
980                count: None,
981                visibility: ShaderStages::COMPUTE,
982                ty: BindingType::Buffer {
983                    has_dynamic_offset: false,
984                    min_binding_size: None,
985                    ty: BufferBindingType::Storage { read_only: false },
986                },
987            },
988            BindGroupLayoutEntry {
989                binding: 2,
990                count: None,
991                visibility: ShaderStages::COMPUTE,
992                ty: BindingType::Buffer {
993                    has_dynamic_offset: false,
994                    min_binding_size: None,
995                    ty: BufferBindingType::Storage { read_only: false },
996                },
997            },
998            BindGroupLayoutEntry {
999                binding: 3,
1000                count: None,
1001                visibility: ShaderStages::COMPUTE,
1002                ty: BindingType::Buffer {
1003                    has_dynamic_offset: false,
1004                    min_binding_size: None,
1005                    ty: BufferBindingType::Storage { read_only: false },
1006                },
1007            },
1008            BindGroupLayoutEntry {
1009                binding: 4,
1010                count: None,
1011                visibility: ShaderStages::COMPUTE,
1012                ty: BindingType::Buffer {
1013                    has_dynamic_offset: false,
1014                    min_binding_size: None,
1015                    ty: BufferBindingType::Storage { read_only: false },
1016                },
1017            },
1018        ],
1019    });
1020
1021    let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1022        label: Some("find_matches_and_compute_f2"),
1023        bind_group_layouts: &[Some(&bind_group_layout)],
1024        immediate_size: 0,
1025    });
1026
1027    let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1028        compilation_options: PipelineCompilationOptions {
1029            constants: &[],
1030            zero_initialize_workgroup_memory: true,
1031        },
1032        cache: None,
1033        label: Some("find_matches_and_compute_f2"),
1034        layout: Some(&pipeline_layout),
1035        module,
1036        entry_point: Some("find_matches_and_compute_f2"),
1037    });
1038
1039    let bind_group = device.create_bind_group(&BindGroupDescriptor {
1040        label: Some("find_matches_and_compute_f2"),
1041        layout: &bind_group_layout,
1042        entries: &[
1043            BindGroupEntry {
1044                binding: 0,
1045                resource: parent_buckets_gpu.as_entire_binding(),
1046            },
1047            BindGroupEntry {
1048                binding: 1,
1049                resource: bucket_sizes_gpu.as_entire_binding(),
1050            },
1051            BindGroupEntry {
1052                binding: 2,
1053                resource: buckets_gpu.as_entire_binding(),
1054            },
1055            BindGroupEntry {
1056                binding: 3,
1057                resource: positions_gpu.as_entire_binding(),
1058            },
1059            BindGroupEntry {
1060                binding: 4,
1061                resource: metadatas_gpu.as_entire_binding(),
1062            },
1063        ],
1064    });
1065
1066    (bind_group, compute_pipeline)
1067}
1068
1069#[expect(
1070    clippy::too_many_arguments,
1071    reason = "Both I/O and Vulkan stuff together take a lot of arguments"
1072)]
1073fn bind_group_and_pipeline_find_matches_and_compute_fn<const TABLE_NUMBER: u8>(
1074    device: &wgpu::Device,
1075    module: &ShaderModule,
1076    parent_buckets_gpu: &Buffer,
1077    parent_metadatas_gpu: &Buffer,
1078    bucket_sizes_gpu: &Buffer,
1079    buckets_gpu: &Buffer,
1080    positions_gpu: &Buffer,
1081    metadatas_gpu: &Buffer,
1082) -> (BindGroup, ComputePipeline) {
1083    let label = format!("find_matches_and_compute_f{TABLE_NUMBER}");
1084    let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
1085        label: Some(&label),
1086        entries: &[
1087            BindGroupLayoutEntry {
1088                binding: 0,
1089                count: None,
1090                visibility: ShaderStages::COMPUTE,
1091                ty: BindingType::Buffer {
1092                    has_dynamic_offset: false,
1093                    min_binding_size: None,
1094                    ty: BufferBindingType::Storage { read_only: true },
1095                },
1096            },
1097            BindGroupLayoutEntry {
1098                binding: 1,
1099                count: None,
1100                visibility: ShaderStages::COMPUTE,
1101                ty: BindingType::Buffer {
1102                    has_dynamic_offset: false,
1103                    min_binding_size: None,
1104                    ty: BufferBindingType::Storage { read_only: true },
1105                },
1106            },
1107            BindGroupLayoutEntry {
1108                binding: 2,
1109                count: None,
1110                visibility: ShaderStages::COMPUTE,
1111                ty: BindingType::Buffer {
1112                    has_dynamic_offset: false,
1113                    min_binding_size: None,
1114                    ty: BufferBindingType::Storage { read_only: false },
1115                },
1116            },
1117            BindGroupLayoutEntry {
1118                binding: 3,
1119                count: None,
1120                visibility: ShaderStages::COMPUTE,
1121                ty: BindingType::Buffer {
1122                    has_dynamic_offset: false,
1123                    min_binding_size: None,
1124                    ty: BufferBindingType::Storage { read_only: false },
1125                },
1126            },
1127            BindGroupLayoutEntry {
1128                binding: 4,
1129                count: None,
1130                visibility: ShaderStages::COMPUTE,
1131                ty: BindingType::Buffer {
1132                    has_dynamic_offset: false,
1133                    min_binding_size: None,
1134                    ty: BufferBindingType::Storage { read_only: false },
1135                },
1136            },
1137            BindGroupLayoutEntry {
1138                binding: 5,
1139                count: None,
1140                visibility: ShaderStages::COMPUTE,
1141                ty: BindingType::Buffer {
1142                    has_dynamic_offset: false,
1143                    min_binding_size: None,
1144                    ty: BufferBindingType::Storage { read_only: false },
1145                },
1146            },
1147        ],
1148    });
1149
1150    let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1151        label: Some(&label),
1152        bind_group_layouts: &[Some(&bind_group_layout)],
1153        immediate_size: 0,
1154    });
1155
1156    let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1157        compilation_options: PipelineCompilationOptions {
1158            constants: &[],
1159            zero_initialize_workgroup_memory: true,
1160        },
1161        cache: None,
1162        label: Some(&label),
1163        layout: Some(&pipeline_layout),
1164        module,
1165        entry_point: Some(&format!("find_matches_and_compute_f{TABLE_NUMBER}")),
1166    });
1167
1168    let bind_group = device.create_bind_group(&BindGroupDescriptor {
1169        label: Some(&label),
1170        layout: &bind_group_layout,
1171        entries: &[
1172            BindGroupEntry {
1173                binding: 0,
1174                resource: parent_buckets_gpu.as_entire_binding(),
1175            },
1176            BindGroupEntry {
1177                binding: 1,
1178                resource: parent_metadatas_gpu.as_entire_binding(),
1179            },
1180            BindGroupEntry {
1181                binding: 2,
1182                resource: bucket_sizes_gpu.as_entire_binding(),
1183            },
1184            BindGroupEntry {
1185                binding: 3,
1186                resource: buckets_gpu.as_entire_binding(),
1187            },
1188            BindGroupEntry {
1189                binding: 4,
1190                resource: positions_gpu.as_entire_binding(),
1191            },
1192            BindGroupEntry {
1193                binding: 5,
1194                resource: metadatas_gpu.as_entire_binding(),
1195            },
1196        ],
1197    });
1198
1199    (bind_group, compute_pipeline)
1200}
1201
1202fn bind_group_and_pipeline_find_matches_and_compute_f7(
1203    device: &wgpu::Device,
1204    module: &ShaderModule,
1205    parent_buckets_gpu: &Buffer,
1206    parent_metadatas_gpu: &Buffer,
1207    table_6_proof_targets_sizes_gpu: &Buffer,
1208    table_6_proof_targets_gpu: &Buffer,
1209) -> (BindGroup, ComputePipeline) {
1210    let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
1211        label: Some("find_matches_and_compute_f7"),
1212        entries: &[
1213            BindGroupLayoutEntry {
1214                binding: 0,
1215                count: None,
1216                visibility: ShaderStages::COMPUTE,
1217                ty: BindingType::Buffer {
1218                    has_dynamic_offset: false,
1219                    min_binding_size: None,
1220                    ty: BufferBindingType::Storage { read_only: true },
1221                },
1222            },
1223            BindGroupLayoutEntry {
1224                binding: 1,
1225                count: None,
1226                visibility: ShaderStages::COMPUTE,
1227                ty: BindingType::Buffer {
1228                    has_dynamic_offset: false,
1229                    min_binding_size: None,
1230                    ty: BufferBindingType::Storage { read_only: true },
1231                },
1232            },
1233            BindGroupLayoutEntry {
1234                binding: 2,
1235                count: None,
1236                visibility: ShaderStages::COMPUTE,
1237                ty: BindingType::Buffer {
1238                    has_dynamic_offset: false,
1239                    min_binding_size: None,
1240                    ty: BufferBindingType::Storage { read_only: false },
1241                },
1242            },
1243            BindGroupLayoutEntry {
1244                binding: 3,
1245                count: None,
1246                visibility: ShaderStages::COMPUTE,
1247                ty: BindingType::Buffer {
1248                    has_dynamic_offset: false,
1249                    min_binding_size: None,
1250                    ty: BufferBindingType::Storage { read_only: false },
1251                },
1252            },
1253        ],
1254    });
1255
1256    let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1257        label: Some("find_matches_and_compute_f7"),
1258        bind_group_layouts: &[Some(&bind_group_layout)],
1259        immediate_size: 0,
1260    });
1261
1262    let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1263        compilation_options: PipelineCompilationOptions {
1264            constants: &[],
1265            zero_initialize_workgroup_memory: true,
1266        },
1267        cache: None,
1268        label: Some("find_matches_and_compute_f7"),
1269        layout: Some(&pipeline_layout),
1270        module,
1271        entry_point: Some("find_matches_and_compute_f7"),
1272    });
1273
1274    let bind_group = device.create_bind_group(&BindGroupDescriptor {
1275        label: Some("find_matches_and_compute_f7"),
1276        layout: &bind_group_layout,
1277        entries: &[
1278            BindGroupEntry {
1279                binding: 0,
1280                resource: parent_buckets_gpu.as_entire_binding(),
1281            },
1282            BindGroupEntry {
1283                binding: 1,
1284                resource: parent_metadatas_gpu.as_entire_binding(),
1285            },
1286            BindGroupEntry {
1287                binding: 2,
1288                resource: table_6_proof_targets_sizes_gpu.as_entire_binding(),
1289            },
1290            BindGroupEntry {
1291                binding: 3,
1292                resource: table_6_proof_targets_gpu.as_entire_binding(),
1293            },
1294        ],
1295    });
1296
1297    (bind_group, compute_pipeline)
1298}
1299
1300#[expect(
1301    clippy::too_many_arguments,
1302    reason = "Both I/O and Vulkan stuff together take a lot of arguments"
1303)]
1304fn bind_group_and_pipeline_find_proofs(
1305    device: &wgpu::Device,
1306    module: &ShaderModule,
1307    table_2_positions_gpu: &Buffer,
1308    table_3_positions_gpu: &Buffer,
1309    table_4_positions_gpu: &Buffer,
1310    table_5_positions_gpu: &Buffer,
1311    table_6_positions_gpu: &Buffer,
1312    bucket_sizes_gpu: &Buffer,
1313    buckets_gpu: &Buffer,
1314    proofs_gpu: &Buffer,
1315) -> (BindGroup, ComputePipeline) {
1316    let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
1317        label: Some("find_proofs"),
1318        entries: &[
1319            BindGroupLayoutEntry {
1320                binding: 0,
1321                count: None,
1322                visibility: ShaderStages::COMPUTE,
1323                ty: BindingType::Buffer {
1324                    has_dynamic_offset: false,
1325                    min_binding_size: None,
1326                    ty: BufferBindingType::Storage { read_only: true },
1327                },
1328            },
1329            BindGroupLayoutEntry {
1330                binding: 1,
1331                count: None,
1332                visibility: ShaderStages::COMPUTE,
1333                ty: BindingType::Buffer {
1334                    has_dynamic_offset: false,
1335                    min_binding_size: None,
1336                    ty: BufferBindingType::Storage { read_only: true },
1337                },
1338            },
1339            BindGroupLayoutEntry {
1340                binding: 2,
1341                count: None,
1342                visibility: ShaderStages::COMPUTE,
1343                ty: BindingType::Buffer {
1344                    has_dynamic_offset: false,
1345                    min_binding_size: None,
1346                    ty: BufferBindingType::Storage { read_only: true },
1347                },
1348            },
1349            BindGroupLayoutEntry {
1350                binding: 3,
1351                count: None,
1352                visibility: ShaderStages::COMPUTE,
1353                ty: BindingType::Buffer {
1354                    has_dynamic_offset: false,
1355                    min_binding_size: None,
1356                    ty: BufferBindingType::Storage { read_only: true },
1357                },
1358            },
1359            BindGroupLayoutEntry {
1360                binding: 4,
1361                count: None,
1362                visibility: ShaderStages::COMPUTE,
1363                ty: BindingType::Buffer {
1364                    has_dynamic_offset: false,
1365                    min_binding_size: None,
1366                    ty: BufferBindingType::Storage { read_only: true },
1367                },
1368            },
1369            BindGroupLayoutEntry {
1370                binding: 5,
1371                count: None,
1372                visibility: ShaderStages::COMPUTE,
1373                ty: BindingType::Buffer {
1374                    has_dynamic_offset: false,
1375                    min_binding_size: None,
1376                    ty: BufferBindingType::Storage { read_only: false },
1377                },
1378            },
1379            BindGroupLayoutEntry {
1380                binding: 6,
1381                count: None,
1382                visibility: ShaderStages::COMPUTE,
1383                ty: BindingType::Buffer {
1384                    has_dynamic_offset: false,
1385                    min_binding_size: None,
1386                    ty: BufferBindingType::Storage { read_only: true },
1387                },
1388            },
1389            BindGroupLayoutEntry {
1390                binding: 7,
1391                count: None,
1392                visibility: ShaderStages::COMPUTE,
1393                ty: BindingType::Buffer {
1394                    has_dynamic_offset: false,
1395                    min_binding_size: None,
1396                    ty: BufferBindingType::Storage { read_only: false },
1397                },
1398            },
1399        ],
1400    });
1401
1402    let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1403        label: Some("find_proofs"),
1404        bind_group_layouts: &[Some(&bind_group_layout)],
1405        immediate_size: 0,
1406    });
1407
1408    let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1409        compilation_options: PipelineCompilationOptions {
1410            constants: &[],
1411            zero_initialize_workgroup_memory: false,
1412        },
1413        cache: None,
1414        label: Some("find_proofs"),
1415        layout: Some(&pipeline_layout),
1416        module,
1417        entry_point: Some("find_proofs"),
1418    });
1419
1420    let bind_group = device.create_bind_group(&BindGroupDescriptor {
1421        label: Some("find_proofs"),
1422        layout: &bind_group_layout,
1423        entries: &[
1424            BindGroupEntry {
1425                binding: 0,
1426                resource: table_2_positions_gpu.as_entire_binding(),
1427            },
1428            BindGroupEntry {
1429                binding: 1,
1430                resource: table_3_positions_gpu.as_entire_binding(),
1431            },
1432            BindGroupEntry {
1433                binding: 2,
1434                resource: table_4_positions_gpu.as_entire_binding(),
1435            },
1436            BindGroupEntry {
1437                binding: 3,
1438                resource: table_5_positions_gpu.as_entire_binding(),
1439            },
1440            BindGroupEntry {
1441                binding: 4,
1442                resource: table_6_positions_gpu.as_entire_binding(),
1443            },
1444            BindGroupEntry {
1445                binding: 5,
1446                resource: bucket_sizes_gpu.as_entire_binding(),
1447            },
1448            BindGroupEntry {
1449                binding: 6,
1450                resource: buckets_gpu.as_entire_binding(),
1451            },
1452            BindGroupEntry {
1453                binding: 7,
1454                resource: proofs_gpu.as_entire_binding(),
1455            },
1456        ],
1457    });
1458
1459    (bind_group, compute_pipeline)
1460}