Skip to main content

ab_proof_of_space_gpu/
host.rs

1#[cfg(all(test, not(miri)))]
2mod tests;
3
4use crate::shader::constants::{
5    MAX_BUCKET_SIZE, MAX_TABLE_SIZE, NUM_BUCKETS, NUM_MATCH_BUCKETS, NUM_S_BUCKETS,
6    REDUCED_MATCHES_COUNT,
7};
8use crate::shader::find_matches_and_compute_f7::{NUM_ELEMENTS_PER_S_BUCKET, ProofTargets};
9use crate::shader::find_proofs::ProofsHost;
10use crate::shader::types::{Metadata, Position, PositionR};
11use crate::shader::{compute_f1, find_proofs, select_shader_features_limits};
12use ab_chacha8::{ChaCha8Block, ChaCha8State, block_to_bytes};
13use ab_core_primitives::pieces::{PieceOffset, Record};
14use ab_core_primitives::pos::PosSeed;
15use ab_core_primitives::sectors::SectorId;
16use ab_erasure_coding::ErasureCoding;
17use ab_farmer_components::plotting::RecordsEncoder;
18use ab_farmer_components::sector::SectorContentsMap;
19use async_lock::Mutex as AsyncMutex;
20use futures::stream::FuturesOrdered;
21use futures::{StreamExt, TryStreamExt};
22use parking_lot::Mutex;
23use rayon::prelude::*;
24use rayon::{ThreadPool, ThreadPoolBuilder};
25use rclite::Arc;
26use std::num::NonZeroU8;
27use std::simd::Simd;
28use std::sync::Arc as StdArc;
29use std::sync::atomic::{AtomicBool, Ordering};
30use std::{fmt, iter};
31use tracing::{debug, warn};
32use wgpu::{
33    AdapterInfo, Backend, BackendOptions, Backends, BindGroup, BindGroupDescriptor, BindGroupEntry,
34    BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType, Buffer, BufferAddress,
35    BufferAsyncError, BufferBindingType, BufferDescriptor, BufferUsages, CommandEncoderDescriptor,
36    ComputePassDescriptor, ComputePipeline, ComputePipelineDescriptor, DeviceDescriptor,
37    DeviceType, Instance, InstanceDescriptor, InstanceFlags, MapMode, MemoryBudgetThresholds,
38    PipelineCompilationOptions, PipelineLayoutDescriptor, PollError, PollType, Queue,
39    RequestDeviceError, ShaderModule, ShaderRuntimeChecks, ShaderStages,
40};
41
42/// Proof creation error
43#[derive(Debug, thiserror::Error)]
44enum RecordEncodingError {
45    /// Too many records
46    #[error("Too many records: {0}")]
47    TooManyRecords(usize),
48    /// Proof creation failed previously and the device is now considered broken
49    #[error("Proof creation failed previously and the device is now considered broken")]
50    DeviceBroken,
51    /// Failed to map buffer
52    #[error("Failed to map buffer: {0}")]
53    BufferMapping(#[from] BufferAsyncError),
54    /// Poll error
55    #[error("Poll error: {0}")]
56    DevicePoll(#[from] PollError),
57}
58
59struct ProofsHostWrapper<'a> {
60    proofs: &'a ProofsHost,
61    proofs_host: &'a Buffer,
62}
63
64impl Drop for ProofsHostWrapper<'_> {
65    fn drop(&mut self) {
66        self.proofs_host.unmap();
67    }
68}
69
70/// Wrapper data structure encapsulating a single compatible device
71#[derive(Clone)]
72pub struct Device {
73    id: u32,
74    devices: Vec<(wgpu::Device, Queue, ShaderModule)>,
75    adapter_info: AdapterInfo,
76}
77
78impl fmt::Debug for Device {
79    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80        f.debug_struct("Device")
81            .field("id", &self.id)
82            .field("name", &self.adapter_info.name)
83            .field("device_type", &self.adapter_info.device_type)
84            .field("driver", &self.adapter_info.driver)
85            .field("driver_info", &self.adapter_info.driver_info)
86            .field("backend", &self.adapter_info.backend)
87            .finish_non_exhaustive()
88    }
89}
90
91impl Device {
92    /// Returns [`Device`] for each available device
93    pub async fn enumerate<NOQ>(number_of_queues: NOQ) -> Vec<Self>
94    where
95        NOQ: Fn(DeviceType) -> NonZeroU8,
96    {
97        let backends = Backends::from_env().unwrap_or(Backends::METAL | Backends::VULKAN);
98        let instance = Instance::new(InstanceDescriptor {
99            backends,
100            flags: if cfg!(debug_assertions) {
101                InstanceFlags::debugging().with_env()
102            } else {
103                InstanceFlags::from_env_or_default()
104            },
105            memory_budget_thresholds: MemoryBudgetThresholds::default(),
106            backend_options: BackendOptions::from_env_or_default(),
107            display: None,
108        });
109
110        let adapters = instance.enumerate_adapters(backends).await;
111        let number_of_queues = &number_of_queues;
112
113        adapters
114            .into_iter()
115            .zip(0..)
116            .map(|(adapter, id)| async move {
117                let adapter_info = adapter.get_info();
118
119                let (shader, required_features, required_limits) =
120                    if let Some((shader, required_features, required_limits)) =
121                        select_shader_features_limits(&adapter)
122                    {
123                        debug!(
124                            %id,
125                            adapter_info = ?adapter_info,
126                            "Compatible adapter found"
127                        );
128
129                        (shader, required_features, required_limits)
130                    } else {
131                        debug!(
132                            %id,
133                            adapter_info = ?adapter_info,
134                            "Incompatible adapter found"
135                        );
136
137                        return None;
138                    };
139
140                // TODO: creation of multiple devices is a workaround for lack of support for
141                //  multiple queues: https://github.com/gfx-rs/wgpu/issues/1066
142                let devices = iter::repeat_with(|| async {
143                    let (device, queue) = adapter
144                        .request_device(&DeviceDescriptor {
145                            label: None,
146                            required_features,
147                            required_limits: required_limits.clone(),
148                            ..DeviceDescriptor::default()
149                        })
150                        .await
151                        .inspect_err(|error| {
152                            warn!(%id, ?adapter_info, %error, "Failed to request the device");
153                        })?;
154                    let module = if cfg!(debug_assertions) {
155                        device.create_shader_module(shader.clone())
156                    } else {
157                        // SAFETY: The shader is trusted, individual tests include correctness
158                        // checks and debug builds of this crate too only release version skips
159                        // checks for better runtime performance
160                        unsafe {
161                            device.create_shader_module_trusted(
162                                shader.clone(),
163                                ShaderRuntimeChecks::unchecked(),
164                            )
165                        }
166                    };
167
168                    Ok::<_, RequestDeviceError>((device, queue, module))
169                })
170                .take(usize::from(
171                    number_of_queues(adapter_info.device_type).get(),
172                ))
173                .collect::<FuturesOrdered<_>>()
174                .try_collect::<Vec<_>>()
175                .await
176                .ok()?;
177
178                Some(Self {
179                    id,
180                    devices,
181                    adapter_info,
182                })
183            })
184            .collect::<FuturesOrdered<_>>()
185            .filter_map(|device| async move { device })
186            .collect()
187            .await
188    }
189
190    /// Gpu ID
191    pub fn id(&self) -> u32 {
192        self.id
193    }
194
195    /// Device name
196    pub fn name(&self) -> &str {
197        &self.adapter_info.name
198    }
199
200    /// Device type
201    pub fn device_type(&self) -> DeviceType {
202        self.adapter_info.device_type
203    }
204
205    /// Driver
206    pub fn driver(&self) -> &str {
207        &self.adapter_info.driver
208    }
209
210    /// Driver info
211    pub fn driver_info(&self) -> &str {
212        &self.adapter_info.driver_info
213    }
214
215    /// Backend
216    pub fn backend(&self) -> Backend {
217        self.adapter_info.backend
218    }
219
220    pub fn instantiate(
221        &self,
222        erasure_coding: ErasureCoding,
223        global_mutex: StdArc<AsyncMutex<()>>,
224    ) -> anyhow::Result<GpuRecordsEncoder> {
225        GpuRecordsEncoder::new(
226            self.id,
227            self.devices.clone(),
228            self.adapter_info.clone(),
229            erasure_coding,
230            global_mutex,
231        )
232    }
233}
234
235pub struct GpuRecordsEncoder {
236    id: u32,
237    instances: Vec<Mutex<GpuRecordsEncoderInstance>>,
238    thread_pool: ThreadPool,
239    adapter_info: AdapterInfo,
240    erasure_coding: ErasureCoding,
241    global_mutex: StdArc<AsyncMutex<()>>,
242}
243
244impl fmt::Debug for GpuRecordsEncoder {
245    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
246        f.debug_struct("GpuRecordsEncoder")
247            .field("id", &self.id)
248            .field("name", &self.adapter_info.name)
249            .field("device_type", &self.adapter_info.device_type)
250            .field("driver", &self.adapter_info.driver)
251            .field("driver_info", &self.adapter_info.driver_info)
252            .field("backend", &self.adapter_info.backend)
253            .finish_non_exhaustive()
254    }
255}
256
257impl RecordsEncoder for GpuRecordsEncoder {
258    // TODO: Run more than one encoding per device concurrently
259    fn encode_records(
260        &mut self,
261        sector_id: &SectorId,
262        records: &mut [Record],
263        abort_early: &AtomicBool,
264    ) -> anyhow::Result<SectorContentsMap> {
265        let mut sector_contents_map = SectorContentsMap::new(
266            u16::try_from(records.len())
267                .map_err(|_error| RecordEncodingError::TooManyRecords(records.len()))?,
268        );
269
270        let maybe_error = self.thread_pool.install(|| {
271            records
272                .par_iter_mut()
273                .zip(sector_contents_map.iter_record_chunks_used_mut())
274                .enumerate()
275                .find_map_any(|(piece_offset, (record, record_chunks_used))| {
276                    // Take mutex briefly to make sure encoding is allowed right now
277                    self.global_mutex.lock_blocking();
278
279                    let mut parity_record_chunks = Record::new_boxed();
280
281                    // TODO: Do erasure coding on the GPU
282                    // Erasure code source record chunks
283                    self.erasure_coding
284                        .extend(record.iter(), parity_record_chunks.iter_mut())
285                        .expect("Statically guaranteed valid inputs; qed");
286
287                    if abort_early.load(Ordering::Relaxed) {
288                        return None;
289                    }
290                    let seed =
291                        sector_id.derive_evaluation_seed(PieceOffset::from(piece_offset as u16));
292                    let thread_index = rayon::current_thread_index().unwrap_or_default();
293                    let mut encoder_instance = self.instances[thread_index]
294                        .try_lock()
295                        .expect("1:1 mapping between threads and devices; qed");
296                    let proofs = match encoder_instance.create_proofs(&seed) {
297                        Ok(proofs) => proofs,
298                        Err(error) => {
299                            return Some(error);
300                        }
301                    };
302                    let proofs = proofs.proofs;
303
304                    *record_chunks_used = proofs.found_proofs;
305
306                    // TODO: Record encoding on the GPU
307                    let mut num_found_proofs = 0_usize;
308                    for (s_buckets, found_proofs) in (0..Record::NUM_S_BUCKETS)
309                        .array_chunks::<{ u8::BITS as usize }>()
310                        .zip(record_chunks_used)
311                    {
312                        for (proof_offset, s_bucket) in s_buckets.into_iter().enumerate() {
313                            if num_found_proofs == Record::NUM_CHUNKS {
314                                // Enough proofs collected, clear the rest of the bits
315                                *found_proofs &=
316                                    u8::MAX.unbounded_shr(u8::BITS - proof_offset as u32);
317                                break;
318                            }
319                            if (*found_proofs & (1 << proof_offset)) != 0 {
320                                let record_chunk = if s_bucket < Record::NUM_CHUNKS {
321                                    record[s_bucket]
322                                } else {
323                                    parity_record_chunks[s_bucket - Record::NUM_CHUNKS]
324                                };
325
326                                record[num_found_proofs] = (Simd::from(record_chunk)
327                                    ^ Simd::from(*proofs.proofs[s_bucket].hash()))
328                                .to_array();
329                                num_found_proofs += 1;
330                            }
331                        }
332                    }
333
334                    None
335                })
336        });
337
338        if let Some(error) = maybe_error {
339            return Err(error.into());
340        }
341
342        Ok(sector_contents_map)
343    }
344}
345
346impl GpuRecordsEncoder {
347    fn new(
348        id: u32,
349        devices: Vec<(wgpu::Device, Queue, ShaderModule)>,
350        adapter_info: AdapterInfo,
351        erasure_coding: ErasureCoding,
352        global_mutex: StdArc<AsyncMutex<()>>,
353    ) -> anyhow::Result<Self> {
354        let thread_pool = ThreadPoolBuilder::new()
355            .thread_name(move |thread_index| format!("pos-gpu-{id}.{thread_index}"))
356            .num_threads(devices.len())
357            .build()?;
358
359        Ok(Self {
360            id,
361            instances: devices
362                .into_iter()
363                .map(|(device, queue, module)| {
364                    Mutex::new(GpuRecordsEncoderInstance::new(device, queue, module))
365                })
366                .collect(),
367            thread_pool,
368            adapter_info,
369            erasure_coding,
370            global_mutex,
371        })
372    }
373}
374
375struct GpuRecordsEncoderInstance {
376    device: wgpu::Device,
377    queue: Queue,
378    mapping_error: Arc<Mutex<Option<BufferAsyncError>>>,
379    tainted: bool,
380    initial_state_host: Buffer,
381    initial_state_gpu: Buffer,
382    proofs_host: Buffer,
383    proofs_gpu: Buffer,
384    bind_group_compute_f1: BindGroup,
385    compute_pipeline_compute_f1: ComputePipeline,
386    bind_group_sort_buckets_a: BindGroup,
387    compute_pipeline_sort_buckets_a: ComputePipeline,
388    bind_group_sort_buckets_b: BindGroup,
389    compute_pipeline_sort_buckets_b: ComputePipeline,
390    bind_group_find_matches_and_compute_f2: BindGroup,
391    compute_pipeline_find_matches_and_compute_f2: ComputePipeline,
392    bind_group_find_matches_and_compute_f3: BindGroup,
393    compute_pipeline_find_matches_and_compute_f3: ComputePipeline,
394    bind_group_find_matches_and_compute_f4: BindGroup,
395    compute_pipeline_find_matches_and_compute_f4: ComputePipeline,
396    bind_group_find_matches_and_compute_f5: BindGroup,
397    compute_pipeline_find_matches_and_compute_f5: ComputePipeline,
398    bind_group_find_matches_and_compute_f6: BindGroup,
399    compute_pipeline_find_matches_and_compute_f6: ComputePipeline,
400    bind_group_find_matches_and_compute_f7: BindGroup,
401    compute_pipeline_find_matches_and_compute_f7: ComputePipeline,
402    bind_group_find_proofs: BindGroup,
403    compute_pipeline_find_proofs: ComputePipeline,
404}
405
406impl GpuRecordsEncoderInstance {
407    fn new(device: wgpu::Device, queue: Queue, module: ShaderModule) -> Self {
408        let initial_state_host = device.create_buffer(&BufferDescriptor {
409            label: Some("initial_state_host"),
410            size: size_of::<ChaCha8Block>() as BufferAddress,
411            usage: BufferUsages::MAP_WRITE | BufferUsages::COPY_SRC,
412            mapped_at_creation: true,
413        });
414
415        let initial_state_gpu = device.create_buffer(&BufferDescriptor {
416            label: Some("initial_state_gpu"),
417            size: initial_state_host.size(),
418            usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST,
419            mapped_at_creation: false,
420        });
421
422        let bucket_sizes_gpu_buffer_size = size_of::<[u32; NUM_BUCKETS]>() as BufferAddress;
423        let table_6_proof_targets_sizes_gpu_buffer_size =
424            size_of::<[u32; NUM_S_BUCKETS]>() as BufferAddress;
425        // TODO: Sizes are excessive, for `bucket_sizes_gpu` are less than `u16` and could use SWAR
426        //  approach for storing bucket sizes. Similarly, `table_6_proof_targets_sizes_gpu` sizes
427        //  are less than `u8` and could use SWAR too with even higher compression ratio
428        let bucket_sizes_gpu = device.create_buffer(&BufferDescriptor {
429            label: Some("bucket_sizes_gpu"),
430            size: bucket_sizes_gpu_buffer_size.max(table_6_proof_targets_sizes_gpu_buffer_size),
431            usage: BufferUsages::STORAGE,
432            mapped_at_creation: false,
433        });
434        // Reuse the same buffer as `bucket_sizes_gpu`, they are not overlapping in use
435        let table_6_proof_targets_sizes_gpu = bucket_sizes_gpu.clone();
436
437        let buckets_a_gpu = device.create_buffer(&BufferDescriptor {
438            label: Some("buckets_a_gpu"),
439            size: size_of::<[[PositionR; MAX_BUCKET_SIZE]; NUM_BUCKETS]>() as BufferAddress,
440            usage: BufferUsages::STORAGE,
441            mapped_at_creation: false,
442        });
443
444        let buckets_b_gpu = device.create_buffer(&BufferDescriptor {
445            label: Some("buckets_b_gpu"),
446            size: buckets_a_gpu.size(),
447            usage: BufferUsages::STORAGE,
448            mapped_at_creation: false,
449        });
450
451        let positions_f2_gpu = device.create_buffer(&BufferDescriptor {
452            label: Some("positions_f2_gpu"),
453            size: size_of::<[[[Position; 2]; REDUCED_MATCHES_COUNT]; NUM_MATCH_BUCKETS]>()
454                as BufferAddress,
455            usage: BufferUsages::STORAGE,
456            mapped_at_creation: false,
457        });
458
459        let positions_f3_gpu = device.create_buffer(&BufferDescriptor {
460            label: Some("positions_f3_gpu"),
461            size: positions_f2_gpu.size(),
462            usage: BufferUsages::STORAGE,
463            mapped_at_creation: false,
464        });
465
466        let positions_f4_gpu = device.create_buffer(&BufferDescriptor {
467            label: Some("positions_f4_gpu"),
468            size: positions_f2_gpu.size(),
469            usage: BufferUsages::STORAGE,
470            mapped_at_creation: false,
471        });
472
473        let positions_f5_gpu = device.create_buffer(&BufferDescriptor {
474            label: Some("positions_f5_gpu"),
475            size: positions_f2_gpu.size(),
476            usage: BufferUsages::STORAGE,
477            mapped_at_creation: false,
478        });
479
480        let positions_f6_gpu = device.create_buffer(&BufferDescriptor {
481            label: Some("positions_f6_gpu"),
482            size: positions_f2_gpu.size(),
483            usage: BufferUsages::STORAGE,
484            mapped_at_creation: false,
485        });
486
487        let metadatas_gpu_buffer_size =
488            size_of::<[[Metadata; REDUCED_MATCHES_COUNT]; NUM_MATCH_BUCKETS]>() as BufferAddress;
489        let table_6_proof_targets_gpu_buffer_size = size_of::<
490            [[ProofTargets; NUM_ELEMENTS_PER_S_BUCKET]; NUM_S_BUCKETS],
491        >() as BufferAddress;
492        let metadatas_a_gpu = device.create_buffer(&BufferDescriptor {
493            label: Some("metadatas_a_gpu"),
494            size: metadatas_gpu_buffer_size.max(table_6_proof_targets_gpu_buffer_size),
495            usage: BufferUsages::STORAGE,
496            mapped_at_creation: false,
497        });
498        // Reuse the same buffer as `metadatas_a_gpu`, they are not overlapping in use
499        let table_6_proof_targets_gpu = metadatas_a_gpu.clone();
500
501        let metadatas_b_gpu = device.create_buffer(&BufferDescriptor {
502            label: Some("metadatas_b_gpu"),
503            size: metadatas_gpu_buffer_size,
504            usage: BufferUsages::STORAGE,
505            mapped_at_creation: false,
506        });
507
508        let proofs_host = device.create_buffer(&BufferDescriptor {
509            label: Some("proofs_host"),
510            size: size_of::<ProofsHost>() as BufferAddress,
511            usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
512            mapped_at_creation: false,
513        });
514
515        let proofs_gpu = device.create_buffer(&BufferDescriptor {
516            label: Some("proofs_gpu"),
517            size: proofs_host.size(),
518            usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
519            mapped_at_creation: false,
520        });
521
522        let (bind_group_compute_f1, compute_pipeline_compute_f1) =
523            bind_group_and_pipeline_compute_f1(
524                &device,
525                &module,
526                &initial_state_gpu,
527                &bucket_sizes_gpu,
528                &buckets_a_gpu,
529            );
530        let (bind_group_sort_buckets_a, compute_pipeline_sort_buckets_a) =
531            bind_group_and_pipeline_sort_buckets(
532                &device,
533                &module,
534                &bucket_sizes_gpu,
535                &buckets_a_gpu,
536            );
537
538        let (bind_group_sort_buckets_b, compute_pipeline_sort_buckets_b) =
539            bind_group_and_pipeline_sort_buckets(
540                &device,
541                &module,
542                &bucket_sizes_gpu,
543                &buckets_b_gpu,
544            );
545
546        let (bind_group_find_matches_and_compute_f2, compute_pipeline_find_matches_and_compute_f2) =
547            bind_group_and_pipeline_find_matches_and_compute_f2(
548                &device,
549                &module,
550                &buckets_a_gpu,
551                &bucket_sizes_gpu,
552                &buckets_b_gpu,
553                &positions_f2_gpu,
554                &metadatas_b_gpu,
555            );
556
557        let (bind_group_find_matches_and_compute_f3, compute_pipeline_find_matches_and_compute_f3) =
558            bind_group_and_pipeline_find_matches_and_compute_fn::<3>(
559                &device,
560                &module,
561                &buckets_b_gpu,
562                &metadatas_b_gpu,
563                &bucket_sizes_gpu,
564                &buckets_a_gpu,
565                &positions_f3_gpu,
566                &metadatas_a_gpu,
567            );
568
569        let (bind_group_find_matches_and_compute_f4, compute_pipeline_find_matches_and_compute_f4) =
570            bind_group_and_pipeline_find_matches_and_compute_fn::<4>(
571                &device,
572                &module,
573                &buckets_a_gpu,
574                &metadatas_a_gpu,
575                &bucket_sizes_gpu,
576                &buckets_b_gpu,
577                &positions_f4_gpu,
578                &metadatas_b_gpu,
579            );
580
581        let (bind_group_find_matches_and_compute_f5, compute_pipeline_find_matches_and_compute_f5) =
582            bind_group_and_pipeline_find_matches_and_compute_fn::<5>(
583                &device,
584                &module,
585                &buckets_b_gpu,
586                &metadatas_b_gpu,
587                &bucket_sizes_gpu,
588                &buckets_a_gpu,
589                &positions_f5_gpu,
590                &metadatas_a_gpu,
591            );
592
593        let (bind_group_find_matches_and_compute_f6, compute_pipeline_find_matches_and_compute_f6) =
594            bind_group_and_pipeline_find_matches_and_compute_fn::<6>(
595                &device,
596                &module,
597                &buckets_a_gpu,
598                &metadatas_a_gpu,
599                &bucket_sizes_gpu,
600                &buckets_b_gpu,
601                &positions_f6_gpu,
602                &metadatas_b_gpu,
603            );
604
605        let (bind_group_find_matches_and_compute_f7, compute_pipeline_find_matches_and_compute_f7) =
606            bind_group_and_pipeline_find_matches_and_compute_f7(
607                &device,
608                &module,
609                &buckets_b_gpu,
610                &metadatas_b_gpu,
611                &table_6_proof_targets_sizes_gpu,
612                &table_6_proof_targets_gpu,
613            );
614
615        let (bind_group_find_proofs, compute_pipeline_find_proofs) =
616            bind_group_and_pipeline_find_proofs(
617                &device,
618                &module,
619                &positions_f2_gpu,
620                &positions_f3_gpu,
621                &positions_f4_gpu,
622                &positions_f5_gpu,
623                &positions_f6_gpu,
624                &table_6_proof_targets_sizes_gpu,
625                &table_6_proof_targets_gpu,
626                &proofs_gpu,
627            );
628
629        Self {
630            device,
631            queue,
632            mapping_error: Arc::new(Mutex::new(None)),
633            tainted: false,
634            initial_state_host,
635            initial_state_gpu,
636            proofs_host,
637            proofs_gpu,
638            bind_group_compute_f1,
639            compute_pipeline_compute_f1,
640            bind_group_sort_buckets_a,
641            compute_pipeline_sort_buckets_a,
642            bind_group_sort_buckets_b,
643            compute_pipeline_sort_buckets_b,
644            bind_group_find_matches_and_compute_f2,
645            compute_pipeline_find_matches_and_compute_f2,
646            bind_group_find_matches_and_compute_f3,
647            compute_pipeline_find_matches_and_compute_f3,
648            bind_group_find_matches_and_compute_f4,
649            compute_pipeline_find_matches_and_compute_f4,
650            bind_group_find_matches_and_compute_f5,
651            compute_pipeline_find_matches_and_compute_f5,
652            bind_group_find_matches_and_compute_f6,
653            compute_pipeline_find_matches_and_compute_f6,
654            bind_group_find_matches_and_compute_f7,
655            compute_pipeline_find_matches_and_compute_f7,
656            bind_group_find_proofs,
657            compute_pipeline_find_proofs,
658        }
659    }
660
661    fn create_proofs(
662        &mut self,
663        seed: &PosSeed,
664    ) -> Result<ProofsHostWrapper<'_>, RecordEncodingError> {
665        if self.tainted {
666            return Err(RecordEncodingError::DeviceBroken);
667        }
668        self.tainted = true;
669
670        let mut encoder = self
671            .device
672            .create_command_encoder(&CommandEncoderDescriptor {
673                label: Some("create_proofs"),
674            });
675
676        // Mapped initially and re-mapped at the end of the computation
677        self.initial_state_host
678            .get_mapped_range_mut(..)
679            .copy_from_slice(&block_to_bytes(
680                &ChaCha8State::init(seed, &[0; _]).to_repr(),
681            ));
682        self.initial_state_host.unmap();
683
684        encoder.copy_buffer_to_buffer(
685            &self.initial_state_host,
686            0,
687            &self.initial_state_gpu,
688            0,
689            self.initial_state_host.size(),
690        );
691
692        {
693            let mut cpass = encoder.begin_compute_pass(&ComputePassDescriptor {
694                label: Some("create_proofs"),
695                timestamp_writes: None,
696            });
697
698            cpass.set_bind_group(0, &self.bind_group_compute_f1, &[]);
699            cpass.set_pipeline(&self.compute_pipeline_compute_f1);
700            cpass.dispatch_workgroups(
701                MAX_TABLE_SIZE
702                    .div_ceil(compute_f1::WORKGROUP_SIZE * compute_f1::ELEMENTS_PER_INVOCATION),
703                1,
704                1,
705            );
706
707            cpass.set_bind_group(0, &self.bind_group_sort_buckets_a, &[]);
708            cpass.set_pipeline(&self.compute_pipeline_sort_buckets_a);
709            cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
710
711            cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f2, &[]);
712            cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f2);
713            cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
714
715            cpass.set_bind_group(0, &self.bind_group_sort_buckets_b, &[]);
716            cpass.set_pipeline(&self.compute_pipeline_sort_buckets_b);
717            cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
718
719            cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f3, &[]);
720            cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f3);
721            cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
722
723            cpass.set_bind_group(0, &self.bind_group_sort_buckets_a, &[]);
724            cpass.set_pipeline(&self.compute_pipeline_sort_buckets_a);
725            cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
726
727            cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f4, &[]);
728            cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f4);
729            cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
730
731            cpass.set_bind_group(0, &self.bind_group_sort_buckets_b, &[]);
732            cpass.set_pipeline(&self.compute_pipeline_sort_buckets_b);
733            cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
734
735            cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f5, &[]);
736            cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f5);
737            cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
738
739            cpass.set_bind_group(0, &self.bind_group_sort_buckets_a, &[]);
740            cpass.set_pipeline(&self.compute_pipeline_sort_buckets_a);
741            cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
742
743            cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f6, &[]);
744            cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f6);
745            cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
746
747            cpass.set_bind_group(0, &self.bind_group_sort_buckets_b, &[]);
748            cpass.set_pipeline(&self.compute_pipeline_sort_buckets_b);
749            cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
750
751            cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f7, &[]);
752            cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f7);
753            cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
754
755            cpass.set_bind_group(0, &self.bind_group_find_proofs, &[]);
756            cpass.set_pipeline(&self.compute_pipeline_find_proofs);
757            cpass.dispatch_workgroups(NUM_S_BUCKETS as u32 / find_proofs::WORKGROUP_SIZE, 1, 1);
758        }
759
760        encoder.copy_buffer_to_buffer(
761            &self.proofs_gpu,
762            0,
763            &self.proofs_host,
764            0,
765            self.proofs_host.size(),
766        );
767
768        // Map initial state for writes for the next iteration
769        encoder.map_buffer_on_submit(&self.initial_state_host, MapMode::Write, .., {
770            let mapping_error = Arc::clone(&self.mapping_error);
771
772            move |r| {
773                if let Err(error) = r {
774                    mapping_error.lock().replace(error);
775                }
776            }
777        });
778        encoder.map_buffer_on_submit(&self.proofs_host, MapMode::Read, .., {
779            let mapping_error = Arc::clone(&self.mapping_error);
780
781            move |r| {
782                if let Err(error) = r {
783                    mapping_error.lock().replace(error);
784                }
785            }
786        });
787
788        let submission_index = self.queue.submit([encoder.finish()]);
789
790        self.device.poll(PollType::Wait {
791            submission_index: Some(submission_index),
792            timeout: None,
793        })?;
794
795        if let Some(error) = self.mapping_error.lock().take() {
796            return Err(RecordEncodingError::BufferMapping(error));
797        }
798
799        let proofs = {
800            let proofs_host_ptr = self
801                .proofs_host
802                .get_mapped_range(..)
803                .as_ptr()
804                .cast::<ProofsHost>();
805            // SAFETY: Initialized on the GPU
806            unsafe { &*proofs_host_ptr }
807        };
808
809        self.tainted = false;
810
811        Ok(ProofsHostWrapper {
812            proofs,
813            proofs_host: &self.proofs_host,
814        })
815    }
816}
817
818fn bind_group_and_pipeline_compute_f1(
819    device: &wgpu::Device,
820    module: &ShaderModule,
821    initial_state_gpu: &Buffer,
822    bucket_sizes_gpu: &Buffer,
823    buckets_gpu: &Buffer,
824) -> (BindGroup, ComputePipeline) {
825    let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
826        label: Some("compute_f1"),
827        entries: &[
828            BindGroupLayoutEntry {
829                binding: 0,
830                count: None,
831                visibility: ShaderStages::COMPUTE,
832                ty: BindingType::Buffer {
833                    has_dynamic_offset: false,
834                    min_binding_size: None,
835                    ty: BufferBindingType::Uniform,
836                },
837            },
838            BindGroupLayoutEntry {
839                binding: 1,
840                count: None,
841                visibility: ShaderStages::COMPUTE,
842                ty: BindingType::Buffer {
843                    has_dynamic_offset: false,
844                    min_binding_size: None,
845                    ty: BufferBindingType::Storage { read_only: false },
846                },
847            },
848            BindGroupLayoutEntry {
849                binding: 2,
850                count: None,
851                visibility: ShaderStages::COMPUTE,
852                ty: BindingType::Buffer {
853                    has_dynamic_offset: false,
854                    min_binding_size: None,
855                    ty: BufferBindingType::Storage { read_only: false },
856                },
857            },
858        ],
859    });
860
861    let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
862        label: Some("compute_f1"),
863        bind_group_layouts: &[Some(&bind_group_layout)],
864        immediate_size: 0,
865    });
866
867    let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
868        compilation_options: PipelineCompilationOptions {
869            constants: &[],
870            zero_initialize_workgroup_memory: false,
871        },
872        cache: None,
873        label: Some("compute_f1"),
874        layout: Some(&pipeline_layout),
875        module,
876        entry_point: Some("compute_f1"),
877    });
878
879    let bind_group = device.create_bind_group(&BindGroupDescriptor {
880        label: Some("compute_f1"),
881        layout: &bind_group_layout,
882        entries: &[
883            BindGroupEntry {
884                binding: 0,
885                resource: initial_state_gpu.as_entire_binding(),
886            },
887            BindGroupEntry {
888                binding: 1,
889                resource: bucket_sizes_gpu.as_entire_binding(),
890            },
891            BindGroupEntry {
892                binding: 2,
893                resource: buckets_gpu.as_entire_binding(),
894            },
895        ],
896    });
897
898    (bind_group, compute_pipeline)
899}
900
901fn bind_group_and_pipeline_sort_buckets(
902    device: &wgpu::Device,
903    module: &ShaderModule,
904    bucket_sizes_gpu: &Buffer,
905    buckets_gpu: &Buffer,
906) -> (BindGroup, ComputePipeline) {
907    let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
908        label: Some("sort_buckets"),
909        entries: &[
910            BindGroupLayoutEntry {
911                binding: 0,
912                count: None,
913                visibility: ShaderStages::COMPUTE,
914                ty: BindingType::Buffer {
915                    has_dynamic_offset: false,
916                    min_binding_size: None,
917                    ty: BufferBindingType::Storage { read_only: false },
918                },
919            },
920            BindGroupLayoutEntry {
921                binding: 1,
922                count: None,
923                visibility: ShaderStages::COMPUTE,
924                ty: BindingType::Buffer {
925                    has_dynamic_offset: false,
926                    min_binding_size: None,
927                    ty: BufferBindingType::Storage { read_only: false },
928                },
929            },
930        ],
931    });
932
933    let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
934        label: Some("sort_buckets"),
935        bind_group_layouts: &[Some(&bind_group_layout)],
936        immediate_size: 0,
937    });
938
939    let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
940        compilation_options: PipelineCompilationOptions {
941            constants: &[],
942            zero_initialize_workgroup_memory: false,
943        },
944        cache: None,
945        label: Some("sort_buckets"),
946        layout: Some(&pipeline_layout),
947        module,
948        entry_point: Some("sort_buckets"),
949    });
950
951    let bind_group = device.create_bind_group(&BindGroupDescriptor {
952        label: Some("sort_buckets"),
953        layout: &bind_group_layout,
954        entries: &[
955            BindGroupEntry {
956                binding: 0,
957                resource: bucket_sizes_gpu.as_entire_binding(),
958            },
959            BindGroupEntry {
960                binding: 1,
961                resource: buckets_gpu.as_entire_binding(),
962            },
963        ],
964    });
965
966    (bind_group, compute_pipeline)
967}
968
969fn bind_group_and_pipeline_find_matches_and_compute_f2(
970    device: &wgpu::Device,
971    module: &ShaderModule,
972    parent_buckets_gpu: &Buffer,
973    bucket_sizes_gpu: &Buffer,
974    buckets_gpu: &Buffer,
975    positions_gpu: &Buffer,
976    metadatas_gpu: &Buffer,
977) -> (BindGroup, ComputePipeline) {
978    let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
979        label: Some("find_matches_and_compute_f2"),
980        entries: &[
981            BindGroupLayoutEntry {
982                binding: 0,
983                count: None,
984                visibility: ShaderStages::COMPUTE,
985                ty: BindingType::Buffer {
986                    has_dynamic_offset: false,
987                    min_binding_size: None,
988                    ty: BufferBindingType::Storage { read_only: true },
989                },
990            },
991            BindGroupLayoutEntry {
992                binding: 1,
993                count: None,
994                visibility: ShaderStages::COMPUTE,
995                ty: BindingType::Buffer {
996                    has_dynamic_offset: false,
997                    min_binding_size: None,
998                    ty: BufferBindingType::Storage { read_only: false },
999                },
1000            },
1001            BindGroupLayoutEntry {
1002                binding: 2,
1003                count: None,
1004                visibility: ShaderStages::COMPUTE,
1005                ty: BindingType::Buffer {
1006                    has_dynamic_offset: false,
1007                    min_binding_size: None,
1008                    ty: BufferBindingType::Storage { read_only: false },
1009                },
1010            },
1011            BindGroupLayoutEntry {
1012                binding: 3,
1013                count: None,
1014                visibility: ShaderStages::COMPUTE,
1015                ty: BindingType::Buffer {
1016                    has_dynamic_offset: false,
1017                    min_binding_size: None,
1018                    ty: BufferBindingType::Storage { read_only: false },
1019                },
1020            },
1021            BindGroupLayoutEntry {
1022                binding: 4,
1023                count: None,
1024                visibility: ShaderStages::COMPUTE,
1025                ty: BindingType::Buffer {
1026                    has_dynamic_offset: false,
1027                    min_binding_size: None,
1028                    ty: BufferBindingType::Storage { read_only: false },
1029                },
1030            },
1031        ],
1032    });
1033
1034    let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1035        label: Some("find_matches_and_compute_f2"),
1036        bind_group_layouts: &[Some(&bind_group_layout)],
1037        immediate_size: 0,
1038    });
1039
1040    let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1041        compilation_options: PipelineCompilationOptions {
1042            constants: &[],
1043            zero_initialize_workgroup_memory: true,
1044        },
1045        cache: None,
1046        label: Some("find_matches_and_compute_f2"),
1047        layout: Some(&pipeline_layout),
1048        module,
1049        entry_point: Some("find_matches_and_compute_f2"),
1050    });
1051
1052    let bind_group = device.create_bind_group(&BindGroupDescriptor {
1053        label: Some("find_matches_and_compute_f2"),
1054        layout: &bind_group_layout,
1055        entries: &[
1056            BindGroupEntry {
1057                binding: 0,
1058                resource: parent_buckets_gpu.as_entire_binding(),
1059            },
1060            BindGroupEntry {
1061                binding: 1,
1062                resource: bucket_sizes_gpu.as_entire_binding(),
1063            },
1064            BindGroupEntry {
1065                binding: 2,
1066                resource: buckets_gpu.as_entire_binding(),
1067            },
1068            BindGroupEntry {
1069                binding: 3,
1070                resource: positions_gpu.as_entire_binding(),
1071            },
1072            BindGroupEntry {
1073                binding: 4,
1074                resource: metadatas_gpu.as_entire_binding(),
1075            },
1076        ],
1077    });
1078
1079    (bind_group, compute_pipeline)
1080}
1081
1082#[expect(
1083    clippy::too_many_arguments,
1084    reason = "Both I/O and Vulkan stuff together take a lot of arguments"
1085)]
1086fn bind_group_and_pipeline_find_matches_and_compute_fn<const TABLE_NUMBER: u8>(
1087    device: &wgpu::Device,
1088    module: &ShaderModule,
1089    parent_buckets_gpu: &Buffer,
1090    parent_metadatas_gpu: &Buffer,
1091    bucket_sizes_gpu: &Buffer,
1092    buckets_gpu: &Buffer,
1093    positions_gpu: &Buffer,
1094    metadatas_gpu: &Buffer,
1095) -> (BindGroup, ComputePipeline) {
1096    let label = format!("find_matches_and_compute_f{TABLE_NUMBER}");
1097    let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
1098        label: Some(&label),
1099        entries: &[
1100            BindGroupLayoutEntry {
1101                binding: 0,
1102                count: None,
1103                visibility: ShaderStages::COMPUTE,
1104                ty: BindingType::Buffer {
1105                    has_dynamic_offset: false,
1106                    min_binding_size: None,
1107                    ty: BufferBindingType::Storage { read_only: true },
1108                },
1109            },
1110            BindGroupLayoutEntry {
1111                binding: 1,
1112                count: None,
1113                visibility: ShaderStages::COMPUTE,
1114                ty: BindingType::Buffer {
1115                    has_dynamic_offset: false,
1116                    min_binding_size: None,
1117                    ty: BufferBindingType::Storage { read_only: true },
1118                },
1119            },
1120            BindGroupLayoutEntry {
1121                binding: 2,
1122                count: None,
1123                visibility: ShaderStages::COMPUTE,
1124                ty: BindingType::Buffer {
1125                    has_dynamic_offset: false,
1126                    min_binding_size: None,
1127                    ty: BufferBindingType::Storage { read_only: false },
1128                },
1129            },
1130            BindGroupLayoutEntry {
1131                binding: 3,
1132                count: None,
1133                visibility: ShaderStages::COMPUTE,
1134                ty: BindingType::Buffer {
1135                    has_dynamic_offset: false,
1136                    min_binding_size: None,
1137                    ty: BufferBindingType::Storage { read_only: false },
1138                },
1139            },
1140            BindGroupLayoutEntry {
1141                binding: 4,
1142                count: None,
1143                visibility: ShaderStages::COMPUTE,
1144                ty: BindingType::Buffer {
1145                    has_dynamic_offset: false,
1146                    min_binding_size: None,
1147                    ty: BufferBindingType::Storage { read_only: false },
1148                },
1149            },
1150            BindGroupLayoutEntry {
1151                binding: 5,
1152                count: None,
1153                visibility: ShaderStages::COMPUTE,
1154                ty: BindingType::Buffer {
1155                    has_dynamic_offset: false,
1156                    min_binding_size: None,
1157                    ty: BufferBindingType::Storage { read_only: false },
1158                },
1159            },
1160        ],
1161    });
1162
1163    let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1164        label: Some(&label),
1165        bind_group_layouts: &[Some(&bind_group_layout)],
1166        immediate_size: 0,
1167    });
1168
1169    let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1170        compilation_options: PipelineCompilationOptions {
1171            constants: &[],
1172            zero_initialize_workgroup_memory: true,
1173        },
1174        cache: None,
1175        label: Some(&label),
1176        layout: Some(&pipeline_layout),
1177        module,
1178        entry_point: Some(&format!("find_matches_and_compute_f{TABLE_NUMBER}")),
1179    });
1180
1181    let bind_group = device.create_bind_group(&BindGroupDescriptor {
1182        label: Some(&label),
1183        layout: &bind_group_layout,
1184        entries: &[
1185            BindGroupEntry {
1186                binding: 0,
1187                resource: parent_buckets_gpu.as_entire_binding(),
1188            },
1189            BindGroupEntry {
1190                binding: 1,
1191                resource: parent_metadatas_gpu.as_entire_binding(),
1192            },
1193            BindGroupEntry {
1194                binding: 2,
1195                resource: bucket_sizes_gpu.as_entire_binding(),
1196            },
1197            BindGroupEntry {
1198                binding: 3,
1199                resource: buckets_gpu.as_entire_binding(),
1200            },
1201            BindGroupEntry {
1202                binding: 4,
1203                resource: positions_gpu.as_entire_binding(),
1204            },
1205            BindGroupEntry {
1206                binding: 5,
1207                resource: metadatas_gpu.as_entire_binding(),
1208            },
1209        ],
1210    });
1211
1212    (bind_group, compute_pipeline)
1213}
1214
1215fn bind_group_and_pipeline_find_matches_and_compute_f7(
1216    device: &wgpu::Device,
1217    module: &ShaderModule,
1218    parent_buckets_gpu: &Buffer,
1219    parent_metadatas_gpu: &Buffer,
1220    table_6_proof_targets_sizes_gpu: &Buffer,
1221    table_6_proof_targets_gpu: &Buffer,
1222) -> (BindGroup, ComputePipeline) {
1223    let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
1224        label: Some("find_matches_and_compute_f7"),
1225        entries: &[
1226            BindGroupLayoutEntry {
1227                binding: 0,
1228                count: None,
1229                visibility: ShaderStages::COMPUTE,
1230                ty: BindingType::Buffer {
1231                    has_dynamic_offset: false,
1232                    min_binding_size: None,
1233                    ty: BufferBindingType::Storage { read_only: true },
1234                },
1235            },
1236            BindGroupLayoutEntry {
1237                binding: 1,
1238                count: None,
1239                visibility: ShaderStages::COMPUTE,
1240                ty: BindingType::Buffer {
1241                    has_dynamic_offset: false,
1242                    min_binding_size: None,
1243                    ty: BufferBindingType::Storage { read_only: true },
1244                },
1245            },
1246            BindGroupLayoutEntry {
1247                binding: 2,
1248                count: None,
1249                visibility: ShaderStages::COMPUTE,
1250                ty: BindingType::Buffer {
1251                    has_dynamic_offset: false,
1252                    min_binding_size: None,
1253                    ty: BufferBindingType::Storage { read_only: false },
1254                },
1255            },
1256            BindGroupLayoutEntry {
1257                binding: 3,
1258                count: None,
1259                visibility: ShaderStages::COMPUTE,
1260                ty: BindingType::Buffer {
1261                    has_dynamic_offset: false,
1262                    min_binding_size: None,
1263                    ty: BufferBindingType::Storage { read_only: false },
1264                },
1265            },
1266        ],
1267    });
1268
1269    let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1270        label: Some("find_matches_and_compute_f7"),
1271        bind_group_layouts: &[Some(&bind_group_layout)],
1272        immediate_size: 0,
1273    });
1274
1275    let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1276        compilation_options: PipelineCompilationOptions {
1277            constants: &[],
1278            zero_initialize_workgroup_memory: true,
1279        },
1280        cache: None,
1281        label: Some("find_matches_and_compute_f7"),
1282        layout: Some(&pipeline_layout),
1283        module,
1284        entry_point: Some("find_matches_and_compute_f7"),
1285    });
1286
1287    let bind_group = device.create_bind_group(&BindGroupDescriptor {
1288        label: Some("find_matches_and_compute_f7"),
1289        layout: &bind_group_layout,
1290        entries: &[
1291            BindGroupEntry {
1292                binding: 0,
1293                resource: parent_buckets_gpu.as_entire_binding(),
1294            },
1295            BindGroupEntry {
1296                binding: 1,
1297                resource: parent_metadatas_gpu.as_entire_binding(),
1298            },
1299            BindGroupEntry {
1300                binding: 2,
1301                resource: table_6_proof_targets_sizes_gpu.as_entire_binding(),
1302            },
1303            BindGroupEntry {
1304                binding: 3,
1305                resource: table_6_proof_targets_gpu.as_entire_binding(),
1306            },
1307        ],
1308    });
1309
1310    (bind_group, compute_pipeline)
1311}
1312
1313#[expect(
1314    clippy::too_many_arguments,
1315    reason = "Both I/O and Vulkan stuff together take a lot of arguments"
1316)]
1317fn bind_group_and_pipeline_find_proofs(
1318    device: &wgpu::Device,
1319    module: &ShaderModule,
1320    table_2_positions_gpu: &Buffer,
1321    table_3_positions_gpu: &Buffer,
1322    table_4_positions_gpu: &Buffer,
1323    table_5_positions_gpu: &Buffer,
1324    table_6_positions_gpu: &Buffer,
1325    bucket_sizes_gpu: &Buffer,
1326    buckets_gpu: &Buffer,
1327    proofs_gpu: &Buffer,
1328) -> (BindGroup, ComputePipeline) {
1329    let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
1330        label: Some("find_proofs"),
1331        entries: &[
1332            BindGroupLayoutEntry {
1333                binding: 0,
1334                count: None,
1335                visibility: ShaderStages::COMPUTE,
1336                ty: BindingType::Buffer {
1337                    has_dynamic_offset: false,
1338                    min_binding_size: None,
1339                    ty: BufferBindingType::Storage { read_only: true },
1340                },
1341            },
1342            BindGroupLayoutEntry {
1343                binding: 1,
1344                count: None,
1345                visibility: ShaderStages::COMPUTE,
1346                ty: BindingType::Buffer {
1347                    has_dynamic_offset: false,
1348                    min_binding_size: None,
1349                    ty: BufferBindingType::Storage { read_only: true },
1350                },
1351            },
1352            BindGroupLayoutEntry {
1353                binding: 2,
1354                count: None,
1355                visibility: ShaderStages::COMPUTE,
1356                ty: BindingType::Buffer {
1357                    has_dynamic_offset: false,
1358                    min_binding_size: None,
1359                    ty: BufferBindingType::Storage { read_only: true },
1360                },
1361            },
1362            BindGroupLayoutEntry {
1363                binding: 3,
1364                count: None,
1365                visibility: ShaderStages::COMPUTE,
1366                ty: BindingType::Buffer {
1367                    has_dynamic_offset: false,
1368                    min_binding_size: None,
1369                    ty: BufferBindingType::Storage { read_only: true },
1370                },
1371            },
1372            BindGroupLayoutEntry {
1373                binding: 4,
1374                count: None,
1375                visibility: ShaderStages::COMPUTE,
1376                ty: BindingType::Buffer {
1377                    has_dynamic_offset: false,
1378                    min_binding_size: None,
1379                    ty: BufferBindingType::Storage { read_only: true },
1380                },
1381            },
1382            BindGroupLayoutEntry {
1383                binding: 5,
1384                count: None,
1385                visibility: ShaderStages::COMPUTE,
1386                ty: BindingType::Buffer {
1387                    has_dynamic_offset: false,
1388                    min_binding_size: None,
1389                    ty: BufferBindingType::Storage { read_only: false },
1390                },
1391            },
1392            BindGroupLayoutEntry {
1393                binding: 6,
1394                count: None,
1395                visibility: ShaderStages::COMPUTE,
1396                ty: BindingType::Buffer {
1397                    has_dynamic_offset: false,
1398                    min_binding_size: None,
1399                    ty: BufferBindingType::Storage { read_only: true },
1400                },
1401            },
1402            BindGroupLayoutEntry {
1403                binding: 7,
1404                count: None,
1405                visibility: ShaderStages::COMPUTE,
1406                ty: BindingType::Buffer {
1407                    has_dynamic_offset: false,
1408                    min_binding_size: None,
1409                    ty: BufferBindingType::Storage { read_only: false },
1410                },
1411            },
1412        ],
1413    });
1414
1415    let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1416        label: Some("find_proofs"),
1417        bind_group_layouts: &[Some(&bind_group_layout)],
1418        immediate_size: 0,
1419    });
1420
1421    let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1422        compilation_options: PipelineCompilationOptions {
1423            constants: &[],
1424            zero_initialize_workgroup_memory: false,
1425        },
1426        cache: None,
1427        label: Some("find_proofs"),
1428        layout: Some(&pipeline_layout),
1429        module,
1430        entry_point: Some("find_proofs"),
1431    });
1432
1433    let bind_group = device.create_bind_group(&BindGroupDescriptor {
1434        label: Some("find_proofs"),
1435        layout: &bind_group_layout,
1436        entries: &[
1437            BindGroupEntry {
1438                binding: 0,
1439                resource: table_2_positions_gpu.as_entire_binding(),
1440            },
1441            BindGroupEntry {
1442                binding: 1,
1443                resource: table_3_positions_gpu.as_entire_binding(),
1444            },
1445            BindGroupEntry {
1446                binding: 2,
1447                resource: table_4_positions_gpu.as_entire_binding(),
1448            },
1449            BindGroupEntry {
1450                binding: 3,
1451                resource: table_5_positions_gpu.as_entire_binding(),
1452            },
1453            BindGroupEntry {
1454                binding: 4,
1455                resource: table_6_positions_gpu.as_entire_binding(),
1456            },
1457            BindGroupEntry {
1458                binding: 5,
1459                resource: bucket_sizes_gpu.as_entire_binding(),
1460            },
1461            BindGroupEntry {
1462                binding: 6,
1463                resource: buckets_gpu.as_entire_binding(),
1464            },
1465            BindGroupEntry {
1466                binding: 7,
1467                resource: proofs_gpu.as_entire_binding(),
1468            },
1469        ],
1470    });
1471
1472    (bind_group, compute_pipeline)
1473}