1#[cfg(all(test, not(miri)))]
2mod tests;
3
4use crate::shader::constants::{
5 MAX_BUCKET_SIZE, MAX_TABLE_SIZE, NUM_BUCKETS, NUM_MATCH_BUCKETS, NUM_S_BUCKETS,
6 REDUCED_MATCHES_COUNT,
7};
8use crate::shader::find_matches_and_compute_f7::{NUM_ELEMENTS_PER_S_BUCKET, ProofTargets};
9use crate::shader::find_proofs::ProofsHost;
10use crate::shader::types::{Metadata, Position, PositionR};
11use crate::shader::{compute_f1, find_proofs, select_shader_features_limits};
12use ab_chacha8::{ChaCha8Block, ChaCha8State, block_to_bytes};
13use ab_core_primitives::pieces::{PieceOffset, Record};
14use ab_core_primitives::pos::PosSeed;
15use ab_core_primitives::sectors::SectorId;
16use ab_erasure_coding::ErasureCoding;
17use ab_farmer_components::plotting::RecordsEncoder;
18use ab_farmer_components::sector::SectorContentsMap;
19use async_lock::Mutex as AsyncMutex;
20use futures::stream::FuturesOrdered;
21use futures::{StreamExt, TryStreamExt};
22use parking_lot::Mutex;
23use rayon::prelude::*;
24use rayon::{ThreadPool, ThreadPoolBuilder};
25use rclite::Arc;
26use std::fmt;
27use std::num::NonZeroU8;
28use std::simd::Simd;
29use std::sync::Arc as StdArc;
30use std::sync::atomic::{AtomicBool, Ordering};
31use tracing::{debug, warn};
32use wgpu::{
33 AdapterInfo, Backend, BackendOptions, Backends, BindGroup, BindGroupDescriptor, BindGroupEntry,
34 BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType, Buffer, BufferAddress,
35 BufferAsyncError, BufferBindingType, BufferDescriptor, BufferUsages, CommandEncoderDescriptor,
36 ComputePassDescriptor, ComputePipeline, ComputePipelineDescriptor, DeviceDescriptor,
37 DeviceType, Instance, InstanceDescriptor, InstanceFlags, MapMode, MemoryBudgetThresholds,
38 PipelineCompilationOptions, PipelineLayoutDescriptor, PollError, PollType, Queue,
39 RequestDeviceError, ShaderModule, ShaderStages,
40};
41
42#[derive(Debug, thiserror::Error)]
44enum RecordEncodingError {
45 #[error("Too many records: {0}")]
47 TooManyRecords(usize),
48 #[error("Proof creation failed previously and the device is now considered broken")]
50 DeviceBroken,
51 #[error("Failed to map buffer: {0}")]
53 BufferMapping(#[from] BufferAsyncError),
54 #[error("Poll error: {0}")]
56 DevicePoll(#[from] PollError),
57}
58
59struct ProofsHostWrapper<'a> {
60 proofs: &'a ProofsHost,
61 proofs_host: &'a Buffer,
62}
63
64impl Drop for ProofsHostWrapper<'_> {
65 fn drop(&mut self) {
66 self.proofs_host.unmap();
67 }
68}
69
70#[derive(Clone)]
72pub struct Device {
73 id: u32,
74 devices: Vec<(wgpu::Device, Queue, ShaderModule)>,
75 adapter_info: AdapterInfo,
76}
77
78impl fmt::Debug for Device {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 f.debug_struct("Device")
81 .field("id", &self.id)
82 .field("name", &self.adapter_info.name)
83 .field("device_type", &self.adapter_info.device_type)
84 .field("driver", &self.adapter_info.driver)
85 .field("driver_info", &self.adapter_info.driver_info)
86 .field("backend", &self.adapter_info.backend)
87 .finish_non_exhaustive()
88 }
89}
90
91impl Device {
92 pub async fn enumerate<NOQ>(number_of_queues: NOQ) -> Vec<Self>
94 where
95 NOQ: Fn(DeviceType) -> NonZeroU8,
96 {
97 let backends = Backends::from_env().unwrap_or(Backends::METAL | Backends::VULKAN);
98 let instance = Instance::new(InstanceDescriptor {
99 backends,
100 flags: if cfg!(debug_assertions) {
101 InstanceFlags::debugging().with_env()
102 } else {
103 InstanceFlags::from_env_or_default()
104 },
105 memory_budget_thresholds: MemoryBudgetThresholds::default(),
106 backend_options: BackendOptions::from_env_or_default(),
107 display: None,
108 });
109
110 let adapters = instance.enumerate_adapters(backends).await;
111 let number_of_queues = &number_of_queues;
112
113 adapters
114 .into_iter()
115 .zip(0..)
116 .map(|(adapter, id)| async move {
117 let adapter_info = adapter.get_info();
118
119 let (shader, required_features, required_limits) =
120 match select_shader_features_limits(&adapter) {
121 Some((shader, required_features, required_limits)) => {
122 debug!(
123 %id,
124 adapter_info = ?adapter_info,
125 "Compatible adapter found"
126 );
127
128 (shader, required_features, required_limits)
129 }
130 None => {
131 debug!(
132 %id,
133 adapter_info = ?adapter_info,
134 "Incompatible adapter found"
135 );
136
137 return None;
138 }
139 };
140
141 let devices = (0..number_of_queues(adapter_info.device_type).get())
144 .map(|_| async {
145 let (device, queue) = adapter
146 .request_device(&DeviceDescriptor {
147 label: None,
148 required_features,
149 required_limits: required_limits.clone(),
150 ..DeviceDescriptor::default()
151 })
152 .await
153 .inspect_err(|error| {
154 warn!(%id, ?adapter_info, %error, "Failed to request the device");
155 })?;
156 let module = device.create_shader_module(shader.clone());
157
158 Ok::<_, RequestDeviceError>((device, queue, module))
159 })
160 .collect::<FuturesOrdered<_>>()
161 .try_collect::<Vec<_>>()
162 .await
163 .ok()?;
164
165 Some(Self {
166 id,
167 devices,
168 adapter_info,
169 })
170 })
171 .collect::<FuturesOrdered<_>>()
172 .filter_map(|device| async move { device })
173 .collect()
174 .await
175 }
176
177 pub fn id(&self) -> u32 {
179 self.id
180 }
181
182 pub fn name(&self) -> &str {
184 &self.adapter_info.name
185 }
186
187 pub fn device_type(&self) -> DeviceType {
189 self.adapter_info.device_type
190 }
191
192 pub fn driver(&self) -> &str {
194 &self.adapter_info.driver
195 }
196
197 pub fn driver_info(&self) -> &str {
199 &self.adapter_info.driver_info
200 }
201
202 pub fn backend(&self) -> Backend {
204 self.adapter_info.backend
205 }
206
207 pub fn instantiate(
208 &self,
209 erasure_coding: ErasureCoding,
210 global_mutex: StdArc<AsyncMutex<()>>,
211 ) -> anyhow::Result<GpuRecordsEncoder> {
212 GpuRecordsEncoder::new(
213 self.id,
214 self.devices.clone(),
215 self.adapter_info.clone(),
216 erasure_coding,
217 global_mutex,
218 )
219 }
220}
221
222pub struct GpuRecordsEncoder {
223 id: u32,
224 instances: Vec<Mutex<GpuRecordsEncoderInstance>>,
225 thread_pool: ThreadPool,
226 adapter_info: AdapterInfo,
227 erasure_coding: ErasureCoding,
228 global_mutex: StdArc<AsyncMutex<()>>,
229}
230
231impl fmt::Debug for GpuRecordsEncoder {
232 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
233 f.debug_struct("GpuRecordsEncoder")
234 .field("id", &self.id)
235 .field("name", &self.adapter_info.name)
236 .field("device_type", &self.adapter_info.device_type)
237 .field("driver", &self.adapter_info.driver)
238 .field("driver_info", &self.adapter_info.driver_info)
239 .field("backend", &self.adapter_info.backend)
240 .finish_non_exhaustive()
241 }
242}
243
244impl RecordsEncoder for GpuRecordsEncoder {
245 fn encode_records(
247 &mut self,
248 sector_id: &SectorId,
249 records: &mut [Record],
250 abort_early: &AtomicBool,
251 ) -> anyhow::Result<SectorContentsMap> {
252 let mut sector_contents_map = SectorContentsMap::new(
253 u16::try_from(records.len())
254 .map_err(|_| RecordEncodingError::TooManyRecords(records.len()))?,
255 );
256
257 let maybe_error = self.thread_pool.install(|| {
258 records
259 .par_iter_mut()
260 .zip(sector_contents_map.iter_record_chunks_used_mut())
261 .enumerate()
262 .find_map_any(|(piece_offset, (record, record_chunks_used))| {
263 self.global_mutex.lock_blocking();
265
266 let mut parity_record_chunks = Record::new_boxed();
267
268 self.erasure_coding
271 .extend(record.iter(), parity_record_chunks.iter_mut())
272 .expect("Statically guaranteed valid inputs; qed");
273
274 if abort_early.load(Ordering::Relaxed) {
275 return None;
276 }
277 let seed =
278 sector_id.derive_evaluation_seed(PieceOffset::from(piece_offset as u16));
279 let thread_index = rayon::current_thread_index().unwrap_or_default();
280 let mut encoder_instance = self.instances[thread_index]
281 .try_lock()
282 .expect("1:1 mapping between threads and devices; qed");
283 let proofs = match encoder_instance.create_proofs(&seed) {
284 Ok(proofs) => proofs,
285 Err(error) => {
286 return Some(error);
287 }
288 };
289 let proofs = proofs.proofs;
290
291 *record_chunks_used = proofs.found_proofs;
292
293 let mut num_found_proofs = 0_usize;
295 for (s_buckets, found_proofs) in (0..Record::NUM_S_BUCKETS)
296 .array_chunks::<{ u8::BITS as usize }>()
297 .zip(record_chunks_used)
298 {
299 for (proof_offset, s_bucket) in s_buckets.into_iter().enumerate() {
300 if num_found_proofs == Record::NUM_CHUNKS {
301 *found_proofs &=
303 u8::MAX.unbounded_shr(u8::BITS - proof_offset as u32);
304 break;
305 }
306 if (*found_proofs & (1 << proof_offset)) != 0 {
307 let record_chunk = if s_bucket < Record::NUM_CHUNKS {
308 record[s_bucket]
309 } else {
310 parity_record_chunks[s_bucket - Record::NUM_CHUNKS]
311 };
312
313 record[num_found_proofs] = (Simd::from(record_chunk)
314 ^ Simd::from(*proofs.proofs[s_bucket].hash()))
315 .to_array();
316 num_found_proofs += 1;
317 }
318 }
319 }
320
321 None
322 })
323 });
324
325 if let Some(error) = maybe_error {
326 return Err(error.into());
327 }
328
329 Ok(sector_contents_map)
330 }
331}
332
333impl GpuRecordsEncoder {
334 fn new(
335 id: u32,
336 devices: Vec<(wgpu::Device, Queue, ShaderModule)>,
337 adapter_info: AdapterInfo,
338 erasure_coding: ErasureCoding,
339 global_mutex: StdArc<AsyncMutex<()>>,
340 ) -> anyhow::Result<Self> {
341 let thread_pool = ThreadPoolBuilder::new()
342 .thread_name(move |thread_index| format!("pos-gpu-{id}.{thread_index}"))
343 .num_threads(devices.len())
344 .build()?;
345
346 Ok(Self {
347 id,
348 instances: devices
349 .into_iter()
350 .map(|(device, queue, module)| {
351 Mutex::new(GpuRecordsEncoderInstance::new(device, queue, module))
352 })
353 .collect(),
354 thread_pool,
355 adapter_info,
356 erasure_coding,
357 global_mutex,
358 })
359 }
360}
361
362struct GpuRecordsEncoderInstance {
363 device: wgpu::Device,
364 queue: Queue,
365 mapping_error: Arc<Mutex<Option<BufferAsyncError>>>,
366 tainted: bool,
367 initial_state_host: Buffer,
368 initial_state_gpu: Buffer,
369 proofs_host: Buffer,
370 proofs_gpu: Buffer,
371 bind_group_compute_f1: BindGroup,
372 compute_pipeline_compute_f1: ComputePipeline,
373 bind_group_sort_buckets_a: BindGroup,
374 compute_pipeline_sort_buckets_a: ComputePipeline,
375 bind_group_sort_buckets_b: BindGroup,
376 compute_pipeline_sort_buckets_b: ComputePipeline,
377 bind_group_find_matches_and_compute_f2: BindGroup,
378 compute_pipeline_find_matches_and_compute_f2: ComputePipeline,
379 bind_group_find_matches_and_compute_f3: BindGroup,
380 compute_pipeline_find_matches_and_compute_f3: ComputePipeline,
381 bind_group_find_matches_and_compute_f4: BindGroup,
382 compute_pipeline_find_matches_and_compute_f4: ComputePipeline,
383 bind_group_find_matches_and_compute_f5: BindGroup,
384 compute_pipeline_find_matches_and_compute_f5: ComputePipeline,
385 bind_group_find_matches_and_compute_f6: BindGroup,
386 compute_pipeline_find_matches_and_compute_f6: ComputePipeline,
387 bind_group_find_matches_and_compute_f7: BindGroup,
388 compute_pipeline_find_matches_and_compute_f7: ComputePipeline,
389 bind_group_find_proofs: BindGroup,
390 compute_pipeline_find_proofs: ComputePipeline,
391}
392
393impl GpuRecordsEncoderInstance {
394 fn new(device: wgpu::Device, queue: Queue, module: ShaderModule) -> Self {
395 let initial_state_host = device.create_buffer(&BufferDescriptor {
396 label: Some("initial_state_host"),
397 size: size_of::<ChaCha8Block>() as BufferAddress,
398 usage: BufferUsages::MAP_WRITE | BufferUsages::COPY_SRC,
399 mapped_at_creation: true,
400 });
401
402 let initial_state_gpu = device.create_buffer(&BufferDescriptor {
403 label: Some("initial_state_gpu"),
404 size: initial_state_host.size(),
405 usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST,
406 mapped_at_creation: false,
407 });
408
409 let bucket_sizes_gpu_buffer_size = size_of::<[u32; NUM_BUCKETS]>() as BufferAddress;
410 let table_6_proof_targets_sizes_gpu_buffer_size =
411 size_of::<[u32; NUM_S_BUCKETS]>() as BufferAddress;
412 let bucket_sizes_gpu = device.create_buffer(&BufferDescriptor {
416 label: Some("bucket_sizes_gpu"),
417 size: bucket_sizes_gpu_buffer_size.max(table_6_proof_targets_sizes_gpu_buffer_size),
418 usage: BufferUsages::STORAGE,
419 mapped_at_creation: false,
420 });
421 let table_6_proof_targets_sizes_gpu = bucket_sizes_gpu.clone();
423
424 let buckets_a_gpu = device.create_buffer(&BufferDescriptor {
425 label: Some("buckets_a_gpu"),
426 size: size_of::<[[PositionR; MAX_BUCKET_SIZE]; NUM_BUCKETS]>() as BufferAddress,
427 usage: BufferUsages::STORAGE,
428 mapped_at_creation: false,
429 });
430
431 let buckets_b_gpu = device.create_buffer(&BufferDescriptor {
432 label: Some("buckets_b_gpu"),
433 size: buckets_a_gpu.size(),
434 usage: BufferUsages::STORAGE,
435 mapped_at_creation: false,
436 });
437
438 let positions_f2_gpu = device.create_buffer(&BufferDescriptor {
439 label: Some("positions_f2_gpu"),
440 size: size_of::<[[[Position; 2]; REDUCED_MATCHES_COUNT]; NUM_MATCH_BUCKETS]>()
441 as BufferAddress,
442 usage: BufferUsages::STORAGE,
443 mapped_at_creation: false,
444 });
445
446 let positions_f3_gpu = device.create_buffer(&BufferDescriptor {
447 label: Some("positions_f3_gpu"),
448 size: positions_f2_gpu.size(),
449 usage: BufferUsages::STORAGE,
450 mapped_at_creation: false,
451 });
452
453 let positions_f4_gpu = device.create_buffer(&BufferDescriptor {
454 label: Some("positions_f4_gpu"),
455 size: positions_f2_gpu.size(),
456 usage: BufferUsages::STORAGE,
457 mapped_at_creation: false,
458 });
459
460 let positions_f5_gpu = device.create_buffer(&BufferDescriptor {
461 label: Some("positions_f5_gpu"),
462 size: positions_f2_gpu.size(),
463 usage: BufferUsages::STORAGE,
464 mapped_at_creation: false,
465 });
466
467 let positions_f6_gpu = device.create_buffer(&BufferDescriptor {
468 label: Some("positions_f6_gpu"),
469 size: positions_f2_gpu.size(),
470 usage: BufferUsages::STORAGE,
471 mapped_at_creation: false,
472 });
473
474 let metadatas_gpu_buffer_size =
475 size_of::<[[Metadata; REDUCED_MATCHES_COUNT]; NUM_MATCH_BUCKETS]>() as BufferAddress;
476 let table_6_proof_targets_gpu_buffer_size = size_of::<
477 [[ProofTargets; NUM_ELEMENTS_PER_S_BUCKET]; NUM_S_BUCKETS],
478 >() as BufferAddress;
479 let metadatas_a_gpu = device.create_buffer(&BufferDescriptor {
480 label: Some("metadatas_a_gpu"),
481 size: metadatas_gpu_buffer_size.max(table_6_proof_targets_gpu_buffer_size),
482 usage: BufferUsages::STORAGE,
483 mapped_at_creation: false,
484 });
485 let table_6_proof_targets_gpu = metadatas_a_gpu.clone();
487
488 let metadatas_b_gpu = device.create_buffer(&BufferDescriptor {
489 label: Some("metadatas_b_gpu"),
490 size: metadatas_gpu_buffer_size,
491 usage: BufferUsages::STORAGE,
492 mapped_at_creation: false,
493 });
494
495 let proofs_host = device.create_buffer(&BufferDescriptor {
496 label: Some("proofs_host"),
497 size: size_of::<ProofsHost>() as BufferAddress,
498 usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
499 mapped_at_creation: false,
500 });
501
502 let proofs_gpu = device.create_buffer(&BufferDescriptor {
503 label: Some("proofs_gpu"),
504 size: proofs_host.size(),
505 usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
506 mapped_at_creation: false,
507 });
508
509 let (bind_group_compute_f1, compute_pipeline_compute_f1) =
510 bind_group_and_pipeline_compute_f1(
511 &device,
512 &module,
513 &initial_state_gpu,
514 &bucket_sizes_gpu,
515 &buckets_a_gpu,
516 );
517 let (bind_group_sort_buckets_a, compute_pipeline_sort_buckets_a) =
518 bind_group_and_pipeline_sort_buckets(
519 &device,
520 &module,
521 &bucket_sizes_gpu,
522 &buckets_a_gpu,
523 );
524
525 let (bind_group_sort_buckets_b, compute_pipeline_sort_buckets_b) =
526 bind_group_and_pipeline_sort_buckets(
527 &device,
528 &module,
529 &bucket_sizes_gpu,
530 &buckets_b_gpu,
531 );
532
533 let (bind_group_find_matches_and_compute_f2, compute_pipeline_find_matches_and_compute_f2) =
534 bind_group_and_pipeline_find_matches_and_compute_f2(
535 &device,
536 &module,
537 &buckets_a_gpu,
538 &bucket_sizes_gpu,
539 &buckets_b_gpu,
540 &positions_f2_gpu,
541 &metadatas_b_gpu,
542 );
543
544 let (bind_group_find_matches_and_compute_f3, compute_pipeline_find_matches_and_compute_f3) =
545 bind_group_and_pipeline_find_matches_and_compute_fn::<3>(
546 &device,
547 &module,
548 &buckets_b_gpu,
549 &metadatas_b_gpu,
550 &bucket_sizes_gpu,
551 &buckets_a_gpu,
552 &positions_f3_gpu,
553 &metadatas_a_gpu,
554 );
555
556 let (bind_group_find_matches_and_compute_f4, compute_pipeline_find_matches_and_compute_f4) =
557 bind_group_and_pipeline_find_matches_and_compute_fn::<4>(
558 &device,
559 &module,
560 &buckets_a_gpu,
561 &metadatas_a_gpu,
562 &bucket_sizes_gpu,
563 &buckets_b_gpu,
564 &positions_f4_gpu,
565 &metadatas_b_gpu,
566 );
567
568 let (bind_group_find_matches_and_compute_f5, compute_pipeline_find_matches_and_compute_f5) =
569 bind_group_and_pipeline_find_matches_and_compute_fn::<5>(
570 &device,
571 &module,
572 &buckets_b_gpu,
573 &metadatas_b_gpu,
574 &bucket_sizes_gpu,
575 &buckets_a_gpu,
576 &positions_f5_gpu,
577 &metadatas_a_gpu,
578 );
579
580 let (bind_group_find_matches_and_compute_f6, compute_pipeline_find_matches_and_compute_f6) =
581 bind_group_and_pipeline_find_matches_and_compute_fn::<6>(
582 &device,
583 &module,
584 &buckets_a_gpu,
585 &metadatas_a_gpu,
586 &bucket_sizes_gpu,
587 &buckets_b_gpu,
588 &positions_f6_gpu,
589 &metadatas_b_gpu,
590 );
591
592 let (bind_group_find_matches_and_compute_f7, compute_pipeline_find_matches_and_compute_f7) =
593 bind_group_and_pipeline_find_matches_and_compute_f7(
594 &device,
595 &module,
596 &buckets_b_gpu,
597 &metadatas_b_gpu,
598 &table_6_proof_targets_sizes_gpu,
599 &table_6_proof_targets_gpu,
600 );
601
602 let (bind_group_find_proofs, compute_pipeline_find_proofs) =
603 bind_group_and_pipeline_find_proofs(
604 &device,
605 &module,
606 &positions_f2_gpu,
607 &positions_f3_gpu,
608 &positions_f4_gpu,
609 &positions_f5_gpu,
610 &positions_f6_gpu,
611 &table_6_proof_targets_sizes_gpu,
612 &table_6_proof_targets_gpu,
613 &proofs_gpu,
614 );
615
616 Self {
617 device,
618 queue,
619 mapping_error: Arc::new(Mutex::new(None)),
620 tainted: false,
621 initial_state_host,
622 initial_state_gpu,
623 proofs_host,
624 proofs_gpu,
625 bind_group_compute_f1,
626 compute_pipeline_compute_f1,
627 bind_group_sort_buckets_a,
628 compute_pipeline_sort_buckets_a,
629 bind_group_sort_buckets_b,
630 compute_pipeline_sort_buckets_b,
631 bind_group_find_matches_and_compute_f2,
632 compute_pipeline_find_matches_and_compute_f2,
633 bind_group_find_matches_and_compute_f3,
634 compute_pipeline_find_matches_and_compute_f3,
635 bind_group_find_matches_and_compute_f4,
636 compute_pipeline_find_matches_and_compute_f4,
637 bind_group_find_matches_and_compute_f5,
638 compute_pipeline_find_matches_and_compute_f5,
639 bind_group_find_matches_and_compute_f6,
640 compute_pipeline_find_matches_and_compute_f6,
641 bind_group_find_matches_and_compute_f7,
642 compute_pipeline_find_matches_and_compute_f7,
643 bind_group_find_proofs,
644 compute_pipeline_find_proofs,
645 }
646 }
647
648 fn create_proofs(
649 &mut self,
650 seed: &PosSeed,
651 ) -> Result<ProofsHostWrapper<'_>, RecordEncodingError> {
652 if self.tainted {
653 return Err(RecordEncodingError::DeviceBroken);
654 }
655 self.tainted = true;
656
657 let mut encoder = self
658 .device
659 .create_command_encoder(&CommandEncoderDescriptor {
660 label: Some("create_proofs"),
661 });
662
663 self.initial_state_host
665 .get_mapped_range_mut(..)
666 .copy_from_slice(&block_to_bytes(
667 &ChaCha8State::init(seed, &[0; _]).to_repr(),
668 ));
669 self.initial_state_host.unmap();
670
671 encoder.copy_buffer_to_buffer(
672 &self.initial_state_host,
673 0,
674 &self.initial_state_gpu,
675 0,
676 self.initial_state_host.size(),
677 );
678
679 {
680 let mut cpass = encoder.begin_compute_pass(&ComputePassDescriptor {
681 label: Some("create_proofs"),
682 timestamp_writes: None,
683 });
684
685 cpass.set_bind_group(0, &self.bind_group_compute_f1, &[]);
686 cpass.set_pipeline(&self.compute_pipeline_compute_f1);
687 cpass.dispatch_workgroups(
688 MAX_TABLE_SIZE
689 .div_ceil(compute_f1::WORKGROUP_SIZE * compute_f1::ELEMENTS_PER_INVOCATION),
690 1,
691 1,
692 );
693
694 cpass.set_bind_group(0, &self.bind_group_sort_buckets_a, &[]);
695 cpass.set_pipeline(&self.compute_pipeline_sort_buckets_a);
696 cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
697
698 cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f2, &[]);
699 cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f2);
700 cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
701
702 cpass.set_bind_group(0, &self.bind_group_sort_buckets_b, &[]);
703 cpass.set_pipeline(&self.compute_pipeline_sort_buckets_b);
704 cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
705
706 cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f3, &[]);
707 cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f3);
708 cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
709
710 cpass.set_bind_group(0, &self.bind_group_sort_buckets_a, &[]);
711 cpass.set_pipeline(&self.compute_pipeline_sort_buckets_a);
712 cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
713
714 cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f4, &[]);
715 cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f4);
716 cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
717
718 cpass.set_bind_group(0, &self.bind_group_sort_buckets_b, &[]);
719 cpass.set_pipeline(&self.compute_pipeline_sort_buckets_b);
720 cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
721
722 cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f5, &[]);
723 cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f5);
724 cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
725
726 cpass.set_bind_group(0, &self.bind_group_sort_buckets_a, &[]);
727 cpass.set_pipeline(&self.compute_pipeline_sort_buckets_a);
728 cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
729
730 cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f6, &[]);
731 cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f6);
732 cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
733
734 cpass.set_bind_group(0, &self.bind_group_sort_buckets_b, &[]);
735 cpass.set_pipeline(&self.compute_pipeline_sort_buckets_b);
736 cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
737
738 cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f7, &[]);
739 cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f7);
740 cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
741
742 cpass.set_bind_group(0, &self.bind_group_find_proofs, &[]);
743 cpass.set_pipeline(&self.compute_pipeline_find_proofs);
744 cpass.dispatch_workgroups(NUM_S_BUCKETS as u32 / find_proofs::WORKGROUP_SIZE, 1, 1);
745 }
746
747 encoder.copy_buffer_to_buffer(
748 &self.proofs_gpu,
749 0,
750 &self.proofs_host,
751 0,
752 self.proofs_host.size(),
753 );
754
755 encoder.map_buffer_on_submit(&self.initial_state_host, MapMode::Write, .., {
757 let mapping_error = Arc::clone(&self.mapping_error);
758
759 move |r| {
760 if let Err(error) = r {
761 mapping_error.lock().replace(error);
762 }
763 }
764 });
765 encoder.map_buffer_on_submit(&self.proofs_host, MapMode::Read, .., {
766 let mapping_error = Arc::clone(&self.mapping_error);
767
768 move |r| {
769 if let Err(error) = r {
770 mapping_error.lock().replace(error);
771 }
772 }
773 });
774
775 let submission_index = self.queue.submit([encoder.finish()]);
776
777 self.device.poll(PollType::Wait {
778 submission_index: Some(submission_index),
779 timeout: None,
780 })?;
781
782 if let Some(error) = self.mapping_error.lock().take() {
783 return Err(RecordEncodingError::BufferMapping(error));
784 }
785
786 let proofs = {
787 let proofs_host_ptr = self
788 .proofs_host
789 .get_mapped_range(..)
790 .as_ptr()
791 .cast::<ProofsHost>();
792 unsafe { &*proofs_host_ptr }
794 };
795
796 self.tainted = false;
797
798 Ok(ProofsHostWrapper {
799 proofs,
800 proofs_host: &self.proofs_host,
801 })
802 }
803}
804
805fn bind_group_and_pipeline_compute_f1(
806 device: &wgpu::Device,
807 module: &ShaderModule,
808 initial_state_gpu: &Buffer,
809 bucket_sizes_gpu: &Buffer,
810 buckets_gpu: &Buffer,
811) -> (BindGroup, ComputePipeline) {
812 let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
813 label: Some("compute_f1"),
814 entries: &[
815 BindGroupLayoutEntry {
816 binding: 0,
817 count: None,
818 visibility: ShaderStages::COMPUTE,
819 ty: BindingType::Buffer {
820 has_dynamic_offset: false,
821 min_binding_size: None,
822 ty: BufferBindingType::Uniform,
823 },
824 },
825 BindGroupLayoutEntry {
826 binding: 1,
827 count: None,
828 visibility: ShaderStages::COMPUTE,
829 ty: BindingType::Buffer {
830 has_dynamic_offset: false,
831 min_binding_size: None,
832 ty: BufferBindingType::Storage { read_only: false },
833 },
834 },
835 BindGroupLayoutEntry {
836 binding: 2,
837 count: None,
838 visibility: ShaderStages::COMPUTE,
839 ty: BindingType::Buffer {
840 has_dynamic_offset: false,
841 min_binding_size: None,
842 ty: BufferBindingType::Storage { read_only: false },
843 },
844 },
845 ],
846 });
847
848 let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
849 label: Some("compute_f1"),
850 bind_group_layouts: &[Some(&bind_group_layout)],
851 immediate_size: 0,
852 });
853
854 let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
855 compilation_options: PipelineCompilationOptions {
856 constants: &[],
857 zero_initialize_workgroup_memory: false,
858 },
859 cache: None,
860 label: Some("compute_f1"),
861 layout: Some(&pipeline_layout),
862 module,
863 entry_point: Some("compute_f1"),
864 });
865
866 let bind_group = device.create_bind_group(&BindGroupDescriptor {
867 label: Some("compute_f1"),
868 layout: &bind_group_layout,
869 entries: &[
870 BindGroupEntry {
871 binding: 0,
872 resource: initial_state_gpu.as_entire_binding(),
873 },
874 BindGroupEntry {
875 binding: 1,
876 resource: bucket_sizes_gpu.as_entire_binding(),
877 },
878 BindGroupEntry {
879 binding: 2,
880 resource: buckets_gpu.as_entire_binding(),
881 },
882 ],
883 });
884
885 (bind_group, compute_pipeline)
886}
887
888fn bind_group_and_pipeline_sort_buckets(
889 device: &wgpu::Device,
890 module: &ShaderModule,
891 bucket_sizes_gpu: &Buffer,
892 buckets_gpu: &Buffer,
893) -> (BindGroup, ComputePipeline) {
894 let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
895 label: Some("sort_buckets"),
896 entries: &[
897 BindGroupLayoutEntry {
898 binding: 0,
899 count: None,
900 visibility: ShaderStages::COMPUTE,
901 ty: BindingType::Buffer {
902 has_dynamic_offset: false,
903 min_binding_size: None,
904 ty: BufferBindingType::Storage { read_only: false },
905 },
906 },
907 BindGroupLayoutEntry {
908 binding: 1,
909 count: None,
910 visibility: ShaderStages::COMPUTE,
911 ty: BindingType::Buffer {
912 has_dynamic_offset: false,
913 min_binding_size: None,
914 ty: BufferBindingType::Storage { read_only: false },
915 },
916 },
917 ],
918 });
919
920 let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
921 label: Some("sort_buckets"),
922 bind_group_layouts: &[Some(&bind_group_layout)],
923 immediate_size: 0,
924 });
925
926 let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
927 compilation_options: PipelineCompilationOptions {
928 constants: &[],
929 zero_initialize_workgroup_memory: false,
930 },
931 cache: None,
932 label: Some("sort_buckets"),
933 layout: Some(&pipeline_layout),
934 module,
935 entry_point: Some("sort_buckets"),
936 });
937
938 let bind_group = device.create_bind_group(&BindGroupDescriptor {
939 label: Some("sort_buckets"),
940 layout: &bind_group_layout,
941 entries: &[
942 BindGroupEntry {
943 binding: 0,
944 resource: bucket_sizes_gpu.as_entire_binding(),
945 },
946 BindGroupEntry {
947 binding: 1,
948 resource: buckets_gpu.as_entire_binding(),
949 },
950 ],
951 });
952
953 (bind_group, compute_pipeline)
954}
955
956fn bind_group_and_pipeline_find_matches_and_compute_f2(
957 device: &wgpu::Device,
958 module: &ShaderModule,
959 parent_buckets_gpu: &Buffer,
960 bucket_sizes_gpu: &Buffer,
961 buckets_gpu: &Buffer,
962 positions_gpu: &Buffer,
963 metadatas_gpu: &Buffer,
964) -> (BindGroup, ComputePipeline) {
965 let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
966 label: Some("find_matches_and_compute_f2"),
967 entries: &[
968 BindGroupLayoutEntry {
969 binding: 0,
970 count: None,
971 visibility: ShaderStages::COMPUTE,
972 ty: BindingType::Buffer {
973 has_dynamic_offset: false,
974 min_binding_size: None,
975 ty: BufferBindingType::Storage { read_only: true },
976 },
977 },
978 BindGroupLayoutEntry {
979 binding: 1,
980 count: None,
981 visibility: ShaderStages::COMPUTE,
982 ty: BindingType::Buffer {
983 has_dynamic_offset: false,
984 min_binding_size: None,
985 ty: BufferBindingType::Storage { read_only: false },
986 },
987 },
988 BindGroupLayoutEntry {
989 binding: 2,
990 count: None,
991 visibility: ShaderStages::COMPUTE,
992 ty: BindingType::Buffer {
993 has_dynamic_offset: false,
994 min_binding_size: None,
995 ty: BufferBindingType::Storage { read_only: false },
996 },
997 },
998 BindGroupLayoutEntry {
999 binding: 3,
1000 count: None,
1001 visibility: ShaderStages::COMPUTE,
1002 ty: BindingType::Buffer {
1003 has_dynamic_offset: false,
1004 min_binding_size: None,
1005 ty: BufferBindingType::Storage { read_only: false },
1006 },
1007 },
1008 BindGroupLayoutEntry {
1009 binding: 4,
1010 count: None,
1011 visibility: ShaderStages::COMPUTE,
1012 ty: BindingType::Buffer {
1013 has_dynamic_offset: false,
1014 min_binding_size: None,
1015 ty: BufferBindingType::Storage { read_only: false },
1016 },
1017 },
1018 ],
1019 });
1020
1021 let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1022 label: Some("find_matches_and_compute_f2"),
1023 bind_group_layouts: &[Some(&bind_group_layout)],
1024 immediate_size: 0,
1025 });
1026
1027 let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1028 compilation_options: PipelineCompilationOptions {
1029 constants: &[],
1030 zero_initialize_workgroup_memory: true,
1031 },
1032 cache: None,
1033 label: Some("find_matches_and_compute_f2"),
1034 layout: Some(&pipeline_layout),
1035 module,
1036 entry_point: Some("find_matches_and_compute_f2"),
1037 });
1038
1039 let bind_group = device.create_bind_group(&BindGroupDescriptor {
1040 label: Some("find_matches_and_compute_f2"),
1041 layout: &bind_group_layout,
1042 entries: &[
1043 BindGroupEntry {
1044 binding: 0,
1045 resource: parent_buckets_gpu.as_entire_binding(),
1046 },
1047 BindGroupEntry {
1048 binding: 1,
1049 resource: bucket_sizes_gpu.as_entire_binding(),
1050 },
1051 BindGroupEntry {
1052 binding: 2,
1053 resource: buckets_gpu.as_entire_binding(),
1054 },
1055 BindGroupEntry {
1056 binding: 3,
1057 resource: positions_gpu.as_entire_binding(),
1058 },
1059 BindGroupEntry {
1060 binding: 4,
1061 resource: metadatas_gpu.as_entire_binding(),
1062 },
1063 ],
1064 });
1065
1066 (bind_group, compute_pipeline)
1067}
1068
1069#[expect(
1070 clippy::too_many_arguments,
1071 reason = "Both I/O and Vulkan stuff together take a lot of arguments"
1072)]
1073fn bind_group_and_pipeline_find_matches_and_compute_fn<const TABLE_NUMBER: u8>(
1074 device: &wgpu::Device,
1075 module: &ShaderModule,
1076 parent_buckets_gpu: &Buffer,
1077 parent_metadatas_gpu: &Buffer,
1078 bucket_sizes_gpu: &Buffer,
1079 buckets_gpu: &Buffer,
1080 positions_gpu: &Buffer,
1081 metadatas_gpu: &Buffer,
1082) -> (BindGroup, ComputePipeline) {
1083 let label = format!("find_matches_and_compute_f{TABLE_NUMBER}");
1084 let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
1085 label: Some(&label),
1086 entries: &[
1087 BindGroupLayoutEntry {
1088 binding: 0,
1089 count: None,
1090 visibility: ShaderStages::COMPUTE,
1091 ty: BindingType::Buffer {
1092 has_dynamic_offset: false,
1093 min_binding_size: None,
1094 ty: BufferBindingType::Storage { read_only: true },
1095 },
1096 },
1097 BindGroupLayoutEntry {
1098 binding: 1,
1099 count: None,
1100 visibility: ShaderStages::COMPUTE,
1101 ty: BindingType::Buffer {
1102 has_dynamic_offset: false,
1103 min_binding_size: None,
1104 ty: BufferBindingType::Storage { read_only: true },
1105 },
1106 },
1107 BindGroupLayoutEntry {
1108 binding: 2,
1109 count: None,
1110 visibility: ShaderStages::COMPUTE,
1111 ty: BindingType::Buffer {
1112 has_dynamic_offset: false,
1113 min_binding_size: None,
1114 ty: BufferBindingType::Storage { read_only: false },
1115 },
1116 },
1117 BindGroupLayoutEntry {
1118 binding: 3,
1119 count: None,
1120 visibility: ShaderStages::COMPUTE,
1121 ty: BindingType::Buffer {
1122 has_dynamic_offset: false,
1123 min_binding_size: None,
1124 ty: BufferBindingType::Storage { read_only: false },
1125 },
1126 },
1127 BindGroupLayoutEntry {
1128 binding: 4,
1129 count: None,
1130 visibility: ShaderStages::COMPUTE,
1131 ty: BindingType::Buffer {
1132 has_dynamic_offset: false,
1133 min_binding_size: None,
1134 ty: BufferBindingType::Storage { read_only: false },
1135 },
1136 },
1137 BindGroupLayoutEntry {
1138 binding: 5,
1139 count: None,
1140 visibility: ShaderStages::COMPUTE,
1141 ty: BindingType::Buffer {
1142 has_dynamic_offset: false,
1143 min_binding_size: None,
1144 ty: BufferBindingType::Storage { read_only: false },
1145 },
1146 },
1147 ],
1148 });
1149
1150 let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1151 label: Some(&label),
1152 bind_group_layouts: &[Some(&bind_group_layout)],
1153 immediate_size: 0,
1154 });
1155
1156 let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1157 compilation_options: PipelineCompilationOptions {
1158 constants: &[],
1159 zero_initialize_workgroup_memory: true,
1160 },
1161 cache: None,
1162 label: Some(&label),
1163 layout: Some(&pipeline_layout),
1164 module,
1165 entry_point: Some(&format!("find_matches_and_compute_f{TABLE_NUMBER}")),
1166 });
1167
1168 let bind_group = device.create_bind_group(&BindGroupDescriptor {
1169 label: Some(&label),
1170 layout: &bind_group_layout,
1171 entries: &[
1172 BindGroupEntry {
1173 binding: 0,
1174 resource: parent_buckets_gpu.as_entire_binding(),
1175 },
1176 BindGroupEntry {
1177 binding: 1,
1178 resource: parent_metadatas_gpu.as_entire_binding(),
1179 },
1180 BindGroupEntry {
1181 binding: 2,
1182 resource: bucket_sizes_gpu.as_entire_binding(),
1183 },
1184 BindGroupEntry {
1185 binding: 3,
1186 resource: buckets_gpu.as_entire_binding(),
1187 },
1188 BindGroupEntry {
1189 binding: 4,
1190 resource: positions_gpu.as_entire_binding(),
1191 },
1192 BindGroupEntry {
1193 binding: 5,
1194 resource: metadatas_gpu.as_entire_binding(),
1195 },
1196 ],
1197 });
1198
1199 (bind_group, compute_pipeline)
1200}
1201
1202fn bind_group_and_pipeline_find_matches_and_compute_f7(
1203 device: &wgpu::Device,
1204 module: &ShaderModule,
1205 parent_buckets_gpu: &Buffer,
1206 parent_metadatas_gpu: &Buffer,
1207 table_6_proof_targets_sizes_gpu: &Buffer,
1208 table_6_proof_targets_gpu: &Buffer,
1209) -> (BindGroup, ComputePipeline) {
1210 let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
1211 label: Some("find_matches_and_compute_f7"),
1212 entries: &[
1213 BindGroupLayoutEntry {
1214 binding: 0,
1215 count: None,
1216 visibility: ShaderStages::COMPUTE,
1217 ty: BindingType::Buffer {
1218 has_dynamic_offset: false,
1219 min_binding_size: None,
1220 ty: BufferBindingType::Storage { read_only: true },
1221 },
1222 },
1223 BindGroupLayoutEntry {
1224 binding: 1,
1225 count: None,
1226 visibility: ShaderStages::COMPUTE,
1227 ty: BindingType::Buffer {
1228 has_dynamic_offset: false,
1229 min_binding_size: None,
1230 ty: BufferBindingType::Storage { read_only: true },
1231 },
1232 },
1233 BindGroupLayoutEntry {
1234 binding: 2,
1235 count: None,
1236 visibility: ShaderStages::COMPUTE,
1237 ty: BindingType::Buffer {
1238 has_dynamic_offset: false,
1239 min_binding_size: None,
1240 ty: BufferBindingType::Storage { read_only: false },
1241 },
1242 },
1243 BindGroupLayoutEntry {
1244 binding: 3,
1245 count: None,
1246 visibility: ShaderStages::COMPUTE,
1247 ty: BindingType::Buffer {
1248 has_dynamic_offset: false,
1249 min_binding_size: None,
1250 ty: BufferBindingType::Storage { read_only: false },
1251 },
1252 },
1253 ],
1254 });
1255
1256 let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1257 label: Some("find_matches_and_compute_f7"),
1258 bind_group_layouts: &[Some(&bind_group_layout)],
1259 immediate_size: 0,
1260 });
1261
1262 let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1263 compilation_options: PipelineCompilationOptions {
1264 constants: &[],
1265 zero_initialize_workgroup_memory: true,
1266 },
1267 cache: None,
1268 label: Some("find_matches_and_compute_f7"),
1269 layout: Some(&pipeline_layout),
1270 module,
1271 entry_point: Some("find_matches_and_compute_f7"),
1272 });
1273
1274 let bind_group = device.create_bind_group(&BindGroupDescriptor {
1275 label: Some("find_matches_and_compute_f7"),
1276 layout: &bind_group_layout,
1277 entries: &[
1278 BindGroupEntry {
1279 binding: 0,
1280 resource: parent_buckets_gpu.as_entire_binding(),
1281 },
1282 BindGroupEntry {
1283 binding: 1,
1284 resource: parent_metadatas_gpu.as_entire_binding(),
1285 },
1286 BindGroupEntry {
1287 binding: 2,
1288 resource: table_6_proof_targets_sizes_gpu.as_entire_binding(),
1289 },
1290 BindGroupEntry {
1291 binding: 3,
1292 resource: table_6_proof_targets_gpu.as_entire_binding(),
1293 },
1294 ],
1295 });
1296
1297 (bind_group, compute_pipeline)
1298}
1299
1300#[expect(
1301 clippy::too_many_arguments,
1302 reason = "Both I/O and Vulkan stuff together take a lot of arguments"
1303)]
1304fn bind_group_and_pipeline_find_proofs(
1305 device: &wgpu::Device,
1306 module: &ShaderModule,
1307 table_2_positions_gpu: &Buffer,
1308 table_3_positions_gpu: &Buffer,
1309 table_4_positions_gpu: &Buffer,
1310 table_5_positions_gpu: &Buffer,
1311 table_6_positions_gpu: &Buffer,
1312 bucket_sizes_gpu: &Buffer,
1313 buckets_gpu: &Buffer,
1314 proofs_gpu: &Buffer,
1315) -> (BindGroup, ComputePipeline) {
1316 let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
1317 label: Some("find_proofs"),
1318 entries: &[
1319 BindGroupLayoutEntry {
1320 binding: 0,
1321 count: None,
1322 visibility: ShaderStages::COMPUTE,
1323 ty: BindingType::Buffer {
1324 has_dynamic_offset: false,
1325 min_binding_size: None,
1326 ty: BufferBindingType::Storage { read_only: true },
1327 },
1328 },
1329 BindGroupLayoutEntry {
1330 binding: 1,
1331 count: None,
1332 visibility: ShaderStages::COMPUTE,
1333 ty: BindingType::Buffer {
1334 has_dynamic_offset: false,
1335 min_binding_size: None,
1336 ty: BufferBindingType::Storage { read_only: true },
1337 },
1338 },
1339 BindGroupLayoutEntry {
1340 binding: 2,
1341 count: None,
1342 visibility: ShaderStages::COMPUTE,
1343 ty: BindingType::Buffer {
1344 has_dynamic_offset: false,
1345 min_binding_size: None,
1346 ty: BufferBindingType::Storage { read_only: true },
1347 },
1348 },
1349 BindGroupLayoutEntry {
1350 binding: 3,
1351 count: None,
1352 visibility: ShaderStages::COMPUTE,
1353 ty: BindingType::Buffer {
1354 has_dynamic_offset: false,
1355 min_binding_size: None,
1356 ty: BufferBindingType::Storage { read_only: true },
1357 },
1358 },
1359 BindGroupLayoutEntry {
1360 binding: 4,
1361 count: None,
1362 visibility: ShaderStages::COMPUTE,
1363 ty: BindingType::Buffer {
1364 has_dynamic_offset: false,
1365 min_binding_size: None,
1366 ty: BufferBindingType::Storage { read_only: true },
1367 },
1368 },
1369 BindGroupLayoutEntry {
1370 binding: 5,
1371 count: None,
1372 visibility: ShaderStages::COMPUTE,
1373 ty: BindingType::Buffer {
1374 has_dynamic_offset: false,
1375 min_binding_size: None,
1376 ty: BufferBindingType::Storage { read_only: false },
1377 },
1378 },
1379 BindGroupLayoutEntry {
1380 binding: 6,
1381 count: None,
1382 visibility: ShaderStages::COMPUTE,
1383 ty: BindingType::Buffer {
1384 has_dynamic_offset: false,
1385 min_binding_size: None,
1386 ty: BufferBindingType::Storage { read_only: true },
1387 },
1388 },
1389 BindGroupLayoutEntry {
1390 binding: 7,
1391 count: None,
1392 visibility: ShaderStages::COMPUTE,
1393 ty: BindingType::Buffer {
1394 has_dynamic_offset: false,
1395 min_binding_size: None,
1396 ty: BufferBindingType::Storage { read_only: false },
1397 },
1398 },
1399 ],
1400 });
1401
1402 let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1403 label: Some("find_proofs"),
1404 bind_group_layouts: &[Some(&bind_group_layout)],
1405 immediate_size: 0,
1406 });
1407
1408 let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1409 compilation_options: PipelineCompilationOptions {
1410 constants: &[],
1411 zero_initialize_workgroup_memory: false,
1412 },
1413 cache: None,
1414 label: Some("find_proofs"),
1415 layout: Some(&pipeline_layout),
1416 module,
1417 entry_point: Some("find_proofs"),
1418 });
1419
1420 let bind_group = device.create_bind_group(&BindGroupDescriptor {
1421 label: Some("find_proofs"),
1422 layout: &bind_group_layout,
1423 entries: &[
1424 BindGroupEntry {
1425 binding: 0,
1426 resource: table_2_positions_gpu.as_entire_binding(),
1427 },
1428 BindGroupEntry {
1429 binding: 1,
1430 resource: table_3_positions_gpu.as_entire_binding(),
1431 },
1432 BindGroupEntry {
1433 binding: 2,
1434 resource: table_4_positions_gpu.as_entire_binding(),
1435 },
1436 BindGroupEntry {
1437 binding: 3,
1438 resource: table_5_positions_gpu.as_entire_binding(),
1439 },
1440 BindGroupEntry {
1441 binding: 4,
1442 resource: table_6_positions_gpu.as_entire_binding(),
1443 },
1444 BindGroupEntry {
1445 binding: 5,
1446 resource: bucket_sizes_gpu.as_entire_binding(),
1447 },
1448 BindGroupEntry {
1449 binding: 6,
1450 resource: buckets_gpu.as_entire_binding(),
1451 },
1452 BindGroupEntry {
1453 binding: 7,
1454 resource: proofs_gpu.as_entire_binding(),
1455 },
1456 ],
1457 });
1458
1459 (bind_group, compute_pipeline)
1460}