ab_proof_of_space_gpu/
host.rs

1#[cfg(all(test, not(miri)))]
2mod tests;
3
4use crate::shader::constants::{
5    MAX_BUCKET_SIZE, MAX_TABLE_SIZE, NUM_BUCKETS, NUM_MATCH_BUCKETS, NUM_S_BUCKETS,
6    REDUCED_MATCHES_COUNT,
7};
8use crate::shader::find_matches_and_compute_f7::{NUM_ELEMENTS_PER_S_BUCKET, ProofTargets};
9use crate::shader::find_proofs::ProofsHost;
10use crate::shader::types::{Metadata, Position, PositionR};
11use crate::shader::{compute_f1, find_proofs, select_shader_features_limits};
12use ab_chacha8::{ChaCha8Block, ChaCha8State, block_to_bytes};
13use ab_core_primitives::pieces::{PieceOffset, Record};
14use ab_core_primitives::pos::PosSeed;
15use ab_core_primitives::sectors::SectorId;
16use ab_erasure_coding::ErasureCoding;
17use ab_farmer_components::plotting::RecordsEncoder;
18use ab_farmer_components::sector::SectorContentsMap;
19use async_lock::Mutex as AsyncMutex;
20use futures::stream::FuturesOrdered;
21use futures::{StreamExt, TryStreamExt};
22use parking_lot::Mutex;
23use rayon::prelude::*;
24use rayon::{ThreadPool, ThreadPoolBuilder};
25use rclite::Arc;
26use std::fmt;
27use std::num::NonZeroU8;
28use std::simd::Simd;
29use std::sync::Arc as StdArc;
30use std::sync::atomic::{AtomicBool, Ordering};
31use tracing::{debug, warn};
32use wgpu::{
33    AdapterInfo, Backend, BackendOptions, Backends, BindGroup, BindGroupDescriptor, BindGroupEntry,
34    BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType, Buffer, BufferAddress,
35    BufferAsyncError, BufferBindingType, BufferDescriptor, BufferUsages, CommandEncoderDescriptor,
36    ComputePassDescriptor, ComputePipeline, ComputePipelineDescriptor, DeviceDescriptor,
37    DeviceType, Instance, InstanceDescriptor, InstanceFlags, MapMode, MemoryBudgetThresholds,
38    PipelineCompilationOptions, PipelineLayoutDescriptor, PollError, PollType, Queue,
39    RequestDeviceError, ShaderModule, ShaderStages,
40};
41
42/// Proof creation error
43#[derive(Debug, thiserror::Error)]
44pub enum RecordEncodingError {
45    /// Too many records
46    #[error("Too many records: {0}")]
47    TooManyRecords(usize),
48    /// Proof creation failed previously and the device is now considered broken
49    #[error("Proof creation failed previously and the device is now considered broken")]
50    DeviceBroken,
51    /// Failed to map buffer
52    #[error("Failed to map buffer: {0}")]
53    BufferMapping(#[from] BufferAsyncError),
54    /// Poll error
55    #[error("Poll error: {0}")]
56    DevicePoll(#[from] PollError),
57}
58
59struct ProofsHostWrapper<'a> {
60    proofs: &'a ProofsHost,
61    proofs_host: &'a Buffer,
62}
63
64impl Drop for ProofsHostWrapper<'_> {
65    fn drop(&mut self) {
66        self.proofs_host.unmap();
67    }
68}
69
70/// Wrapper data structure encapsulating a single compatible device
71#[derive(Clone)]
72pub struct Device {
73    id: u32,
74    devices: Vec<(wgpu::Device, Queue, ShaderModule)>,
75    adapter_info: AdapterInfo,
76}
77
78impl fmt::Debug for Device {
79    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80        f.debug_struct("Device")
81            .field("id", &self.id)
82            .field("name", &self.adapter_info.name)
83            .field("device_type", &self.adapter_info.device_type)
84            .field("driver", &self.adapter_info.driver)
85            .field("driver_info", &self.adapter_info.driver_info)
86            .field("backend", &self.adapter_info.backend)
87            .finish_non_exhaustive()
88    }
89}
90
91impl Device {
92    /// Returns [`Device`] for each available device
93    pub async fn enumerate<NOQ>(number_of_queues: NOQ) -> Vec<Self>
94    where
95        NOQ: Fn(DeviceType) -> NonZeroU8,
96    {
97        let backends = Backends::from_env().unwrap_or(Backends::METAL | Backends::VULKAN);
98        let instance = Instance::new(&InstanceDescriptor {
99            backends,
100            flags: if cfg!(debug_assertions) {
101                InstanceFlags::debugging().with_env()
102            } else {
103                InstanceFlags::from_env_or_default()
104            },
105            memory_budget_thresholds: MemoryBudgetThresholds::default(),
106            backend_options: BackendOptions::from_env_or_default(),
107        });
108
109        let adapters = instance.enumerate_adapters(backends);
110        let number_of_queues = &number_of_queues;
111
112        adapters
113            .into_iter()
114            .zip(0..)
115            .map(|(adapter, id)| async move {
116                let adapter_info = adapter.get_info();
117
118                let (shader, required_features, required_limits) =
119                    match select_shader_features_limits(&adapter) {
120                        Some((shader, required_features, required_limits)) => {
121                            debug!(
122                                %id,
123                                adapter_info = ?adapter_info,
124                                "Compatible adapter found"
125                            );
126
127                            (shader, required_features, required_limits)
128                        }
129                        None => {
130                            debug!(
131                                %id,
132                                adapter_info = ?adapter_info,
133                                "Incompatible adapter found"
134                            );
135
136                            return None;
137                        }
138                    };
139
140                // TODO: creation of multiple devices is a workaround for lack of support for
141                //  multiple queues: https://github.com/gfx-rs/wgpu/issues/1066
142                let devices = (0..number_of_queues(adapter_info.device_type).get())
143                    .map(|_| async {
144                        let (device, queue) = adapter
145                            .request_device(&DeviceDescriptor {
146                                label: None,
147                                required_features,
148                                required_limits: required_limits.clone(),
149                                ..DeviceDescriptor::default()
150                            })
151                            .await
152                            .inspect_err(|error| {
153                                warn!(%id, ?adapter_info, %error, "Failed to request the device");
154                            })?;
155                        let module = device.create_shader_module(shader.clone());
156
157                        Ok::<_, RequestDeviceError>((device, queue, module))
158                    })
159                    .collect::<FuturesOrdered<_>>()
160                    .try_collect::<Vec<_>>()
161                    .await
162                    .ok()?;
163
164                Some(Self {
165                    id,
166                    devices,
167                    adapter_info,
168                })
169            })
170            .collect::<FuturesOrdered<_>>()
171            .filter_map(|device| async move { device })
172            .collect()
173            .await
174    }
175
176    /// Gpu ID
177    pub fn id(&self) -> u32 {
178        self.id
179    }
180
181    /// Device name
182    pub fn name(&self) -> &str {
183        &self.adapter_info.name
184    }
185
186    /// Device type
187    pub fn device_type(&self) -> DeviceType {
188        self.adapter_info.device_type
189    }
190
191    /// Driver
192    pub fn driver(&self) -> &str {
193        &self.adapter_info.driver
194    }
195
196    /// Driver info
197    pub fn driver_info(&self) -> &str {
198        &self.adapter_info.driver_info
199    }
200
201    /// Backend
202    pub fn backend(&self) -> Backend {
203        self.adapter_info.backend
204    }
205
206    pub fn instantiate(
207        &self,
208        erasure_coding: ErasureCoding,
209        global_mutex: StdArc<AsyncMutex<()>>,
210    ) -> anyhow::Result<GpuRecordsEncoder> {
211        GpuRecordsEncoder::new(
212            self.id,
213            self.devices.clone(),
214            self.adapter_info.clone(),
215            erasure_coding,
216            global_mutex,
217        )
218    }
219}
220
221pub struct GpuRecordsEncoder {
222    id: u32,
223    instances: Vec<Mutex<GpuRecordsEncoderInstance>>,
224    thread_pool: ThreadPool,
225    adapter_info: AdapterInfo,
226    erasure_coding: ErasureCoding,
227    global_mutex: StdArc<AsyncMutex<()>>,
228}
229
230impl fmt::Debug for GpuRecordsEncoder {
231    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232        f.debug_struct("GpuRecordsEncoder")
233            .field("id", &self.id)
234            .field("name", &self.adapter_info.name)
235            .field("device_type", &self.adapter_info.device_type)
236            .field("driver", &self.adapter_info.driver)
237            .field("driver_info", &self.adapter_info.driver_info)
238            .field("backend", &self.adapter_info.backend)
239            .finish_non_exhaustive()
240    }
241}
242
243impl RecordsEncoder for GpuRecordsEncoder {
244    // TODO: Run more than one encoding per device concurrently
245    fn encode_records(
246        &mut self,
247        sector_id: &SectorId,
248        records: &mut [Record],
249        abort_early: &AtomicBool,
250    ) -> anyhow::Result<SectorContentsMap> {
251        let mut sector_contents_map = SectorContentsMap::new(
252            u16::try_from(records.len())
253                .map_err(|_| RecordEncodingError::TooManyRecords(records.len()))?,
254        );
255
256        let maybe_error = self.thread_pool.install(|| {
257            records
258                .par_iter_mut()
259                .zip(sector_contents_map.iter_record_chunks_used_mut())
260                .enumerate()
261                .find_map_any(|(piece_offset, (record, record_chunks_used))| {
262                    // Take mutex briefly to make sure encoding is allowed right now
263                    self.global_mutex.lock_blocking();
264
265                    let mut parity_record_chunks = Record::new_boxed();
266
267                    // TODO: Do erasure coding on the GPU
268                    // Erasure code source record chunks
269                    self.erasure_coding
270                        .extend(record.iter(), parity_record_chunks.iter_mut())
271                        .expect("Statically guaranteed valid inputs; qed");
272
273                    if abort_early.load(Ordering::Relaxed) {
274                        return None;
275                    }
276                    let seed =
277                        sector_id.derive_evaluation_seed(PieceOffset::from(piece_offset as u16));
278                    let thread_index = rayon::current_thread_index().unwrap_or_default();
279                    let mut encoder_instance = self.instances[thread_index]
280                        .try_lock()
281                        .expect("1:1 mapping between threads and devices; qed");
282                    let proofs = match encoder_instance.create_proofs(&seed) {
283                        Ok(proofs) => proofs,
284                        Err(error) => {
285                            return Some(error);
286                        }
287                    };
288                    let proofs = proofs.proofs;
289
290                    *record_chunks_used = proofs.found_proofs;
291
292                    // TODO: Record encoding on the GPU
293                    let mut num_found_proofs = 0_usize;
294                    for (s_buckets, found_proofs) in (0..Record::NUM_S_BUCKETS)
295                        .array_chunks::<{ u8::BITS as usize }>()
296                        .zip(record_chunks_used)
297                    {
298                        for (proof_offset, s_bucket) in s_buckets.into_iter().enumerate() {
299                            if num_found_proofs == Record::NUM_CHUNKS {
300                                // Enough proofs collected, clear the rest of the bits
301                                *found_proofs &=
302                                    u8::MAX.unbounded_shr(u8::BITS - proof_offset as u32);
303                                break;
304                            }
305                            if (*found_proofs & (1 << proof_offset)) != 0 {
306                                let record_chunk = if s_bucket < Record::NUM_CHUNKS {
307                                    record[s_bucket]
308                                } else {
309                                    parity_record_chunks[s_bucket - Record::NUM_CHUNKS]
310                                };
311
312                                record[num_found_proofs] = (Simd::from(record_chunk)
313                                    ^ Simd::from(*proofs.proofs[s_bucket].hash()))
314                                .to_array();
315                                num_found_proofs += 1;
316                            }
317                        }
318                    }
319
320                    None
321                })
322        });
323
324        if let Some(error) = maybe_error {
325            return Err(error.into());
326        }
327
328        Ok(sector_contents_map)
329    }
330}
331
332impl GpuRecordsEncoder {
333    fn new(
334        id: u32,
335        devices: Vec<(wgpu::Device, Queue, ShaderModule)>,
336        adapter_info: AdapterInfo,
337        erasure_coding: ErasureCoding,
338        global_mutex: StdArc<AsyncMutex<()>>,
339    ) -> anyhow::Result<Self> {
340        let thread_pool = ThreadPoolBuilder::new()
341            .thread_name(move |thread_index| format!("pos-gpu-{id}.{thread_index}"))
342            .num_threads(devices.len())
343            .build()?;
344
345        Ok(Self {
346            id,
347            instances: devices
348                .into_iter()
349                .map(|(device, queue, module)| {
350                    Mutex::new(GpuRecordsEncoderInstance::new(device, queue, module))
351                })
352                .collect(),
353            thread_pool,
354            adapter_info,
355            erasure_coding,
356            global_mutex,
357        })
358    }
359}
360
361struct GpuRecordsEncoderInstance {
362    device: wgpu::Device,
363    queue: Queue,
364    mapping_error: Arc<Mutex<Option<BufferAsyncError>>>,
365    tainted: bool,
366    initial_state_host: Buffer,
367    initial_state_gpu: Buffer,
368    proofs_host: Buffer,
369    proofs_gpu: Buffer,
370    bind_group_compute_f1: BindGroup,
371    compute_pipeline_compute_f1: ComputePipeline,
372    bind_group_sort_buckets_a: BindGroup,
373    compute_pipeline_sort_buckets_a: ComputePipeline,
374    bind_group_sort_buckets_b: BindGroup,
375    compute_pipeline_sort_buckets_b: ComputePipeline,
376    bind_group_find_matches_and_compute_f2: BindGroup,
377    compute_pipeline_find_matches_and_compute_f2: ComputePipeline,
378    bind_group_find_matches_and_compute_f3: BindGroup,
379    compute_pipeline_find_matches_and_compute_f3: ComputePipeline,
380    bind_group_find_matches_and_compute_f4: BindGroup,
381    compute_pipeline_find_matches_and_compute_f4: ComputePipeline,
382    bind_group_find_matches_and_compute_f5: BindGroup,
383    compute_pipeline_find_matches_and_compute_f5: ComputePipeline,
384    bind_group_find_matches_and_compute_f6: BindGroup,
385    compute_pipeline_find_matches_and_compute_f6: ComputePipeline,
386    bind_group_find_matches_and_compute_f7: BindGroup,
387    compute_pipeline_find_matches_and_compute_f7: ComputePipeline,
388    bind_group_find_proofs: BindGroup,
389    compute_pipeline_find_proofs: ComputePipeline,
390}
391
392impl GpuRecordsEncoderInstance {
393    fn new(device: wgpu::Device, queue: Queue, module: ShaderModule) -> Self {
394        let initial_state_host = device.create_buffer(&BufferDescriptor {
395            label: Some("initial_state_host"),
396            size: size_of::<ChaCha8Block>() as BufferAddress,
397            usage: BufferUsages::MAP_WRITE | BufferUsages::COPY_SRC,
398            mapped_at_creation: true,
399        });
400
401        let initial_state_gpu = device.create_buffer(&BufferDescriptor {
402            label: Some("initial_state_gpu"),
403            size: initial_state_host.size(),
404            usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST,
405            mapped_at_creation: false,
406        });
407
408        let bucket_sizes_gpu_buffer_size = size_of::<[u32; NUM_BUCKETS]>() as BufferAddress;
409        let table_6_proof_targets_sizes_gpu_buffer_size =
410            size_of::<[u32; NUM_S_BUCKETS]>() as BufferAddress;
411        // TODO: Sizes are excessive, for `bucket_sizes_gpu` are less than `u16` and could use SWAR
412        //  approach for storing bucket sizes. Similarly, `table_6_proof_targets_sizes_gpu` sizes
413        //  are less than `u8` and could use SWAR too with even higher compression ratio
414        let bucket_sizes_gpu = device.create_buffer(&BufferDescriptor {
415            label: Some("bucket_sizes_gpu"),
416            size: bucket_sizes_gpu_buffer_size.max(table_6_proof_targets_sizes_gpu_buffer_size),
417            usage: BufferUsages::STORAGE,
418            mapped_at_creation: false,
419        });
420        // Reuse the same buffer as `bucket_sizes_gpu`, they are not overlapping in use
421        let table_6_proof_targets_sizes_gpu = bucket_sizes_gpu.clone();
422
423        let buckets_a_gpu = device.create_buffer(&BufferDescriptor {
424            label: Some("buckets_a_gpu"),
425            size: size_of::<[[PositionR; MAX_BUCKET_SIZE]; NUM_BUCKETS]>() as BufferAddress,
426            usage: BufferUsages::STORAGE,
427            mapped_at_creation: false,
428        });
429
430        let buckets_b_gpu = device.create_buffer(&BufferDescriptor {
431            label: Some("buckets_b_gpu"),
432            size: buckets_a_gpu.size(),
433            usage: BufferUsages::STORAGE,
434            mapped_at_creation: false,
435        });
436
437        let positions_f2_gpu = device.create_buffer(&BufferDescriptor {
438            label: Some("positions_f2_gpu"),
439            size: size_of::<[[[Position; 2]; REDUCED_MATCHES_COUNT]; NUM_MATCH_BUCKETS]>()
440                as BufferAddress,
441            usage: BufferUsages::STORAGE,
442            mapped_at_creation: false,
443        });
444
445        let positions_f3_gpu = device.create_buffer(&BufferDescriptor {
446            label: Some("positions_f3_gpu"),
447            size: positions_f2_gpu.size(),
448            usage: BufferUsages::STORAGE,
449            mapped_at_creation: false,
450        });
451
452        let positions_f4_gpu = device.create_buffer(&BufferDescriptor {
453            label: Some("positions_f4_gpu"),
454            size: positions_f2_gpu.size(),
455            usage: BufferUsages::STORAGE,
456            mapped_at_creation: false,
457        });
458
459        let positions_f5_gpu = device.create_buffer(&BufferDescriptor {
460            label: Some("positions_f5_gpu"),
461            size: positions_f2_gpu.size(),
462            usage: BufferUsages::STORAGE,
463            mapped_at_creation: false,
464        });
465
466        let positions_f6_gpu = device.create_buffer(&BufferDescriptor {
467            label: Some("positions_f6_gpu"),
468            size: positions_f2_gpu.size(),
469            usage: BufferUsages::STORAGE,
470            mapped_at_creation: false,
471        });
472
473        let metadatas_gpu_buffer_size =
474            size_of::<[[Metadata; REDUCED_MATCHES_COUNT]; NUM_MATCH_BUCKETS]>() as BufferAddress;
475        let table_6_proof_targets_gpu_buffer_size = size_of::<
476            [[ProofTargets; NUM_ELEMENTS_PER_S_BUCKET]; NUM_S_BUCKETS],
477        >() as BufferAddress;
478        let metadatas_a_gpu = device.create_buffer(&BufferDescriptor {
479            label: Some("metadatas_a_gpu"),
480            size: metadatas_gpu_buffer_size.max(table_6_proof_targets_gpu_buffer_size),
481            usage: BufferUsages::STORAGE,
482            mapped_at_creation: false,
483        });
484        // Reuse the same buffer as `metadatas_a_gpu`, they are not overlapping in use
485        let table_6_proof_targets_gpu = metadatas_a_gpu.clone();
486
487        let metadatas_b_gpu = device.create_buffer(&BufferDescriptor {
488            label: Some("metadatas_b_gpu"),
489            size: metadatas_gpu_buffer_size,
490            usage: BufferUsages::STORAGE,
491            mapped_at_creation: false,
492        });
493
494        let proofs_host = device.create_buffer(&BufferDescriptor {
495            label: Some("proofs_host"),
496            size: size_of::<ProofsHost>() as BufferAddress,
497            usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
498            mapped_at_creation: false,
499        });
500
501        let proofs_gpu = device.create_buffer(&BufferDescriptor {
502            label: Some("proofs_gpu"),
503            size: proofs_host.size(),
504            usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
505            mapped_at_creation: false,
506        });
507
508        let (bind_group_compute_f1, compute_pipeline_compute_f1) =
509            bind_group_and_pipeline_compute_f1(
510                &device,
511                &module,
512                &initial_state_gpu,
513                &bucket_sizes_gpu,
514                &buckets_a_gpu,
515            );
516        let (bind_group_sort_buckets_a, compute_pipeline_sort_buckets_a) =
517            bind_group_and_pipeline_sort_buckets(
518                &device,
519                &module,
520                &bucket_sizes_gpu,
521                &buckets_a_gpu,
522            );
523
524        let (bind_group_sort_buckets_b, compute_pipeline_sort_buckets_b) =
525            bind_group_and_pipeline_sort_buckets(
526                &device,
527                &module,
528                &bucket_sizes_gpu,
529                &buckets_b_gpu,
530            );
531
532        let (bind_group_find_matches_and_compute_f2, compute_pipeline_find_matches_and_compute_f2) =
533            bind_group_and_pipeline_find_matches_and_compute_f2(
534                &device,
535                &module,
536                &buckets_a_gpu,
537                &bucket_sizes_gpu,
538                &buckets_b_gpu,
539                &positions_f2_gpu,
540                &metadatas_b_gpu,
541            );
542
543        let (bind_group_find_matches_and_compute_f3, compute_pipeline_find_matches_and_compute_f3) =
544            bind_group_and_pipeline_find_matches_and_compute_fn::<3>(
545                &device,
546                &module,
547                &buckets_b_gpu,
548                &metadatas_b_gpu,
549                &bucket_sizes_gpu,
550                &buckets_a_gpu,
551                &positions_f3_gpu,
552                &metadatas_a_gpu,
553            );
554
555        let (bind_group_find_matches_and_compute_f4, compute_pipeline_find_matches_and_compute_f4) =
556            bind_group_and_pipeline_find_matches_and_compute_fn::<4>(
557                &device,
558                &module,
559                &buckets_a_gpu,
560                &metadatas_a_gpu,
561                &bucket_sizes_gpu,
562                &buckets_b_gpu,
563                &positions_f4_gpu,
564                &metadatas_b_gpu,
565            );
566
567        let (bind_group_find_matches_and_compute_f5, compute_pipeline_find_matches_and_compute_f5) =
568            bind_group_and_pipeline_find_matches_and_compute_fn::<5>(
569                &device,
570                &module,
571                &buckets_b_gpu,
572                &metadatas_b_gpu,
573                &bucket_sizes_gpu,
574                &buckets_a_gpu,
575                &positions_f5_gpu,
576                &metadatas_a_gpu,
577            );
578
579        let (bind_group_find_matches_and_compute_f6, compute_pipeline_find_matches_and_compute_f6) =
580            bind_group_and_pipeline_find_matches_and_compute_fn::<6>(
581                &device,
582                &module,
583                &buckets_a_gpu,
584                &metadatas_a_gpu,
585                &bucket_sizes_gpu,
586                &buckets_b_gpu,
587                &positions_f6_gpu,
588                &metadatas_b_gpu,
589            );
590
591        let (bind_group_find_matches_and_compute_f7, compute_pipeline_find_matches_and_compute_f7) =
592            bind_group_and_pipeline_find_matches_and_compute_f7(
593                &device,
594                &module,
595                &buckets_b_gpu,
596                &metadatas_b_gpu,
597                &table_6_proof_targets_sizes_gpu,
598                &table_6_proof_targets_gpu,
599            );
600
601        let (bind_group_find_proofs, compute_pipeline_find_proofs) =
602            bind_group_and_pipeline_find_proofs(
603                &device,
604                &module,
605                &positions_f2_gpu,
606                &positions_f3_gpu,
607                &positions_f4_gpu,
608                &positions_f5_gpu,
609                &positions_f6_gpu,
610                &table_6_proof_targets_sizes_gpu,
611                &table_6_proof_targets_gpu,
612                &proofs_gpu,
613            );
614
615        Self {
616            device,
617            queue,
618            mapping_error: Arc::new(Mutex::new(None)),
619            tainted: false,
620            initial_state_host,
621            initial_state_gpu,
622            proofs_host,
623            proofs_gpu,
624            bind_group_compute_f1,
625            compute_pipeline_compute_f1,
626            bind_group_sort_buckets_a,
627            compute_pipeline_sort_buckets_a,
628            bind_group_sort_buckets_b,
629            compute_pipeline_sort_buckets_b,
630            bind_group_find_matches_and_compute_f2,
631            compute_pipeline_find_matches_and_compute_f2,
632            bind_group_find_matches_and_compute_f3,
633            compute_pipeline_find_matches_and_compute_f3,
634            bind_group_find_matches_and_compute_f4,
635            compute_pipeline_find_matches_and_compute_f4,
636            bind_group_find_matches_and_compute_f5,
637            compute_pipeline_find_matches_and_compute_f5,
638            bind_group_find_matches_and_compute_f6,
639            compute_pipeline_find_matches_and_compute_f6,
640            bind_group_find_matches_and_compute_f7,
641            compute_pipeline_find_matches_and_compute_f7,
642            bind_group_find_proofs,
643            compute_pipeline_find_proofs,
644        }
645    }
646
647    fn create_proofs(
648        &mut self,
649        seed: &PosSeed,
650    ) -> Result<ProofsHostWrapper<'_>, RecordEncodingError> {
651        if self.tainted {
652            return Err(RecordEncodingError::DeviceBroken);
653        }
654        self.tainted = true;
655
656        let mut encoder = self
657            .device
658            .create_command_encoder(&CommandEncoderDescriptor {
659                label: Some("create_proofs"),
660            });
661
662        // Mapped initially and re-mapped at the end of the computation
663        self.initial_state_host
664            .get_mapped_range_mut(..)
665            .copy_from_slice(&block_to_bytes(
666                &ChaCha8State::init(seed, &[0; _]).to_repr(),
667            ));
668        self.initial_state_host.unmap();
669
670        encoder.copy_buffer_to_buffer(
671            &self.initial_state_host,
672            0,
673            &self.initial_state_gpu,
674            0,
675            self.initial_state_host.size(),
676        );
677
678        {
679            let mut cpass = encoder.begin_compute_pass(&ComputePassDescriptor {
680                label: Some("create_proofs"),
681                timestamp_writes: None,
682            });
683
684            cpass.set_bind_group(0, &self.bind_group_compute_f1, &[]);
685            cpass.set_pipeline(&self.compute_pipeline_compute_f1);
686            cpass.dispatch_workgroups(
687                MAX_TABLE_SIZE
688                    .div_ceil(compute_f1::WORKGROUP_SIZE * compute_f1::ELEMENTS_PER_INVOCATION),
689                1,
690                1,
691            );
692
693            cpass.set_bind_group(0, &self.bind_group_sort_buckets_a, &[]);
694            cpass.set_pipeline(&self.compute_pipeline_sort_buckets_a);
695            cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
696
697            cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f2, &[]);
698            cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f2);
699            cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
700
701            cpass.set_bind_group(0, &self.bind_group_sort_buckets_b, &[]);
702            cpass.set_pipeline(&self.compute_pipeline_sort_buckets_b);
703            cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
704
705            cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f3, &[]);
706            cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f3);
707            cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
708
709            cpass.set_bind_group(0, &self.bind_group_sort_buckets_a, &[]);
710            cpass.set_pipeline(&self.compute_pipeline_sort_buckets_a);
711            cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
712
713            cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f4, &[]);
714            cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f4);
715            cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
716
717            cpass.set_bind_group(0, &self.bind_group_sort_buckets_b, &[]);
718            cpass.set_pipeline(&self.compute_pipeline_sort_buckets_b);
719            cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
720
721            cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f5, &[]);
722            cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f5);
723            cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
724
725            cpass.set_bind_group(0, &self.bind_group_sort_buckets_a, &[]);
726            cpass.set_pipeline(&self.compute_pipeline_sort_buckets_a);
727            cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
728
729            cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f6, &[]);
730            cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f6);
731            cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
732
733            cpass.set_bind_group(0, &self.bind_group_sort_buckets_b, &[]);
734            cpass.set_pipeline(&self.compute_pipeline_sort_buckets_b);
735            cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
736
737            cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f7, &[]);
738            cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f7);
739            cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
740
741            cpass.set_bind_group(0, &self.bind_group_find_proofs, &[]);
742            cpass.set_pipeline(&self.compute_pipeline_find_proofs);
743            cpass.dispatch_workgroups(NUM_S_BUCKETS as u32 / find_proofs::WORKGROUP_SIZE, 1, 1);
744        }
745
746        encoder.copy_buffer_to_buffer(
747            &self.proofs_gpu,
748            0,
749            &self.proofs_host,
750            0,
751            self.proofs_host.size(),
752        );
753
754        // Map initial state for writes for the next iteration
755        encoder.map_buffer_on_submit(&self.initial_state_host, MapMode::Write, .., {
756            let mapping_error = Arc::clone(&self.mapping_error);
757
758            move |r| {
759                if let Err(error) = r {
760                    mapping_error.lock().replace(error);
761                }
762            }
763        });
764        encoder.map_buffer_on_submit(&self.proofs_host, MapMode::Read, .., {
765            let mapping_error = Arc::clone(&self.mapping_error);
766
767            move |r| {
768                if let Err(error) = r {
769                    mapping_error.lock().replace(error);
770                }
771            }
772        });
773
774        let submission_index = self.queue.submit([encoder.finish()]);
775
776        self.device.poll(PollType::Wait {
777            submission_index: Some(submission_index),
778            timeout: None,
779        })?;
780
781        if let Some(error) = self.mapping_error.lock().take() {
782            return Err(RecordEncodingError::BufferMapping(error));
783        }
784
785        let proofs = {
786            let proofs_host_ptr = self
787                .proofs_host
788                .get_mapped_range(..)
789                .as_ptr()
790                .cast::<ProofsHost>();
791            // SAFETY: Initialized on the GPU
792            unsafe { &*proofs_host_ptr }
793        };
794
795        self.tainted = false;
796
797        Ok(ProofsHostWrapper {
798            proofs,
799            proofs_host: &self.proofs_host,
800        })
801    }
802}
803
804fn bind_group_and_pipeline_compute_f1(
805    device: &wgpu::Device,
806    module: &ShaderModule,
807    initial_state_gpu: &Buffer,
808    bucket_sizes_gpu: &Buffer,
809    buckets_gpu: &Buffer,
810) -> (BindGroup, ComputePipeline) {
811    let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
812        label: Some("compute_f1"),
813        entries: &[
814            BindGroupLayoutEntry {
815                binding: 0,
816                count: None,
817                visibility: ShaderStages::COMPUTE,
818                ty: BindingType::Buffer {
819                    has_dynamic_offset: false,
820                    min_binding_size: None,
821                    ty: BufferBindingType::Uniform,
822                },
823            },
824            BindGroupLayoutEntry {
825                binding: 1,
826                count: None,
827                visibility: ShaderStages::COMPUTE,
828                ty: BindingType::Buffer {
829                    has_dynamic_offset: false,
830                    min_binding_size: None,
831                    ty: BufferBindingType::Storage { read_only: false },
832                },
833            },
834            BindGroupLayoutEntry {
835                binding: 2,
836                count: None,
837                visibility: ShaderStages::COMPUTE,
838                ty: BindingType::Buffer {
839                    has_dynamic_offset: false,
840                    min_binding_size: None,
841                    ty: BufferBindingType::Storage { read_only: false },
842                },
843            },
844        ],
845    });
846
847    let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
848        label: Some("compute_f1"),
849        bind_group_layouts: &[&bind_group_layout],
850        push_constant_ranges: &[],
851    });
852
853    let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
854        compilation_options: PipelineCompilationOptions {
855            constants: &[],
856            zero_initialize_workgroup_memory: false,
857        },
858        cache: None,
859        label: Some("compute_f1"),
860        layout: Some(&pipeline_layout),
861        module,
862        entry_point: Some("compute_f1"),
863    });
864
865    let bind_group = device.create_bind_group(&BindGroupDescriptor {
866        label: Some("compute_f1"),
867        layout: &bind_group_layout,
868        entries: &[
869            BindGroupEntry {
870                binding: 0,
871                resource: initial_state_gpu.as_entire_binding(),
872            },
873            BindGroupEntry {
874                binding: 1,
875                resource: bucket_sizes_gpu.as_entire_binding(),
876            },
877            BindGroupEntry {
878                binding: 2,
879                resource: buckets_gpu.as_entire_binding(),
880            },
881        ],
882    });
883
884    (bind_group, compute_pipeline)
885}
886
887fn bind_group_and_pipeline_sort_buckets(
888    device: &wgpu::Device,
889    module: &ShaderModule,
890    bucket_sizes_gpu: &Buffer,
891    buckets_gpu: &Buffer,
892) -> (BindGroup, ComputePipeline) {
893    let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
894        label: Some("sort_buckets"),
895        entries: &[
896            BindGroupLayoutEntry {
897                binding: 0,
898                count: None,
899                visibility: ShaderStages::COMPUTE,
900                ty: BindingType::Buffer {
901                    has_dynamic_offset: false,
902                    min_binding_size: None,
903                    ty: BufferBindingType::Storage { read_only: false },
904                },
905            },
906            BindGroupLayoutEntry {
907                binding: 1,
908                count: None,
909                visibility: ShaderStages::COMPUTE,
910                ty: BindingType::Buffer {
911                    has_dynamic_offset: false,
912                    min_binding_size: None,
913                    ty: BufferBindingType::Storage { read_only: false },
914                },
915            },
916        ],
917    });
918
919    let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
920        label: Some("sort_buckets"),
921        bind_group_layouts: &[&bind_group_layout],
922        push_constant_ranges: &[],
923    });
924
925    let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
926        compilation_options: PipelineCompilationOptions {
927            constants: &[],
928            zero_initialize_workgroup_memory: false,
929        },
930        cache: None,
931        label: Some("sort_buckets"),
932        layout: Some(&pipeline_layout),
933        module,
934        entry_point: Some("sort_buckets"),
935    });
936
937    let bind_group = device.create_bind_group(&BindGroupDescriptor {
938        label: Some("sort_buckets"),
939        layout: &bind_group_layout,
940        entries: &[
941            BindGroupEntry {
942                binding: 0,
943                resource: bucket_sizes_gpu.as_entire_binding(),
944            },
945            BindGroupEntry {
946                binding: 1,
947                resource: buckets_gpu.as_entire_binding(),
948            },
949        ],
950    });
951
952    (bind_group, compute_pipeline)
953}
954
955fn bind_group_and_pipeline_find_matches_and_compute_f2(
956    device: &wgpu::Device,
957    module: &ShaderModule,
958    parent_buckets_gpu: &Buffer,
959    bucket_sizes_gpu: &Buffer,
960    buckets_gpu: &Buffer,
961    positions_gpu: &Buffer,
962    metadatas_gpu: &Buffer,
963) -> (BindGroup, ComputePipeline) {
964    let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
965        label: Some("find_matches_and_compute_f2"),
966        entries: &[
967            BindGroupLayoutEntry {
968                binding: 0,
969                count: None,
970                visibility: ShaderStages::COMPUTE,
971                ty: BindingType::Buffer {
972                    has_dynamic_offset: false,
973                    min_binding_size: None,
974                    ty: BufferBindingType::Storage { read_only: true },
975                },
976            },
977            BindGroupLayoutEntry {
978                binding: 1,
979                count: None,
980                visibility: ShaderStages::COMPUTE,
981                ty: BindingType::Buffer {
982                    has_dynamic_offset: false,
983                    min_binding_size: None,
984                    ty: BufferBindingType::Storage { read_only: false },
985                },
986            },
987            BindGroupLayoutEntry {
988                binding: 2,
989                count: None,
990                visibility: ShaderStages::COMPUTE,
991                ty: BindingType::Buffer {
992                    has_dynamic_offset: false,
993                    min_binding_size: None,
994                    ty: BufferBindingType::Storage { read_only: false },
995                },
996            },
997            BindGroupLayoutEntry {
998                binding: 3,
999                count: None,
1000                visibility: ShaderStages::COMPUTE,
1001                ty: BindingType::Buffer {
1002                    has_dynamic_offset: false,
1003                    min_binding_size: None,
1004                    ty: BufferBindingType::Storage { read_only: false },
1005                },
1006            },
1007            BindGroupLayoutEntry {
1008                binding: 4,
1009                count: None,
1010                visibility: ShaderStages::COMPUTE,
1011                ty: BindingType::Buffer {
1012                    has_dynamic_offset: false,
1013                    min_binding_size: None,
1014                    ty: BufferBindingType::Storage { read_only: false },
1015                },
1016            },
1017        ],
1018    });
1019
1020    let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1021        label: Some("find_matches_and_compute_f2"),
1022        bind_group_layouts: &[&bind_group_layout],
1023        push_constant_ranges: &[],
1024    });
1025
1026    let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1027        compilation_options: PipelineCompilationOptions {
1028            constants: &[],
1029            zero_initialize_workgroup_memory: true,
1030        },
1031        cache: None,
1032        label: Some("find_matches_and_compute_f2"),
1033        layout: Some(&pipeline_layout),
1034        module,
1035        entry_point: Some("find_matches_and_compute_f2"),
1036    });
1037
1038    let bind_group = device.create_bind_group(&BindGroupDescriptor {
1039        label: Some("find_matches_and_compute_f2"),
1040        layout: &bind_group_layout,
1041        entries: &[
1042            BindGroupEntry {
1043                binding: 0,
1044                resource: parent_buckets_gpu.as_entire_binding(),
1045            },
1046            BindGroupEntry {
1047                binding: 1,
1048                resource: bucket_sizes_gpu.as_entire_binding(),
1049            },
1050            BindGroupEntry {
1051                binding: 2,
1052                resource: buckets_gpu.as_entire_binding(),
1053            },
1054            BindGroupEntry {
1055                binding: 3,
1056                resource: positions_gpu.as_entire_binding(),
1057            },
1058            BindGroupEntry {
1059                binding: 4,
1060                resource: metadatas_gpu.as_entire_binding(),
1061            },
1062        ],
1063    });
1064
1065    (bind_group, compute_pipeline)
1066}
1067
1068#[expect(
1069    clippy::too_many_arguments,
1070    reason = "Both I/O and Vulkan stuff together take a lot of arguments"
1071)]
1072fn bind_group_and_pipeline_find_matches_and_compute_fn<const TABLE_NUMBER: u8>(
1073    device: &wgpu::Device,
1074    module: &ShaderModule,
1075    parent_buckets_gpu: &Buffer,
1076    parent_metadatas_gpu: &Buffer,
1077    bucket_sizes_gpu: &Buffer,
1078    buckets_gpu: &Buffer,
1079    positions_gpu: &Buffer,
1080    metadatas_gpu: &Buffer,
1081) -> (BindGroup, ComputePipeline) {
1082    let label = format!("find_matches_and_compute_f{TABLE_NUMBER}");
1083    let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
1084        label: Some(&label),
1085        entries: &[
1086            BindGroupLayoutEntry {
1087                binding: 0,
1088                count: None,
1089                visibility: ShaderStages::COMPUTE,
1090                ty: BindingType::Buffer {
1091                    has_dynamic_offset: false,
1092                    min_binding_size: None,
1093                    ty: BufferBindingType::Storage { read_only: true },
1094                },
1095            },
1096            BindGroupLayoutEntry {
1097                binding: 1,
1098                count: None,
1099                visibility: ShaderStages::COMPUTE,
1100                ty: BindingType::Buffer {
1101                    has_dynamic_offset: false,
1102                    min_binding_size: None,
1103                    ty: BufferBindingType::Storage { read_only: true },
1104                },
1105            },
1106            BindGroupLayoutEntry {
1107                binding: 2,
1108                count: None,
1109                visibility: ShaderStages::COMPUTE,
1110                ty: BindingType::Buffer {
1111                    has_dynamic_offset: false,
1112                    min_binding_size: None,
1113                    ty: BufferBindingType::Storage { read_only: false },
1114                },
1115            },
1116            BindGroupLayoutEntry {
1117                binding: 3,
1118                count: None,
1119                visibility: ShaderStages::COMPUTE,
1120                ty: BindingType::Buffer {
1121                    has_dynamic_offset: false,
1122                    min_binding_size: None,
1123                    ty: BufferBindingType::Storage { read_only: false },
1124                },
1125            },
1126            BindGroupLayoutEntry {
1127                binding: 4,
1128                count: None,
1129                visibility: ShaderStages::COMPUTE,
1130                ty: BindingType::Buffer {
1131                    has_dynamic_offset: false,
1132                    min_binding_size: None,
1133                    ty: BufferBindingType::Storage { read_only: false },
1134                },
1135            },
1136            BindGroupLayoutEntry {
1137                binding: 5,
1138                count: None,
1139                visibility: ShaderStages::COMPUTE,
1140                ty: BindingType::Buffer {
1141                    has_dynamic_offset: false,
1142                    min_binding_size: None,
1143                    ty: BufferBindingType::Storage { read_only: false },
1144                },
1145            },
1146        ],
1147    });
1148
1149    let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1150        label: Some(&label),
1151        bind_group_layouts: &[&bind_group_layout],
1152        push_constant_ranges: &[],
1153    });
1154
1155    let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1156        compilation_options: PipelineCompilationOptions {
1157            constants: &[],
1158            zero_initialize_workgroup_memory: true,
1159        },
1160        cache: None,
1161        label: Some(&label),
1162        layout: Some(&pipeline_layout),
1163        module,
1164        entry_point: Some(&format!("find_matches_and_compute_f{TABLE_NUMBER}")),
1165    });
1166
1167    let bind_group = device.create_bind_group(&BindGroupDescriptor {
1168        label: Some(&label),
1169        layout: &bind_group_layout,
1170        entries: &[
1171            BindGroupEntry {
1172                binding: 0,
1173                resource: parent_buckets_gpu.as_entire_binding(),
1174            },
1175            BindGroupEntry {
1176                binding: 1,
1177                resource: parent_metadatas_gpu.as_entire_binding(),
1178            },
1179            BindGroupEntry {
1180                binding: 2,
1181                resource: bucket_sizes_gpu.as_entire_binding(),
1182            },
1183            BindGroupEntry {
1184                binding: 3,
1185                resource: buckets_gpu.as_entire_binding(),
1186            },
1187            BindGroupEntry {
1188                binding: 4,
1189                resource: positions_gpu.as_entire_binding(),
1190            },
1191            BindGroupEntry {
1192                binding: 5,
1193                resource: metadatas_gpu.as_entire_binding(),
1194            },
1195        ],
1196    });
1197
1198    (bind_group, compute_pipeline)
1199}
1200
1201fn bind_group_and_pipeline_find_matches_and_compute_f7(
1202    device: &wgpu::Device,
1203    module: &ShaderModule,
1204    parent_buckets_gpu: &Buffer,
1205    parent_metadatas_gpu: &Buffer,
1206    table_6_proof_targets_sizes_gpu: &Buffer,
1207    table_6_proof_targets_gpu: &Buffer,
1208) -> (BindGroup, ComputePipeline) {
1209    let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
1210        label: Some("find_matches_and_compute_f7"),
1211        entries: &[
1212            BindGroupLayoutEntry {
1213                binding: 0,
1214                count: None,
1215                visibility: ShaderStages::COMPUTE,
1216                ty: BindingType::Buffer {
1217                    has_dynamic_offset: false,
1218                    min_binding_size: None,
1219                    ty: BufferBindingType::Storage { read_only: true },
1220                },
1221            },
1222            BindGroupLayoutEntry {
1223                binding: 1,
1224                count: None,
1225                visibility: ShaderStages::COMPUTE,
1226                ty: BindingType::Buffer {
1227                    has_dynamic_offset: false,
1228                    min_binding_size: None,
1229                    ty: BufferBindingType::Storage { read_only: true },
1230                },
1231            },
1232            BindGroupLayoutEntry {
1233                binding: 2,
1234                count: None,
1235                visibility: ShaderStages::COMPUTE,
1236                ty: BindingType::Buffer {
1237                    has_dynamic_offset: false,
1238                    min_binding_size: None,
1239                    ty: BufferBindingType::Storage { read_only: false },
1240                },
1241            },
1242            BindGroupLayoutEntry {
1243                binding: 3,
1244                count: None,
1245                visibility: ShaderStages::COMPUTE,
1246                ty: BindingType::Buffer {
1247                    has_dynamic_offset: false,
1248                    min_binding_size: None,
1249                    ty: BufferBindingType::Storage { read_only: false },
1250                },
1251            },
1252        ],
1253    });
1254
1255    let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1256        label: Some("find_matches_and_compute_f7"),
1257        bind_group_layouts: &[&bind_group_layout],
1258        push_constant_ranges: &[],
1259    });
1260
1261    let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1262        compilation_options: PipelineCompilationOptions {
1263            constants: &[],
1264            zero_initialize_workgroup_memory: true,
1265        },
1266        cache: None,
1267        label: Some("find_matches_and_compute_f7"),
1268        layout: Some(&pipeline_layout),
1269        module,
1270        entry_point: Some("find_matches_and_compute_f7"),
1271    });
1272
1273    let bind_group = device.create_bind_group(&BindGroupDescriptor {
1274        label: Some("find_matches_and_compute_f7"),
1275        layout: &bind_group_layout,
1276        entries: &[
1277            BindGroupEntry {
1278                binding: 0,
1279                resource: parent_buckets_gpu.as_entire_binding(),
1280            },
1281            BindGroupEntry {
1282                binding: 1,
1283                resource: parent_metadatas_gpu.as_entire_binding(),
1284            },
1285            BindGroupEntry {
1286                binding: 2,
1287                resource: table_6_proof_targets_sizes_gpu.as_entire_binding(),
1288            },
1289            BindGroupEntry {
1290                binding: 3,
1291                resource: table_6_proof_targets_gpu.as_entire_binding(),
1292            },
1293        ],
1294    });
1295
1296    (bind_group, compute_pipeline)
1297}
1298
1299#[expect(
1300    clippy::too_many_arguments,
1301    reason = "Both I/O and Vulkan stuff together take a lot of arguments"
1302)]
1303fn bind_group_and_pipeline_find_proofs(
1304    device: &wgpu::Device,
1305    module: &ShaderModule,
1306    table_2_positions_gpu: &Buffer,
1307    table_3_positions_gpu: &Buffer,
1308    table_4_positions_gpu: &Buffer,
1309    table_5_positions_gpu: &Buffer,
1310    table_6_positions_gpu: &Buffer,
1311    bucket_sizes_gpu: &Buffer,
1312    buckets_gpu: &Buffer,
1313    proofs_gpu: &Buffer,
1314) -> (BindGroup, ComputePipeline) {
1315    let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
1316        label: Some("find_proofs"),
1317        entries: &[
1318            BindGroupLayoutEntry {
1319                binding: 0,
1320                count: None,
1321                visibility: ShaderStages::COMPUTE,
1322                ty: BindingType::Buffer {
1323                    has_dynamic_offset: false,
1324                    min_binding_size: None,
1325                    ty: BufferBindingType::Storage { read_only: true },
1326                },
1327            },
1328            BindGroupLayoutEntry {
1329                binding: 1,
1330                count: None,
1331                visibility: ShaderStages::COMPUTE,
1332                ty: BindingType::Buffer {
1333                    has_dynamic_offset: false,
1334                    min_binding_size: None,
1335                    ty: BufferBindingType::Storage { read_only: true },
1336                },
1337            },
1338            BindGroupLayoutEntry {
1339                binding: 2,
1340                count: None,
1341                visibility: ShaderStages::COMPUTE,
1342                ty: BindingType::Buffer {
1343                    has_dynamic_offset: false,
1344                    min_binding_size: None,
1345                    ty: BufferBindingType::Storage { read_only: true },
1346                },
1347            },
1348            BindGroupLayoutEntry {
1349                binding: 3,
1350                count: None,
1351                visibility: ShaderStages::COMPUTE,
1352                ty: BindingType::Buffer {
1353                    has_dynamic_offset: false,
1354                    min_binding_size: None,
1355                    ty: BufferBindingType::Storage { read_only: true },
1356                },
1357            },
1358            BindGroupLayoutEntry {
1359                binding: 4,
1360                count: None,
1361                visibility: ShaderStages::COMPUTE,
1362                ty: BindingType::Buffer {
1363                    has_dynamic_offset: false,
1364                    min_binding_size: None,
1365                    ty: BufferBindingType::Storage { read_only: true },
1366                },
1367            },
1368            BindGroupLayoutEntry {
1369                binding: 5,
1370                count: None,
1371                visibility: ShaderStages::COMPUTE,
1372                ty: BindingType::Buffer {
1373                    has_dynamic_offset: false,
1374                    min_binding_size: None,
1375                    ty: BufferBindingType::Storage { read_only: false },
1376                },
1377            },
1378            BindGroupLayoutEntry {
1379                binding: 6,
1380                count: None,
1381                visibility: ShaderStages::COMPUTE,
1382                ty: BindingType::Buffer {
1383                    has_dynamic_offset: false,
1384                    min_binding_size: None,
1385                    ty: BufferBindingType::Storage { read_only: true },
1386                },
1387            },
1388            BindGroupLayoutEntry {
1389                binding: 7,
1390                count: None,
1391                visibility: ShaderStages::COMPUTE,
1392                ty: BindingType::Buffer {
1393                    has_dynamic_offset: false,
1394                    min_binding_size: None,
1395                    ty: BufferBindingType::Storage { read_only: false },
1396                },
1397            },
1398        ],
1399    });
1400
1401    let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1402        label: Some("find_proofs"),
1403        bind_group_layouts: &[&bind_group_layout],
1404        push_constant_ranges: &[],
1405    });
1406
1407    let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1408        compilation_options: PipelineCompilationOptions {
1409            constants: &[],
1410            zero_initialize_workgroup_memory: false,
1411        },
1412        cache: None,
1413        label: Some("find_proofs"),
1414        layout: Some(&pipeline_layout),
1415        module,
1416        entry_point: Some("find_proofs"),
1417    });
1418
1419    let bind_group = device.create_bind_group(&BindGroupDescriptor {
1420        label: Some("find_proofs"),
1421        layout: &bind_group_layout,
1422        entries: &[
1423            BindGroupEntry {
1424                binding: 0,
1425                resource: table_2_positions_gpu.as_entire_binding(),
1426            },
1427            BindGroupEntry {
1428                binding: 1,
1429                resource: table_3_positions_gpu.as_entire_binding(),
1430            },
1431            BindGroupEntry {
1432                binding: 2,
1433                resource: table_4_positions_gpu.as_entire_binding(),
1434            },
1435            BindGroupEntry {
1436                binding: 3,
1437                resource: table_5_positions_gpu.as_entire_binding(),
1438            },
1439            BindGroupEntry {
1440                binding: 4,
1441                resource: table_6_positions_gpu.as_entire_binding(),
1442            },
1443            BindGroupEntry {
1444                binding: 5,
1445                resource: bucket_sizes_gpu.as_entire_binding(),
1446            },
1447            BindGroupEntry {
1448                binding: 6,
1449                resource: buckets_gpu.as_entire_binding(),
1450            },
1451            BindGroupEntry {
1452                binding: 7,
1453                resource: proofs_gpu.as_entire_binding(),
1454            },
1455        ],
1456    });
1457
1458    (bind_group, compute_pipeline)
1459}