Skip to main content

ab_proof_of_space_gpu/
host.rs

1#[cfg(all(test, not(miri)))]
2mod tests;
3
4use crate::shader::constants::{
5    MAX_BUCKET_SIZE, MAX_TABLE_SIZE, NUM_BUCKETS, NUM_MATCH_BUCKETS, NUM_S_BUCKETS,
6    REDUCED_MATCHES_COUNT,
7};
8use crate::shader::find_matches_and_compute_f7::{NUM_ELEMENTS_PER_S_BUCKET, ProofTargets};
9use crate::shader::find_proofs::ProofsHost;
10use crate::shader::types::{Metadata, Position, PositionR};
11use crate::shader::{compute_f1, find_proofs, select_shader_features_limits};
12use ab_chacha8::{ChaCha8Block, ChaCha8State, block_to_bytes};
13use ab_core_primitives::pieces::{PieceOffset, Record};
14use ab_core_primitives::pos::PosSeed;
15use ab_core_primitives::sectors::SectorId;
16use ab_erasure_coding::ErasureCoding;
17use ab_farmer_components::plotting::RecordsEncoder;
18use ab_farmer_components::sector::SectorContentsMap;
19use async_lock::Mutex as AsyncMutex;
20use futures::stream::FuturesOrdered;
21use futures::{StreamExt, TryStreamExt};
22use parking_lot::Mutex;
23use rayon::prelude::*;
24use rayon::{ThreadPool, ThreadPoolBuilder};
25use rclite::Arc;
26use std::num::NonZeroU8;
27use std::simd::Simd;
28use std::sync::Arc as StdArc;
29use std::sync::atomic::{AtomicBool, Ordering};
30use std::{fmt, iter};
31use tracing::{debug, warn};
32use wgpu::{
33    AdapterInfo, Backend, BackendOptions, Backends, BindGroup, BindGroupDescriptor, BindGroupEntry,
34    BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType, Buffer, BufferAddress,
35    BufferAsyncError, BufferBindingType, BufferDescriptor, BufferUsages, CommandEncoderDescriptor,
36    ComputePassDescriptor, ComputePipeline, ComputePipelineDescriptor, DeviceDescriptor,
37    DeviceType, Instance, InstanceDescriptor, InstanceFlags, MapMode, MemoryBudgetThresholds,
38    PipelineCompilationOptions, PipelineLayoutDescriptor, PollError, PollType, Queue,
39    RequestDeviceError, ShaderModule, ShaderStages,
40};
41
42/// Proof creation error
43#[derive(Debug, thiserror::Error)]
44enum RecordEncodingError {
45    /// Too many records
46    #[error("Too many records: {0}")]
47    TooManyRecords(usize),
48    /// Proof creation failed previously and the device is now considered broken
49    #[error("Proof creation failed previously and the device is now considered broken")]
50    DeviceBroken,
51    /// Failed to map buffer
52    #[error("Failed to map buffer: {0}")]
53    BufferMapping(#[from] BufferAsyncError),
54    /// Poll error
55    #[error("Poll error: {0}")]
56    DevicePoll(#[from] PollError),
57}
58
59struct ProofsHostWrapper<'a> {
60    proofs: &'a ProofsHost,
61    proofs_host: &'a Buffer,
62}
63
64impl Drop for ProofsHostWrapper<'_> {
65    fn drop(&mut self) {
66        self.proofs_host.unmap();
67    }
68}
69
70/// Wrapper data structure encapsulating a single compatible device
71#[derive(Clone)]
72pub struct Device {
73    id: u32,
74    devices: Vec<(wgpu::Device, Queue, ShaderModule)>,
75    adapter_info: AdapterInfo,
76}
77
78impl fmt::Debug for Device {
79    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80        f.debug_struct("Device")
81            .field("id", &self.id)
82            .field("name", &self.adapter_info.name)
83            .field("device_type", &self.adapter_info.device_type)
84            .field("driver", &self.adapter_info.driver)
85            .field("driver_info", &self.adapter_info.driver_info)
86            .field("backend", &self.adapter_info.backend)
87            .finish_non_exhaustive()
88    }
89}
90
91impl Device {
92    /// Returns [`Device`] for each available device
93    pub async fn enumerate<NOQ>(number_of_queues: NOQ) -> Vec<Self>
94    where
95        NOQ: Fn(DeviceType) -> NonZeroU8,
96    {
97        let backends = Backends::from_env().unwrap_or(Backends::METAL | Backends::VULKAN);
98        let instance = Instance::new(InstanceDescriptor {
99            backends,
100            flags: if cfg!(debug_assertions) {
101                InstanceFlags::debugging().with_env()
102            } else {
103                InstanceFlags::from_env_or_default()
104            },
105            memory_budget_thresholds: MemoryBudgetThresholds::default(),
106            backend_options: BackendOptions::from_env_or_default(),
107            display: None,
108        });
109
110        let adapters = instance.enumerate_adapters(backends).await;
111        let number_of_queues = &number_of_queues;
112
113        adapters
114            .into_iter()
115            .zip(0..)
116            .map(|(adapter, id)| async move {
117                let adapter_info = adapter.get_info();
118
119                let (shader, required_features, required_limits) =
120                    if let Some((shader, required_features, required_limits)) =
121                        select_shader_features_limits(&adapter)
122                    {
123                        debug!(
124                            %id,
125                            adapter_info = ?adapter_info,
126                            "Compatible adapter found"
127                        );
128
129                        (shader, required_features, required_limits)
130                    } else {
131                        debug!(
132                            %id,
133                            adapter_info = ?adapter_info,
134                            "Incompatible adapter found"
135                        );
136
137                        return None;
138                    };
139
140                // TODO: creation of multiple devices is a workaround for lack of support for
141                //  multiple queues: https://github.com/gfx-rs/wgpu/issues/1066
142                let devices = iter::repeat_with(|| async {
143                    let (device, queue) = adapter
144                        .request_device(&DeviceDescriptor {
145                            label: None,
146                            required_features,
147                            required_limits: required_limits.clone(),
148                            ..DeviceDescriptor::default()
149                        })
150                        .await
151                        .inspect_err(|error| {
152                            warn!(%id, ?adapter_info, %error, "Failed to request the device");
153                        })?;
154                    let module = device.create_shader_module(shader.clone());
155
156                    Ok::<_, RequestDeviceError>((device, queue, module))
157                })
158                .take(usize::from(
159                    number_of_queues(adapter_info.device_type).get(),
160                ))
161                .collect::<FuturesOrdered<_>>()
162                .try_collect::<Vec<_>>()
163                .await
164                .ok()?;
165
166                Some(Self {
167                    id,
168                    devices,
169                    adapter_info,
170                })
171            })
172            .collect::<FuturesOrdered<_>>()
173            .filter_map(|device| async move { device })
174            .collect()
175            .await
176    }
177
178    /// Gpu ID
179    pub fn id(&self) -> u32 {
180        self.id
181    }
182
183    /// Device name
184    pub fn name(&self) -> &str {
185        &self.adapter_info.name
186    }
187
188    /// Device type
189    pub fn device_type(&self) -> DeviceType {
190        self.adapter_info.device_type
191    }
192
193    /// Driver
194    pub fn driver(&self) -> &str {
195        &self.adapter_info.driver
196    }
197
198    /// Driver info
199    pub fn driver_info(&self) -> &str {
200        &self.adapter_info.driver_info
201    }
202
203    /// Backend
204    pub fn backend(&self) -> Backend {
205        self.adapter_info.backend
206    }
207
208    pub fn instantiate(
209        &self,
210        erasure_coding: ErasureCoding,
211        global_mutex: StdArc<AsyncMutex<()>>,
212    ) -> anyhow::Result<GpuRecordsEncoder> {
213        GpuRecordsEncoder::new(
214            self.id,
215            self.devices.clone(),
216            self.adapter_info.clone(),
217            erasure_coding,
218            global_mutex,
219        )
220    }
221}
222
223pub struct GpuRecordsEncoder {
224    id: u32,
225    instances: Vec<Mutex<GpuRecordsEncoderInstance>>,
226    thread_pool: ThreadPool,
227    adapter_info: AdapterInfo,
228    erasure_coding: ErasureCoding,
229    global_mutex: StdArc<AsyncMutex<()>>,
230}
231
232impl fmt::Debug for GpuRecordsEncoder {
233    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234        f.debug_struct("GpuRecordsEncoder")
235            .field("id", &self.id)
236            .field("name", &self.adapter_info.name)
237            .field("device_type", &self.adapter_info.device_type)
238            .field("driver", &self.adapter_info.driver)
239            .field("driver_info", &self.adapter_info.driver_info)
240            .field("backend", &self.adapter_info.backend)
241            .finish_non_exhaustive()
242    }
243}
244
245impl RecordsEncoder for GpuRecordsEncoder {
246    // TODO: Run more than one encoding per device concurrently
247    fn encode_records(
248        &mut self,
249        sector_id: &SectorId,
250        records: &mut [Record],
251        abort_early: &AtomicBool,
252    ) -> anyhow::Result<SectorContentsMap> {
253        let mut sector_contents_map = SectorContentsMap::new(
254            u16::try_from(records.len())
255                .map_err(|_error| RecordEncodingError::TooManyRecords(records.len()))?,
256        );
257
258        let maybe_error = self.thread_pool.install(|| {
259            records
260                .par_iter_mut()
261                .zip(sector_contents_map.iter_record_chunks_used_mut())
262                .enumerate()
263                .find_map_any(|(piece_offset, (record, record_chunks_used))| {
264                    // Take mutex briefly to make sure encoding is allowed right now
265                    self.global_mutex.lock_blocking();
266
267                    let mut parity_record_chunks = Record::new_boxed();
268
269                    // TODO: Do erasure coding on the GPU
270                    // Erasure code source record chunks
271                    self.erasure_coding
272                        .extend(record.iter(), parity_record_chunks.iter_mut())
273                        .expect("Statically guaranteed valid inputs; qed");
274
275                    if abort_early.load(Ordering::Relaxed) {
276                        return None;
277                    }
278                    let seed =
279                        sector_id.derive_evaluation_seed(PieceOffset::from(piece_offset as u16));
280                    let thread_index = rayon::current_thread_index().unwrap_or_default();
281                    let mut encoder_instance = self.instances[thread_index]
282                        .try_lock()
283                        .expect("1:1 mapping between threads and devices; qed");
284                    let proofs = match encoder_instance.create_proofs(&seed) {
285                        Ok(proofs) => proofs,
286                        Err(error) => {
287                            return Some(error);
288                        }
289                    };
290                    let proofs = proofs.proofs;
291
292                    *record_chunks_used = proofs.found_proofs;
293
294                    // TODO: Record encoding on the GPU
295                    let mut num_found_proofs = 0_usize;
296                    for (s_buckets, found_proofs) in (0..Record::NUM_S_BUCKETS)
297                        .array_chunks::<{ u8::BITS as usize }>()
298                        .zip(record_chunks_used)
299                    {
300                        for (proof_offset, s_bucket) in s_buckets.into_iter().enumerate() {
301                            if num_found_proofs == Record::NUM_CHUNKS {
302                                // Enough proofs collected, clear the rest of the bits
303                                *found_proofs &=
304                                    u8::MAX.unbounded_shr(u8::BITS - proof_offset as u32);
305                                break;
306                            }
307                            if (*found_proofs & (1 << proof_offset)) != 0 {
308                                let record_chunk = if s_bucket < Record::NUM_CHUNKS {
309                                    record[s_bucket]
310                                } else {
311                                    parity_record_chunks[s_bucket - Record::NUM_CHUNKS]
312                                };
313
314                                record[num_found_proofs] = (Simd::from(record_chunk)
315                                    ^ Simd::from(*proofs.proofs[s_bucket].hash()))
316                                .to_array();
317                                num_found_proofs += 1;
318                            }
319                        }
320                    }
321
322                    None
323                })
324        });
325
326        if let Some(error) = maybe_error {
327            return Err(error.into());
328        }
329
330        Ok(sector_contents_map)
331    }
332}
333
334impl GpuRecordsEncoder {
335    fn new(
336        id: u32,
337        devices: Vec<(wgpu::Device, Queue, ShaderModule)>,
338        adapter_info: AdapterInfo,
339        erasure_coding: ErasureCoding,
340        global_mutex: StdArc<AsyncMutex<()>>,
341    ) -> anyhow::Result<Self> {
342        let thread_pool = ThreadPoolBuilder::new()
343            .thread_name(move |thread_index| format!("pos-gpu-{id}.{thread_index}"))
344            .num_threads(devices.len())
345            .build()?;
346
347        Ok(Self {
348            id,
349            instances: devices
350                .into_iter()
351                .map(|(device, queue, module)| {
352                    Mutex::new(GpuRecordsEncoderInstance::new(device, queue, module))
353                })
354                .collect(),
355            thread_pool,
356            adapter_info,
357            erasure_coding,
358            global_mutex,
359        })
360    }
361}
362
363struct GpuRecordsEncoderInstance {
364    device: wgpu::Device,
365    queue: Queue,
366    mapping_error: Arc<Mutex<Option<BufferAsyncError>>>,
367    tainted: bool,
368    initial_state_host: Buffer,
369    initial_state_gpu: Buffer,
370    proofs_host: Buffer,
371    proofs_gpu: Buffer,
372    bind_group_compute_f1: BindGroup,
373    compute_pipeline_compute_f1: ComputePipeline,
374    bind_group_sort_buckets_a: BindGroup,
375    compute_pipeline_sort_buckets_a: ComputePipeline,
376    bind_group_sort_buckets_b: BindGroup,
377    compute_pipeline_sort_buckets_b: ComputePipeline,
378    bind_group_find_matches_and_compute_f2: BindGroup,
379    compute_pipeline_find_matches_and_compute_f2: ComputePipeline,
380    bind_group_find_matches_and_compute_f3: BindGroup,
381    compute_pipeline_find_matches_and_compute_f3: ComputePipeline,
382    bind_group_find_matches_and_compute_f4: BindGroup,
383    compute_pipeline_find_matches_and_compute_f4: ComputePipeline,
384    bind_group_find_matches_and_compute_f5: BindGroup,
385    compute_pipeline_find_matches_and_compute_f5: ComputePipeline,
386    bind_group_find_matches_and_compute_f6: BindGroup,
387    compute_pipeline_find_matches_and_compute_f6: ComputePipeline,
388    bind_group_find_matches_and_compute_f7: BindGroup,
389    compute_pipeline_find_matches_and_compute_f7: ComputePipeline,
390    bind_group_find_proofs: BindGroup,
391    compute_pipeline_find_proofs: ComputePipeline,
392}
393
394impl GpuRecordsEncoderInstance {
395    fn new(device: wgpu::Device, queue: Queue, module: ShaderModule) -> Self {
396        let initial_state_host = device.create_buffer(&BufferDescriptor {
397            label: Some("initial_state_host"),
398            size: size_of::<ChaCha8Block>() as BufferAddress,
399            usage: BufferUsages::MAP_WRITE | BufferUsages::COPY_SRC,
400            mapped_at_creation: true,
401        });
402
403        let initial_state_gpu = device.create_buffer(&BufferDescriptor {
404            label: Some("initial_state_gpu"),
405            size: initial_state_host.size(),
406            usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST,
407            mapped_at_creation: false,
408        });
409
410        let bucket_sizes_gpu_buffer_size = size_of::<[u32; NUM_BUCKETS]>() as BufferAddress;
411        let table_6_proof_targets_sizes_gpu_buffer_size =
412            size_of::<[u32; NUM_S_BUCKETS]>() as BufferAddress;
413        // TODO: Sizes are excessive, for `bucket_sizes_gpu` are less than `u16` and could use SWAR
414        //  approach for storing bucket sizes. Similarly, `table_6_proof_targets_sizes_gpu` sizes
415        //  are less than `u8` and could use SWAR too with even higher compression ratio
416        let bucket_sizes_gpu = device.create_buffer(&BufferDescriptor {
417            label: Some("bucket_sizes_gpu"),
418            size: bucket_sizes_gpu_buffer_size.max(table_6_proof_targets_sizes_gpu_buffer_size),
419            usage: BufferUsages::STORAGE,
420            mapped_at_creation: false,
421        });
422        // Reuse the same buffer as `bucket_sizes_gpu`, they are not overlapping in use
423        let table_6_proof_targets_sizes_gpu = bucket_sizes_gpu.clone();
424
425        let buckets_a_gpu = device.create_buffer(&BufferDescriptor {
426            label: Some("buckets_a_gpu"),
427            size: size_of::<[[PositionR; MAX_BUCKET_SIZE]; NUM_BUCKETS]>() as BufferAddress,
428            usage: BufferUsages::STORAGE,
429            mapped_at_creation: false,
430        });
431
432        let buckets_b_gpu = device.create_buffer(&BufferDescriptor {
433            label: Some("buckets_b_gpu"),
434            size: buckets_a_gpu.size(),
435            usage: BufferUsages::STORAGE,
436            mapped_at_creation: false,
437        });
438
439        let positions_f2_gpu = device.create_buffer(&BufferDescriptor {
440            label: Some("positions_f2_gpu"),
441            size: size_of::<[[[Position; 2]; REDUCED_MATCHES_COUNT]; NUM_MATCH_BUCKETS]>()
442                as BufferAddress,
443            usage: BufferUsages::STORAGE,
444            mapped_at_creation: false,
445        });
446
447        let positions_f3_gpu = device.create_buffer(&BufferDescriptor {
448            label: Some("positions_f3_gpu"),
449            size: positions_f2_gpu.size(),
450            usage: BufferUsages::STORAGE,
451            mapped_at_creation: false,
452        });
453
454        let positions_f4_gpu = device.create_buffer(&BufferDescriptor {
455            label: Some("positions_f4_gpu"),
456            size: positions_f2_gpu.size(),
457            usage: BufferUsages::STORAGE,
458            mapped_at_creation: false,
459        });
460
461        let positions_f5_gpu = device.create_buffer(&BufferDescriptor {
462            label: Some("positions_f5_gpu"),
463            size: positions_f2_gpu.size(),
464            usage: BufferUsages::STORAGE,
465            mapped_at_creation: false,
466        });
467
468        let positions_f6_gpu = device.create_buffer(&BufferDescriptor {
469            label: Some("positions_f6_gpu"),
470            size: positions_f2_gpu.size(),
471            usage: BufferUsages::STORAGE,
472            mapped_at_creation: false,
473        });
474
475        let metadatas_gpu_buffer_size =
476            size_of::<[[Metadata; REDUCED_MATCHES_COUNT]; NUM_MATCH_BUCKETS]>() as BufferAddress;
477        let table_6_proof_targets_gpu_buffer_size = size_of::<
478            [[ProofTargets; NUM_ELEMENTS_PER_S_BUCKET]; NUM_S_BUCKETS],
479        >() as BufferAddress;
480        let metadatas_a_gpu = device.create_buffer(&BufferDescriptor {
481            label: Some("metadatas_a_gpu"),
482            size: metadatas_gpu_buffer_size.max(table_6_proof_targets_gpu_buffer_size),
483            usage: BufferUsages::STORAGE,
484            mapped_at_creation: false,
485        });
486        // Reuse the same buffer as `metadatas_a_gpu`, they are not overlapping in use
487        let table_6_proof_targets_gpu = metadatas_a_gpu.clone();
488
489        let metadatas_b_gpu = device.create_buffer(&BufferDescriptor {
490            label: Some("metadatas_b_gpu"),
491            size: metadatas_gpu_buffer_size,
492            usage: BufferUsages::STORAGE,
493            mapped_at_creation: false,
494        });
495
496        let proofs_host = device.create_buffer(&BufferDescriptor {
497            label: Some("proofs_host"),
498            size: size_of::<ProofsHost>() as BufferAddress,
499            usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
500            mapped_at_creation: false,
501        });
502
503        let proofs_gpu = device.create_buffer(&BufferDescriptor {
504            label: Some("proofs_gpu"),
505            size: proofs_host.size(),
506            usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
507            mapped_at_creation: false,
508        });
509
510        let (bind_group_compute_f1, compute_pipeline_compute_f1) =
511            bind_group_and_pipeline_compute_f1(
512                &device,
513                &module,
514                &initial_state_gpu,
515                &bucket_sizes_gpu,
516                &buckets_a_gpu,
517            );
518        let (bind_group_sort_buckets_a, compute_pipeline_sort_buckets_a) =
519            bind_group_and_pipeline_sort_buckets(
520                &device,
521                &module,
522                &bucket_sizes_gpu,
523                &buckets_a_gpu,
524            );
525
526        let (bind_group_sort_buckets_b, compute_pipeline_sort_buckets_b) =
527            bind_group_and_pipeline_sort_buckets(
528                &device,
529                &module,
530                &bucket_sizes_gpu,
531                &buckets_b_gpu,
532            );
533
534        let (bind_group_find_matches_and_compute_f2, compute_pipeline_find_matches_and_compute_f2) =
535            bind_group_and_pipeline_find_matches_and_compute_f2(
536                &device,
537                &module,
538                &buckets_a_gpu,
539                &bucket_sizes_gpu,
540                &buckets_b_gpu,
541                &positions_f2_gpu,
542                &metadatas_b_gpu,
543            );
544
545        let (bind_group_find_matches_and_compute_f3, compute_pipeline_find_matches_and_compute_f3) =
546            bind_group_and_pipeline_find_matches_and_compute_fn::<3>(
547                &device,
548                &module,
549                &buckets_b_gpu,
550                &metadatas_b_gpu,
551                &bucket_sizes_gpu,
552                &buckets_a_gpu,
553                &positions_f3_gpu,
554                &metadatas_a_gpu,
555            );
556
557        let (bind_group_find_matches_and_compute_f4, compute_pipeline_find_matches_and_compute_f4) =
558            bind_group_and_pipeline_find_matches_and_compute_fn::<4>(
559                &device,
560                &module,
561                &buckets_a_gpu,
562                &metadatas_a_gpu,
563                &bucket_sizes_gpu,
564                &buckets_b_gpu,
565                &positions_f4_gpu,
566                &metadatas_b_gpu,
567            );
568
569        let (bind_group_find_matches_and_compute_f5, compute_pipeline_find_matches_and_compute_f5) =
570            bind_group_and_pipeline_find_matches_and_compute_fn::<5>(
571                &device,
572                &module,
573                &buckets_b_gpu,
574                &metadatas_b_gpu,
575                &bucket_sizes_gpu,
576                &buckets_a_gpu,
577                &positions_f5_gpu,
578                &metadatas_a_gpu,
579            );
580
581        let (bind_group_find_matches_and_compute_f6, compute_pipeline_find_matches_and_compute_f6) =
582            bind_group_and_pipeline_find_matches_and_compute_fn::<6>(
583                &device,
584                &module,
585                &buckets_a_gpu,
586                &metadatas_a_gpu,
587                &bucket_sizes_gpu,
588                &buckets_b_gpu,
589                &positions_f6_gpu,
590                &metadatas_b_gpu,
591            );
592
593        let (bind_group_find_matches_and_compute_f7, compute_pipeline_find_matches_and_compute_f7) =
594            bind_group_and_pipeline_find_matches_and_compute_f7(
595                &device,
596                &module,
597                &buckets_b_gpu,
598                &metadatas_b_gpu,
599                &table_6_proof_targets_sizes_gpu,
600                &table_6_proof_targets_gpu,
601            );
602
603        let (bind_group_find_proofs, compute_pipeline_find_proofs) =
604            bind_group_and_pipeline_find_proofs(
605                &device,
606                &module,
607                &positions_f2_gpu,
608                &positions_f3_gpu,
609                &positions_f4_gpu,
610                &positions_f5_gpu,
611                &positions_f6_gpu,
612                &table_6_proof_targets_sizes_gpu,
613                &table_6_proof_targets_gpu,
614                &proofs_gpu,
615            );
616
617        Self {
618            device,
619            queue,
620            mapping_error: Arc::new(Mutex::new(None)),
621            tainted: false,
622            initial_state_host,
623            initial_state_gpu,
624            proofs_host,
625            proofs_gpu,
626            bind_group_compute_f1,
627            compute_pipeline_compute_f1,
628            bind_group_sort_buckets_a,
629            compute_pipeline_sort_buckets_a,
630            bind_group_sort_buckets_b,
631            compute_pipeline_sort_buckets_b,
632            bind_group_find_matches_and_compute_f2,
633            compute_pipeline_find_matches_and_compute_f2,
634            bind_group_find_matches_and_compute_f3,
635            compute_pipeline_find_matches_and_compute_f3,
636            bind_group_find_matches_and_compute_f4,
637            compute_pipeline_find_matches_and_compute_f4,
638            bind_group_find_matches_and_compute_f5,
639            compute_pipeline_find_matches_and_compute_f5,
640            bind_group_find_matches_and_compute_f6,
641            compute_pipeline_find_matches_and_compute_f6,
642            bind_group_find_matches_and_compute_f7,
643            compute_pipeline_find_matches_and_compute_f7,
644            bind_group_find_proofs,
645            compute_pipeline_find_proofs,
646        }
647    }
648
649    fn create_proofs(
650        &mut self,
651        seed: &PosSeed,
652    ) -> Result<ProofsHostWrapper<'_>, RecordEncodingError> {
653        if self.tainted {
654            return Err(RecordEncodingError::DeviceBroken);
655        }
656        self.tainted = true;
657
658        let mut encoder = self
659            .device
660            .create_command_encoder(&CommandEncoderDescriptor {
661                label: Some("create_proofs"),
662            });
663
664        // Mapped initially and re-mapped at the end of the computation
665        self.initial_state_host
666            .get_mapped_range_mut(..)
667            .copy_from_slice(&block_to_bytes(
668                &ChaCha8State::init(seed, &[0; _]).to_repr(),
669            ));
670        self.initial_state_host.unmap();
671
672        encoder.copy_buffer_to_buffer(
673            &self.initial_state_host,
674            0,
675            &self.initial_state_gpu,
676            0,
677            self.initial_state_host.size(),
678        );
679
680        {
681            let mut cpass = encoder.begin_compute_pass(&ComputePassDescriptor {
682                label: Some("create_proofs"),
683                timestamp_writes: None,
684            });
685
686            cpass.set_bind_group(0, &self.bind_group_compute_f1, &[]);
687            cpass.set_pipeline(&self.compute_pipeline_compute_f1);
688            cpass.dispatch_workgroups(
689                MAX_TABLE_SIZE
690                    .div_ceil(compute_f1::WORKGROUP_SIZE * compute_f1::ELEMENTS_PER_INVOCATION),
691                1,
692                1,
693            );
694
695            cpass.set_bind_group(0, &self.bind_group_sort_buckets_a, &[]);
696            cpass.set_pipeline(&self.compute_pipeline_sort_buckets_a);
697            cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
698
699            cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f2, &[]);
700            cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f2);
701            cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
702
703            cpass.set_bind_group(0, &self.bind_group_sort_buckets_b, &[]);
704            cpass.set_pipeline(&self.compute_pipeline_sort_buckets_b);
705            cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
706
707            cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f3, &[]);
708            cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f3);
709            cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
710
711            cpass.set_bind_group(0, &self.bind_group_sort_buckets_a, &[]);
712            cpass.set_pipeline(&self.compute_pipeline_sort_buckets_a);
713            cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
714
715            cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f4, &[]);
716            cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f4);
717            cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
718
719            cpass.set_bind_group(0, &self.bind_group_sort_buckets_b, &[]);
720            cpass.set_pipeline(&self.compute_pipeline_sort_buckets_b);
721            cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
722
723            cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f5, &[]);
724            cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f5);
725            cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
726
727            cpass.set_bind_group(0, &self.bind_group_sort_buckets_a, &[]);
728            cpass.set_pipeline(&self.compute_pipeline_sort_buckets_a);
729            cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
730
731            cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f6, &[]);
732            cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f6);
733            cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
734
735            cpass.set_bind_group(0, &self.bind_group_sort_buckets_b, &[]);
736            cpass.set_pipeline(&self.compute_pipeline_sort_buckets_b);
737            cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
738
739            cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f7, &[]);
740            cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f7);
741            cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
742
743            cpass.set_bind_group(0, &self.bind_group_find_proofs, &[]);
744            cpass.set_pipeline(&self.compute_pipeline_find_proofs);
745            cpass.dispatch_workgroups(NUM_S_BUCKETS as u32 / find_proofs::WORKGROUP_SIZE, 1, 1);
746        }
747
748        encoder.copy_buffer_to_buffer(
749            &self.proofs_gpu,
750            0,
751            &self.proofs_host,
752            0,
753            self.proofs_host.size(),
754        );
755
756        // Map initial state for writes for the next iteration
757        encoder.map_buffer_on_submit(&self.initial_state_host, MapMode::Write, .., {
758            let mapping_error = Arc::clone(&self.mapping_error);
759
760            move |r| {
761                if let Err(error) = r {
762                    mapping_error.lock().replace(error);
763                }
764            }
765        });
766        encoder.map_buffer_on_submit(&self.proofs_host, MapMode::Read, .., {
767            let mapping_error = Arc::clone(&self.mapping_error);
768
769            move |r| {
770                if let Err(error) = r {
771                    mapping_error.lock().replace(error);
772                }
773            }
774        });
775
776        let submission_index = self.queue.submit([encoder.finish()]);
777
778        self.device.poll(PollType::Wait {
779            submission_index: Some(submission_index),
780            timeout: None,
781        })?;
782
783        if let Some(error) = self.mapping_error.lock().take() {
784            return Err(RecordEncodingError::BufferMapping(error));
785        }
786
787        let proofs = {
788            let proofs_host_ptr = self
789                .proofs_host
790                .get_mapped_range(..)
791                .as_ptr()
792                .cast::<ProofsHost>();
793            // SAFETY: Initialized on the GPU
794            unsafe { &*proofs_host_ptr }
795        };
796
797        self.tainted = false;
798
799        Ok(ProofsHostWrapper {
800            proofs,
801            proofs_host: &self.proofs_host,
802        })
803    }
804}
805
806fn bind_group_and_pipeline_compute_f1(
807    device: &wgpu::Device,
808    module: &ShaderModule,
809    initial_state_gpu: &Buffer,
810    bucket_sizes_gpu: &Buffer,
811    buckets_gpu: &Buffer,
812) -> (BindGroup, ComputePipeline) {
813    let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
814        label: Some("compute_f1"),
815        entries: &[
816            BindGroupLayoutEntry {
817                binding: 0,
818                count: None,
819                visibility: ShaderStages::COMPUTE,
820                ty: BindingType::Buffer {
821                    has_dynamic_offset: false,
822                    min_binding_size: None,
823                    ty: BufferBindingType::Uniform,
824                },
825            },
826            BindGroupLayoutEntry {
827                binding: 1,
828                count: None,
829                visibility: ShaderStages::COMPUTE,
830                ty: BindingType::Buffer {
831                    has_dynamic_offset: false,
832                    min_binding_size: None,
833                    ty: BufferBindingType::Storage { read_only: false },
834                },
835            },
836            BindGroupLayoutEntry {
837                binding: 2,
838                count: None,
839                visibility: ShaderStages::COMPUTE,
840                ty: BindingType::Buffer {
841                    has_dynamic_offset: false,
842                    min_binding_size: None,
843                    ty: BufferBindingType::Storage { read_only: false },
844                },
845            },
846        ],
847    });
848
849    let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
850        label: Some("compute_f1"),
851        bind_group_layouts: &[Some(&bind_group_layout)],
852        immediate_size: 0,
853    });
854
855    let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
856        compilation_options: PipelineCompilationOptions {
857            constants: &[],
858            zero_initialize_workgroup_memory: false,
859        },
860        cache: None,
861        label: Some("compute_f1"),
862        layout: Some(&pipeline_layout),
863        module,
864        entry_point: Some("compute_f1"),
865    });
866
867    let bind_group = device.create_bind_group(&BindGroupDescriptor {
868        label: Some("compute_f1"),
869        layout: &bind_group_layout,
870        entries: &[
871            BindGroupEntry {
872                binding: 0,
873                resource: initial_state_gpu.as_entire_binding(),
874            },
875            BindGroupEntry {
876                binding: 1,
877                resource: bucket_sizes_gpu.as_entire_binding(),
878            },
879            BindGroupEntry {
880                binding: 2,
881                resource: buckets_gpu.as_entire_binding(),
882            },
883        ],
884    });
885
886    (bind_group, compute_pipeline)
887}
888
889fn bind_group_and_pipeline_sort_buckets(
890    device: &wgpu::Device,
891    module: &ShaderModule,
892    bucket_sizes_gpu: &Buffer,
893    buckets_gpu: &Buffer,
894) -> (BindGroup, ComputePipeline) {
895    let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
896        label: Some("sort_buckets"),
897        entries: &[
898            BindGroupLayoutEntry {
899                binding: 0,
900                count: None,
901                visibility: ShaderStages::COMPUTE,
902                ty: BindingType::Buffer {
903                    has_dynamic_offset: false,
904                    min_binding_size: None,
905                    ty: BufferBindingType::Storage { read_only: false },
906                },
907            },
908            BindGroupLayoutEntry {
909                binding: 1,
910                count: None,
911                visibility: ShaderStages::COMPUTE,
912                ty: BindingType::Buffer {
913                    has_dynamic_offset: false,
914                    min_binding_size: None,
915                    ty: BufferBindingType::Storage { read_only: false },
916                },
917            },
918        ],
919    });
920
921    let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
922        label: Some("sort_buckets"),
923        bind_group_layouts: &[Some(&bind_group_layout)],
924        immediate_size: 0,
925    });
926
927    let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
928        compilation_options: PipelineCompilationOptions {
929            constants: &[],
930            zero_initialize_workgroup_memory: false,
931        },
932        cache: None,
933        label: Some("sort_buckets"),
934        layout: Some(&pipeline_layout),
935        module,
936        entry_point: Some("sort_buckets"),
937    });
938
939    let bind_group = device.create_bind_group(&BindGroupDescriptor {
940        label: Some("sort_buckets"),
941        layout: &bind_group_layout,
942        entries: &[
943            BindGroupEntry {
944                binding: 0,
945                resource: bucket_sizes_gpu.as_entire_binding(),
946            },
947            BindGroupEntry {
948                binding: 1,
949                resource: buckets_gpu.as_entire_binding(),
950            },
951        ],
952    });
953
954    (bind_group, compute_pipeline)
955}
956
957fn bind_group_and_pipeline_find_matches_and_compute_f2(
958    device: &wgpu::Device,
959    module: &ShaderModule,
960    parent_buckets_gpu: &Buffer,
961    bucket_sizes_gpu: &Buffer,
962    buckets_gpu: &Buffer,
963    positions_gpu: &Buffer,
964    metadatas_gpu: &Buffer,
965) -> (BindGroup, ComputePipeline) {
966    let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
967        label: Some("find_matches_and_compute_f2"),
968        entries: &[
969            BindGroupLayoutEntry {
970                binding: 0,
971                count: None,
972                visibility: ShaderStages::COMPUTE,
973                ty: BindingType::Buffer {
974                    has_dynamic_offset: false,
975                    min_binding_size: None,
976                    ty: BufferBindingType::Storage { read_only: true },
977                },
978            },
979            BindGroupLayoutEntry {
980                binding: 1,
981                count: None,
982                visibility: ShaderStages::COMPUTE,
983                ty: BindingType::Buffer {
984                    has_dynamic_offset: false,
985                    min_binding_size: None,
986                    ty: BufferBindingType::Storage { read_only: false },
987                },
988            },
989            BindGroupLayoutEntry {
990                binding: 2,
991                count: None,
992                visibility: ShaderStages::COMPUTE,
993                ty: BindingType::Buffer {
994                    has_dynamic_offset: false,
995                    min_binding_size: None,
996                    ty: BufferBindingType::Storage { read_only: false },
997                },
998            },
999            BindGroupLayoutEntry {
1000                binding: 3,
1001                count: None,
1002                visibility: ShaderStages::COMPUTE,
1003                ty: BindingType::Buffer {
1004                    has_dynamic_offset: false,
1005                    min_binding_size: None,
1006                    ty: BufferBindingType::Storage { read_only: false },
1007                },
1008            },
1009            BindGroupLayoutEntry {
1010                binding: 4,
1011                count: None,
1012                visibility: ShaderStages::COMPUTE,
1013                ty: BindingType::Buffer {
1014                    has_dynamic_offset: false,
1015                    min_binding_size: None,
1016                    ty: BufferBindingType::Storage { read_only: false },
1017                },
1018            },
1019        ],
1020    });
1021
1022    let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1023        label: Some("find_matches_and_compute_f2"),
1024        bind_group_layouts: &[Some(&bind_group_layout)],
1025        immediate_size: 0,
1026    });
1027
1028    let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1029        compilation_options: PipelineCompilationOptions {
1030            constants: &[],
1031            zero_initialize_workgroup_memory: true,
1032        },
1033        cache: None,
1034        label: Some("find_matches_and_compute_f2"),
1035        layout: Some(&pipeline_layout),
1036        module,
1037        entry_point: Some("find_matches_and_compute_f2"),
1038    });
1039
1040    let bind_group = device.create_bind_group(&BindGroupDescriptor {
1041        label: Some("find_matches_and_compute_f2"),
1042        layout: &bind_group_layout,
1043        entries: &[
1044            BindGroupEntry {
1045                binding: 0,
1046                resource: parent_buckets_gpu.as_entire_binding(),
1047            },
1048            BindGroupEntry {
1049                binding: 1,
1050                resource: bucket_sizes_gpu.as_entire_binding(),
1051            },
1052            BindGroupEntry {
1053                binding: 2,
1054                resource: buckets_gpu.as_entire_binding(),
1055            },
1056            BindGroupEntry {
1057                binding: 3,
1058                resource: positions_gpu.as_entire_binding(),
1059            },
1060            BindGroupEntry {
1061                binding: 4,
1062                resource: metadatas_gpu.as_entire_binding(),
1063            },
1064        ],
1065    });
1066
1067    (bind_group, compute_pipeline)
1068}
1069
1070#[expect(
1071    clippy::too_many_arguments,
1072    reason = "Both I/O and Vulkan stuff together take a lot of arguments"
1073)]
1074fn bind_group_and_pipeline_find_matches_and_compute_fn<const TABLE_NUMBER: u8>(
1075    device: &wgpu::Device,
1076    module: &ShaderModule,
1077    parent_buckets_gpu: &Buffer,
1078    parent_metadatas_gpu: &Buffer,
1079    bucket_sizes_gpu: &Buffer,
1080    buckets_gpu: &Buffer,
1081    positions_gpu: &Buffer,
1082    metadatas_gpu: &Buffer,
1083) -> (BindGroup, ComputePipeline) {
1084    let label = format!("find_matches_and_compute_f{TABLE_NUMBER}");
1085    let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
1086        label: Some(&label),
1087        entries: &[
1088            BindGroupLayoutEntry {
1089                binding: 0,
1090                count: None,
1091                visibility: ShaderStages::COMPUTE,
1092                ty: BindingType::Buffer {
1093                    has_dynamic_offset: false,
1094                    min_binding_size: None,
1095                    ty: BufferBindingType::Storage { read_only: true },
1096                },
1097            },
1098            BindGroupLayoutEntry {
1099                binding: 1,
1100                count: None,
1101                visibility: ShaderStages::COMPUTE,
1102                ty: BindingType::Buffer {
1103                    has_dynamic_offset: false,
1104                    min_binding_size: None,
1105                    ty: BufferBindingType::Storage { read_only: true },
1106                },
1107            },
1108            BindGroupLayoutEntry {
1109                binding: 2,
1110                count: None,
1111                visibility: ShaderStages::COMPUTE,
1112                ty: BindingType::Buffer {
1113                    has_dynamic_offset: false,
1114                    min_binding_size: None,
1115                    ty: BufferBindingType::Storage { read_only: false },
1116                },
1117            },
1118            BindGroupLayoutEntry {
1119                binding: 3,
1120                count: None,
1121                visibility: ShaderStages::COMPUTE,
1122                ty: BindingType::Buffer {
1123                    has_dynamic_offset: false,
1124                    min_binding_size: None,
1125                    ty: BufferBindingType::Storage { read_only: false },
1126                },
1127            },
1128            BindGroupLayoutEntry {
1129                binding: 4,
1130                count: None,
1131                visibility: ShaderStages::COMPUTE,
1132                ty: BindingType::Buffer {
1133                    has_dynamic_offset: false,
1134                    min_binding_size: None,
1135                    ty: BufferBindingType::Storage { read_only: false },
1136                },
1137            },
1138            BindGroupLayoutEntry {
1139                binding: 5,
1140                count: None,
1141                visibility: ShaderStages::COMPUTE,
1142                ty: BindingType::Buffer {
1143                    has_dynamic_offset: false,
1144                    min_binding_size: None,
1145                    ty: BufferBindingType::Storage { read_only: false },
1146                },
1147            },
1148        ],
1149    });
1150
1151    let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1152        label: Some(&label),
1153        bind_group_layouts: &[Some(&bind_group_layout)],
1154        immediate_size: 0,
1155    });
1156
1157    let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1158        compilation_options: PipelineCompilationOptions {
1159            constants: &[],
1160            zero_initialize_workgroup_memory: true,
1161        },
1162        cache: None,
1163        label: Some(&label),
1164        layout: Some(&pipeline_layout),
1165        module,
1166        entry_point: Some(&format!("find_matches_and_compute_f{TABLE_NUMBER}")),
1167    });
1168
1169    let bind_group = device.create_bind_group(&BindGroupDescriptor {
1170        label: Some(&label),
1171        layout: &bind_group_layout,
1172        entries: &[
1173            BindGroupEntry {
1174                binding: 0,
1175                resource: parent_buckets_gpu.as_entire_binding(),
1176            },
1177            BindGroupEntry {
1178                binding: 1,
1179                resource: parent_metadatas_gpu.as_entire_binding(),
1180            },
1181            BindGroupEntry {
1182                binding: 2,
1183                resource: bucket_sizes_gpu.as_entire_binding(),
1184            },
1185            BindGroupEntry {
1186                binding: 3,
1187                resource: buckets_gpu.as_entire_binding(),
1188            },
1189            BindGroupEntry {
1190                binding: 4,
1191                resource: positions_gpu.as_entire_binding(),
1192            },
1193            BindGroupEntry {
1194                binding: 5,
1195                resource: metadatas_gpu.as_entire_binding(),
1196            },
1197        ],
1198    });
1199
1200    (bind_group, compute_pipeline)
1201}
1202
1203fn bind_group_and_pipeline_find_matches_and_compute_f7(
1204    device: &wgpu::Device,
1205    module: &ShaderModule,
1206    parent_buckets_gpu: &Buffer,
1207    parent_metadatas_gpu: &Buffer,
1208    table_6_proof_targets_sizes_gpu: &Buffer,
1209    table_6_proof_targets_gpu: &Buffer,
1210) -> (BindGroup, ComputePipeline) {
1211    let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
1212        label: Some("find_matches_and_compute_f7"),
1213        entries: &[
1214            BindGroupLayoutEntry {
1215                binding: 0,
1216                count: None,
1217                visibility: ShaderStages::COMPUTE,
1218                ty: BindingType::Buffer {
1219                    has_dynamic_offset: false,
1220                    min_binding_size: None,
1221                    ty: BufferBindingType::Storage { read_only: true },
1222                },
1223            },
1224            BindGroupLayoutEntry {
1225                binding: 1,
1226                count: None,
1227                visibility: ShaderStages::COMPUTE,
1228                ty: BindingType::Buffer {
1229                    has_dynamic_offset: false,
1230                    min_binding_size: None,
1231                    ty: BufferBindingType::Storage { read_only: true },
1232                },
1233            },
1234            BindGroupLayoutEntry {
1235                binding: 2,
1236                count: None,
1237                visibility: ShaderStages::COMPUTE,
1238                ty: BindingType::Buffer {
1239                    has_dynamic_offset: false,
1240                    min_binding_size: None,
1241                    ty: BufferBindingType::Storage { read_only: false },
1242                },
1243            },
1244            BindGroupLayoutEntry {
1245                binding: 3,
1246                count: None,
1247                visibility: ShaderStages::COMPUTE,
1248                ty: BindingType::Buffer {
1249                    has_dynamic_offset: false,
1250                    min_binding_size: None,
1251                    ty: BufferBindingType::Storage { read_only: false },
1252                },
1253            },
1254        ],
1255    });
1256
1257    let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1258        label: Some("find_matches_and_compute_f7"),
1259        bind_group_layouts: &[Some(&bind_group_layout)],
1260        immediate_size: 0,
1261    });
1262
1263    let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1264        compilation_options: PipelineCompilationOptions {
1265            constants: &[],
1266            zero_initialize_workgroup_memory: true,
1267        },
1268        cache: None,
1269        label: Some("find_matches_and_compute_f7"),
1270        layout: Some(&pipeline_layout),
1271        module,
1272        entry_point: Some("find_matches_and_compute_f7"),
1273    });
1274
1275    let bind_group = device.create_bind_group(&BindGroupDescriptor {
1276        label: Some("find_matches_and_compute_f7"),
1277        layout: &bind_group_layout,
1278        entries: &[
1279            BindGroupEntry {
1280                binding: 0,
1281                resource: parent_buckets_gpu.as_entire_binding(),
1282            },
1283            BindGroupEntry {
1284                binding: 1,
1285                resource: parent_metadatas_gpu.as_entire_binding(),
1286            },
1287            BindGroupEntry {
1288                binding: 2,
1289                resource: table_6_proof_targets_sizes_gpu.as_entire_binding(),
1290            },
1291            BindGroupEntry {
1292                binding: 3,
1293                resource: table_6_proof_targets_gpu.as_entire_binding(),
1294            },
1295        ],
1296    });
1297
1298    (bind_group, compute_pipeline)
1299}
1300
1301#[expect(
1302    clippy::too_many_arguments,
1303    reason = "Both I/O and Vulkan stuff together take a lot of arguments"
1304)]
1305fn bind_group_and_pipeline_find_proofs(
1306    device: &wgpu::Device,
1307    module: &ShaderModule,
1308    table_2_positions_gpu: &Buffer,
1309    table_3_positions_gpu: &Buffer,
1310    table_4_positions_gpu: &Buffer,
1311    table_5_positions_gpu: &Buffer,
1312    table_6_positions_gpu: &Buffer,
1313    bucket_sizes_gpu: &Buffer,
1314    buckets_gpu: &Buffer,
1315    proofs_gpu: &Buffer,
1316) -> (BindGroup, ComputePipeline) {
1317    let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
1318        label: Some("find_proofs"),
1319        entries: &[
1320            BindGroupLayoutEntry {
1321                binding: 0,
1322                count: None,
1323                visibility: ShaderStages::COMPUTE,
1324                ty: BindingType::Buffer {
1325                    has_dynamic_offset: false,
1326                    min_binding_size: None,
1327                    ty: BufferBindingType::Storage { read_only: true },
1328                },
1329            },
1330            BindGroupLayoutEntry {
1331                binding: 1,
1332                count: None,
1333                visibility: ShaderStages::COMPUTE,
1334                ty: BindingType::Buffer {
1335                    has_dynamic_offset: false,
1336                    min_binding_size: None,
1337                    ty: BufferBindingType::Storage { read_only: true },
1338                },
1339            },
1340            BindGroupLayoutEntry {
1341                binding: 2,
1342                count: None,
1343                visibility: ShaderStages::COMPUTE,
1344                ty: BindingType::Buffer {
1345                    has_dynamic_offset: false,
1346                    min_binding_size: None,
1347                    ty: BufferBindingType::Storage { read_only: true },
1348                },
1349            },
1350            BindGroupLayoutEntry {
1351                binding: 3,
1352                count: None,
1353                visibility: ShaderStages::COMPUTE,
1354                ty: BindingType::Buffer {
1355                    has_dynamic_offset: false,
1356                    min_binding_size: None,
1357                    ty: BufferBindingType::Storage { read_only: true },
1358                },
1359            },
1360            BindGroupLayoutEntry {
1361                binding: 4,
1362                count: None,
1363                visibility: ShaderStages::COMPUTE,
1364                ty: BindingType::Buffer {
1365                    has_dynamic_offset: false,
1366                    min_binding_size: None,
1367                    ty: BufferBindingType::Storage { read_only: true },
1368                },
1369            },
1370            BindGroupLayoutEntry {
1371                binding: 5,
1372                count: None,
1373                visibility: ShaderStages::COMPUTE,
1374                ty: BindingType::Buffer {
1375                    has_dynamic_offset: false,
1376                    min_binding_size: None,
1377                    ty: BufferBindingType::Storage { read_only: false },
1378                },
1379            },
1380            BindGroupLayoutEntry {
1381                binding: 6,
1382                count: None,
1383                visibility: ShaderStages::COMPUTE,
1384                ty: BindingType::Buffer {
1385                    has_dynamic_offset: false,
1386                    min_binding_size: None,
1387                    ty: BufferBindingType::Storage { read_only: true },
1388                },
1389            },
1390            BindGroupLayoutEntry {
1391                binding: 7,
1392                count: None,
1393                visibility: ShaderStages::COMPUTE,
1394                ty: BindingType::Buffer {
1395                    has_dynamic_offset: false,
1396                    min_binding_size: None,
1397                    ty: BufferBindingType::Storage { read_only: false },
1398                },
1399            },
1400        ],
1401    });
1402
1403    let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1404        label: Some("find_proofs"),
1405        bind_group_layouts: &[Some(&bind_group_layout)],
1406        immediate_size: 0,
1407    });
1408
1409    let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1410        compilation_options: PipelineCompilationOptions {
1411            constants: &[],
1412            zero_initialize_workgroup_memory: false,
1413        },
1414        cache: None,
1415        label: Some("find_proofs"),
1416        layout: Some(&pipeline_layout),
1417        module,
1418        entry_point: Some("find_proofs"),
1419    });
1420
1421    let bind_group = device.create_bind_group(&BindGroupDescriptor {
1422        label: Some("find_proofs"),
1423        layout: &bind_group_layout,
1424        entries: &[
1425            BindGroupEntry {
1426                binding: 0,
1427                resource: table_2_positions_gpu.as_entire_binding(),
1428            },
1429            BindGroupEntry {
1430                binding: 1,
1431                resource: table_3_positions_gpu.as_entire_binding(),
1432            },
1433            BindGroupEntry {
1434                binding: 2,
1435                resource: table_4_positions_gpu.as_entire_binding(),
1436            },
1437            BindGroupEntry {
1438                binding: 3,
1439                resource: table_5_positions_gpu.as_entire_binding(),
1440            },
1441            BindGroupEntry {
1442                binding: 4,
1443                resource: table_6_positions_gpu.as_entire_binding(),
1444            },
1445            BindGroupEntry {
1446                binding: 5,
1447                resource: bucket_sizes_gpu.as_entire_binding(),
1448            },
1449            BindGroupEntry {
1450                binding: 6,
1451                resource: buckets_gpu.as_entire_binding(),
1452            },
1453            BindGroupEntry {
1454                binding: 7,
1455                resource: proofs_gpu.as_entire_binding(),
1456            },
1457        ],
1458    });
1459
1460    (bind_group, compute_pipeline)
1461}