1#[cfg(all(test, not(miri)))]
2mod tests;
3
4use crate::shader::constants::{
5 MAX_BUCKET_SIZE, MAX_TABLE_SIZE, NUM_BUCKETS, NUM_MATCH_BUCKETS, NUM_S_BUCKETS,
6 REDUCED_MATCHES_COUNT,
7};
8use crate::shader::find_matches_and_compute_f7::{NUM_ELEMENTS_PER_S_BUCKET, ProofTargets};
9use crate::shader::find_proofs::ProofsHost;
10use crate::shader::types::{Metadata, Position, PositionR};
11use crate::shader::{compute_f1, find_proofs, select_shader_features_limits};
12use ab_chacha8::{ChaCha8Block, ChaCha8State, block_to_bytes};
13use ab_core_primitives::pieces::{PieceOffset, Record};
14use ab_core_primitives::pos::PosSeed;
15use ab_core_primitives::sectors::SectorId;
16use ab_erasure_coding::ErasureCoding;
17use ab_farmer_components::plotting::RecordsEncoder;
18use ab_farmer_components::sector::SectorContentsMap;
19use async_lock::Mutex as AsyncMutex;
20use futures::stream::FuturesOrdered;
21use futures::{StreamExt, TryStreamExt};
22use parking_lot::Mutex;
23use rayon::prelude::*;
24use rayon::{ThreadPool, ThreadPoolBuilder};
25use rclite::Arc;
26use std::fmt;
27use std::num::NonZeroU8;
28use std::simd::Simd;
29use std::sync::Arc as StdArc;
30use std::sync::atomic::{AtomicBool, Ordering};
31use tracing::{debug, warn};
32use wgpu::{
33 AdapterInfo, Backend, BackendOptions, Backends, BindGroup, BindGroupDescriptor, BindGroupEntry,
34 BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType, Buffer, BufferAddress,
35 BufferAsyncError, BufferBindingType, BufferDescriptor, BufferUsages, CommandEncoderDescriptor,
36 ComputePassDescriptor, ComputePipeline, ComputePipelineDescriptor, DeviceDescriptor,
37 DeviceType, Instance, InstanceDescriptor, InstanceFlags, MapMode, MemoryBudgetThresholds,
38 PipelineCompilationOptions, PipelineLayoutDescriptor, PollError, PollType, Queue,
39 RequestDeviceError, ShaderModule, ShaderStages,
40};
41
42#[derive(Debug, thiserror::Error)]
44pub enum RecordEncodingError {
45 #[error("Too many records: {0}")]
47 TooManyRecords(usize),
48 #[error("Proof creation failed previously and the device is now considered broken")]
50 DeviceBroken,
51 #[error("Failed to map buffer: {0}")]
53 BufferMapping(#[from] BufferAsyncError),
54 #[error("Poll error: {0}")]
56 DevicePoll(#[from] PollError),
57}
58
59struct ProofsHostWrapper<'a> {
60 proofs: &'a ProofsHost,
61 proofs_host: &'a Buffer,
62}
63
64impl Drop for ProofsHostWrapper<'_> {
65 fn drop(&mut self) {
66 self.proofs_host.unmap();
67 }
68}
69
70#[derive(Clone)]
72pub struct Device {
73 id: u32,
74 devices: Vec<(wgpu::Device, Queue, ShaderModule)>,
75 adapter_info: AdapterInfo,
76}
77
78impl fmt::Debug for Device {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 f.debug_struct("Device")
81 .field("id", &self.id)
82 .field("name", &self.adapter_info.name)
83 .field("device_type", &self.adapter_info.device_type)
84 .field("driver", &self.adapter_info.driver)
85 .field("driver_info", &self.adapter_info.driver_info)
86 .field("backend", &self.adapter_info.backend)
87 .finish_non_exhaustive()
88 }
89}
90
91impl Device {
92 pub async fn enumerate<NOQ>(number_of_queues: NOQ) -> Vec<Self>
94 where
95 NOQ: Fn(DeviceType) -> NonZeroU8,
96 {
97 let backends = Backends::from_env().unwrap_or(Backends::METAL | Backends::VULKAN);
98 let instance = Instance::new(&InstanceDescriptor {
99 backends,
100 flags: if cfg!(debug_assertions) {
101 InstanceFlags::debugging().with_env()
102 } else {
103 InstanceFlags::from_env_or_default()
104 },
105 memory_budget_thresholds: MemoryBudgetThresholds::default(),
106 backend_options: BackendOptions::from_env_or_default(),
107 });
108
109 let adapters = instance.enumerate_adapters(backends);
110 let number_of_queues = &number_of_queues;
111
112 adapters
113 .into_iter()
114 .zip(0..)
115 .map(|(adapter, id)| async move {
116 let adapter_info = adapter.get_info();
117
118 let (shader, required_features, required_limits) =
119 match select_shader_features_limits(&adapter) {
120 Some((shader, required_features, required_limits)) => {
121 debug!(
122 %id,
123 adapter_info = ?adapter_info,
124 "Compatible adapter found"
125 );
126
127 (shader, required_features, required_limits)
128 }
129 None => {
130 debug!(
131 %id,
132 adapter_info = ?adapter_info,
133 "Incompatible adapter found"
134 );
135
136 return None;
137 }
138 };
139
140 let devices = (0..number_of_queues(adapter_info.device_type).get())
143 .map(|_| async {
144 let (device, queue) = adapter
145 .request_device(&DeviceDescriptor {
146 label: None,
147 required_features,
148 required_limits: required_limits.clone(),
149 ..DeviceDescriptor::default()
150 })
151 .await
152 .inspect_err(|error| {
153 warn!(%id, ?adapter_info, %error, "Failed to request the device");
154 })?;
155 let module = device.create_shader_module(shader.clone());
156
157 Ok::<_, RequestDeviceError>((device, queue, module))
158 })
159 .collect::<FuturesOrdered<_>>()
160 .try_collect::<Vec<_>>()
161 .await
162 .ok()?;
163
164 Some(Self {
165 id,
166 devices,
167 adapter_info,
168 })
169 })
170 .collect::<FuturesOrdered<_>>()
171 .filter_map(|device| async move { device })
172 .collect()
173 .await
174 }
175
176 pub fn id(&self) -> u32 {
178 self.id
179 }
180
181 pub fn name(&self) -> &str {
183 &self.adapter_info.name
184 }
185
186 pub fn device_type(&self) -> DeviceType {
188 self.adapter_info.device_type
189 }
190
191 pub fn driver(&self) -> &str {
193 &self.adapter_info.driver
194 }
195
196 pub fn driver_info(&self) -> &str {
198 &self.adapter_info.driver_info
199 }
200
201 pub fn backend(&self) -> Backend {
203 self.adapter_info.backend
204 }
205
206 pub fn instantiate(
207 &self,
208 erasure_coding: ErasureCoding,
209 global_mutex: StdArc<AsyncMutex<()>>,
210 ) -> anyhow::Result<GpuRecordsEncoder> {
211 GpuRecordsEncoder::new(
212 self.id,
213 self.devices.clone(),
214 self.adapter_info.clone(),
215 erasure_coding,
216 global_mutex,
217 )
218 }
219}
220
221pub struct GpuRecordsEncoder {
222 id: u32,
223 instances: Vec<Mutex<GpuRecordsEncoderInstance>>,
224 thread_pool: ThreadPool,
225 adapter_info: AdapterInfo,
226 erasure_coding: ErasureCoding,
227 global_mutex: StdArc<AsyncMutex<()>>,
228}
229
230impl fmt::Debug for GpuRecordsEncoder {
231 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232 f.debug_struct("GpuRecordsEncoder")
233 .field("id", &self.id)
234 .field("name", &self.adapter_info.name)
235 .field("device_type", &self.adapter_info.device_type)
236 .field("driver", &self.adapter_info.driver)
237 .field("driver_info", &self.adapter_info.driver_info)
238 .field("backend", &self.adapter_info.backend)
239 .finish_non_exhaustive()
240 }
241}
242
243impl RecordsEncoder for GpuRecordsEncoder {
244 fn encode_records(
246 &mut self,
247 sector_id: &SectorId,
248 records: &mut [Record],
249 abort_early: &AtomicBool,
250 ) -> anyhow::Result<SectorContentsMap> {
251 let mut sector_contents_map = SectorContentsMap::new(
252 u16::try_from(records.len())
253 .map_err(|_| RecordEncodingError::TooManyRecords(records.len()))?,
254 );
255
256 let maybe_error = self.thread_pool.install(|| {
257 records
258 .par_iter_mut()
259 .zip(sector_contents_map.iter_record_chunks_used_mut())
260 .enumerate()
261 .find_map_any(|(piece_offset, (record, record_chunks_used))| {
262 self.global_mutex.lock_blocking();
264
265 let mut parity_record_chunks = Record::new_boxed();
266
267 self.erasure_coding
270 .extend(record.iter(), parity_record_chunks.iter_mut())
271 .expect("Statically guaranteed valid inputs; qed");
272
273 if abort_early.load(Ordering::Relaxed) {
274 return None;
275 }
276 let seed =
277 sector_id.derive_evaluation_seed(PieceOffset::from(piece_offset as u16));
278 let thread_index = rayon::current_thread_index().unwrap_or_default();
279 let mut encoder_instance = self.instances[thread_index]
280 .try_lock()
281 .expect("1:1 mapping between threads and devices; qed");
282 let proofs = match encoder_instance.create_proofs(&seed) {
283 Ok(proofs) => proofs,
284 Err(error) => {
285 return Some(error);
286 }
287 };
288 let proofs = proofs.proofs;
289
290 *record_chunks_used = proofs.found_proofs;
291
292 let mut num_found_proofs = 0_usize;
294 for (s_buckets, found_proofs) in (0..Record::NUM_S_BUCKETS)
295 .array_chunks::<{ u8::BITS as usize }>()
296 .zip(record_chunks_used)
297 {
298 for (proof_offset, s_bucket) in s_buckets.into_iter().enumerate() {
299 if num_found_proofs == Record::NUM_CHUNKS {
300 *found_proofs &=
302 u8::MAX.unbounded_shr(u8::BITS - proof_offset as u32);
303 break;
304 }
305 if (*found_proofs & (1 << proof_offset)) != 0 {
306 let record_chunk = if s_bucket < Record::NUM_CHUNKS {
307 record[s_bucket]
308 } else {
309 parity_record_chunks[s_bucket - Record::NUM_CHUNKS]
310 };
311
312 record[num_found_proofs] = (Simd::from(record_chunk)
313 ^ Simd::from(*proofs.proofs[s_bucket].hash()))
314 .to_array();
315 num_found_proofs += 1;
316 }
317 }
318 }
319
320 None
321 })
322 });
323
324 if let Some(error) = maybe_error {
325 return Err(error.into());
326 }
327
328 Ok(sector_contents_map)
329 }
330}
331
332impl GpuRecordsEncoder {
333 fn new(
334 id: u32,
335 devices: Vec<(wgpu::Device, Queue, ShaderModule)>,
336 adapter_info: AdapterInfo,
337 erasure_coding: ErasureCoding,
338 global_mutex: StdArc<AsyncMutex<()>>,
339 ) -> anyhow::Result<Self> {
340 let thread_pool = ThreadPoolBuilder::new()
341 .thread_name(move |thread_index| format!("pos-gpu-{id}.{thread_index}"))
342 .num_threads(devices.len())
343 .build()?;
344
345 Ok(Self {
346 id,
347 instances: devices
348 .into_iter()
349 .map(|(device, queue, module)| {
350 Mutex::new(GpuRecordsEncoderInstance::new(device, queue, module))
351 })
352 .collect(),
353 thread_pool,
354 adapter_info,
355 erasure_coding,
356 global_mutex,
357 })
358 }
359}
360
361struct GpuRecordsEncoderInstance {
362 device: wgpu::Device,
363 queue: Queue,
364 mapping_error: Arc<Mutex<Option<BufferAsyncError>>>,
365 tainted: bool,
366 initial_state_host: Buffer,
367 initial_state_gpu: Buffer,
368 proofs_host: Buffer,
369 proofs_gpu: Buffer,
370 bind_group_compute_f1: BindGroup,
371 compute_pipeline_compute_f1: ComputePipeline,
372 bind_group_sort_buckets_a: BindGroup,
373 compute_pipeline_sort_buckets_a: ComputePipeline,
374 bind_group_sort_buckets_b: BindGroup,
375 compute_pipeline_sort_buckets_b: ComputePipeline,
376 bind_group_find_matches_and_compute_f2: BindGroup,
377 compute_pipeline_find_matches_and_compute_f2: ComputePipeline,
378 bind_group_find_matches_and_compute_f3: BindGroup,
379 compute_pipeline_find_matches_and_compute_f3: ComputePipeline,
380 bind_group_find_matches_and_compute_f4: BindGroup,
381 compute_pipeline_find_matches_and_compute_f4: ComputePipeline,
382 bind_group_find_matches_and_compute_f5: BindGroup,
383 compute_pipeline_find_matches_and_compute_f5: ComputePipeline,
384 bind_group_find_matches_and_compute_f6: BindGroup,
385 compute_pipeline_find_matches_and_compute_f6: ComputePipeline,
386 bind_group_find_matches_and_compute_f7: BindGroup,
387 compute_pipeline_find_matches_and_compute_f7: ComputePipeline,
388 bind_group_find_proofs: BindGroup,
389 compute_pipeline_find_proofs: ComputePipeline,
390}
391
392impl GpuRecordsEncoderInstance {
393 fn new(device: wgpu::Device, queue: Queue, module: ShaderModule) -> Self {
394 let initial_state_host = device.create_buffer(&BufferDescriptor {
395 label: Some("initial_state_host"),
396 size: size_of::<ChaCha8Block>() as BufferAddress,
397 usage: BufferUsages::MAP_WRITE | BufferUsages::COPY_SRC,
398 mapped_at_creation: true,
399 });
400
401 let initial_state_gpu = device.create_buffer(&BufferDescriptor {
402 label: Some("initial_state_gpu"),
403 size: initial_state_host.size(),
404 usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST,
405 mapped_at_creation: false,
406 });
407
408 let bucket_sizes_gpu_buffer_size = size_of::<[u32; NUM_BUCKETS]>() as BufferAddress;
409 let table_6_proof_targets_sizes_gpu_buffer_size =
410 size_of::<[u32; NUM_S_BUCKETS]>() as BufferAddress;
411 let bucket_sizes_gpu = device.create_buffer(&BufferDescriptor {
415 label: Some("bucket_sizes_gpu"),
416 size: bucket_sizes_gpu_buffer_size.max(table_6_proof_targets_sizes_gpu_buffer_size),
417 usage: BufferUsages::STORAGE,
418 mapped_at_creation: false,
419 });
420 let table_6_proof_targets_sizes_gpu = bucket_sizes_gpu.clone();
422
423 let buckets_a_gpu = device.create_buffer(&BufferDescriptor {
424 label: Some("buckets_a_gpu"),
425 size: size_of::<[[PositionR; MAX_BUCKET_SIZE]; NUM_BUCKETS]>() as BufferAddress,
426 usage: BufferUsages::STORAGE,
427 mapped_at_creation: false,
428 });
429
430 let buckets_b_gpu = device.create_buffer(&BufferDescriptor {
431 label: Some("buckets_b_gpu"),
432 size: buckets_a_gpu.size(),
433 usage: BufferUsages::STORAGE,
434 mapped_at_creation: false,
435 });
436
437 let positions_f2_gpu = device.create_buffer(&BufferDescriptor {
438 label: Some("positions_f2_gpu"),
439 size: size_of::<[[[Position; 2]; REDUCED_MATCHES_COUNT]; NUM_MATCH_BUCKETS]>()
440 as BufferAddress,
441 usage: BufferUsages::STORAGE,
442 mapped_at_creation: false,
443 });
444
445 let positions_f3_gpu = device.create_buffer(&BufferDescriptor {
446 label: Some("positions_f3_gpu"),
447 size: positions_f2_gpu.size(),
448 usage: BufferUsages::STORAGE,
449 mapped_at_creation: false,
450 });
451
452 let positions_f4_gpu = device.create_buffer(&BufferDescriptor {
453 label: Some("positions_f4_gpu"),
454 size: positions_f2_gpu.size(),
455 usage: BufferUsages::STORAGE,
456 mapped_at_creation: false,
457 });
458
459 let positions_f5_gpu = device.create_buffer(&BufferDescriptor {
460 label: Some("positions_f5_gpu"),
461 size: positions_f2_gpu.size(),
462 usage: BufferUsages::STORAGE,
463 mapped_at_creation: false,
464 });
465
466 let positions_f6_gpu = device.create_buffer(&BufferDescriptor {
467 label: Some("positions_f6_gpu"),
468 size: positions_f2_gpu.size(),
469 usage: BufferUsages::STORAGE,
470 mapped_at_creation: false,
471 });
472
473 let metadatas_gpu_buffer_size =
474 size_of::<[[Metadata; REDUCED_MATCHES_COUNT]; NUM_MATCH_BUCKETS]>() as BufferAddress;
475 let table_6_proof_targets_gpu_buffer_size = size_of::<
476 [[ProofTargets; NUM_ELEMENTS_PER_S_BUCKET]; NUM_S_BUCKETS],
477 >() as BufferAddress;
478 let metadatas_a_gpu = device.create_buffer(&BufferDescriptor {
479 label: Some("metadatas_a_gpu"),
480 size: metadatas_gpu_buffer_size.max(table_6_proof_targets_gpu_buffer_size),
481 usage: BufferUsages::STORAGE,
482 mapped_at_creation: false,
483 });
484 let table_6_proof_targets_gpu = metadatas_a_gpu.clone();
486
487 let metadatas_b_gpu = device.create_buffer(&BufferDescriptor {
488 label: Some("metadatas_b_gpu"),
489 size: metadatas_gpu_buffer_size,
490 usage: BufferUsages::STORAGE,
491 mapped_at_creation: false,
492 });
493
494 let proofs_host = device.create_buffer(&BufferDescriptor {
495 label: Some("proofs_host"),
496 size: size_of::<ProofsHost>() as BufferAddress,
497 usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
498 mapped_at_creation: false,
499 });
500
501 let proofs_gpu = device.create_buffer(&BufferDescriptor {
502 label: Some("proofs_gpu"),
503 size: proofs_host.size(),
504 usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
505 mapped_at_creation: false,
506 });
507
508 let (bind_group_compute_f1, compute_pipeline_compute_f1) =
509 bind_group_and_pipeline_compute_f1(
510 &device,
511 &module,
512 &initial_state_gpu,
513 &bucket_sizes_gpu,
514 &buckets_a_gpu,
515 );
516 let (bind_group_sort_buckets_a, compute_pipeline_sort_buckets_a) =
517 bind_group_and_pipeline_sort_buckets(
518 &device,
519 &module,
520 &bucket_sizes_gpu,
521 &buckets_a_gpu,
522 );
523
524 let (bind_group_sort_buckets_b, compute_pipeline_sort_buckets_b) =
525 bind_group_and_pipeline_sort_buckets(
526 &device,
527 &module,
528 &bucket_sizes_gpu,
529 &buckets_b_gpu,
530 );
531
532 let (bind_group_find_matches_and_compute_f2, compute_pipeline_find_matches_and_compute_f2) =
533 bind_group_and_pipeline_find_matches_and_compute_f2(
534 &device,
535 &module,
536 &buckets_a_gpu,
537 &bucket_sizes_gpu,
538 &buckets_b_gpu,
539 &positions_f2_gpu,
540 &metadatas_b_gpu,
541 );
542
543 let (bind_group_find_matches_and_compute_f3, compute_pipeline_find_matches_and_compute_f3) =
544 bind_group_and_pipeline_find_matches_and_compute_fn::<3>(
545 &device,
546 &module,
547 &buckets_b_gpu,
548 &metadatas_b_gpu,
549 &bucket_sizes_gpu,
550 &buckets_a_gpu,
551 &positions_f3_gpu,
552 &metadatas_a_gpu,
553 );
554
555 let (bind_group_find_matches_and_compute_f4, compute_pipeline_find_matches_and_compute_f4) =
556 bind_group_and_pipeline_find_matches_and_compute_fn::<4>(
557 &device,
558 &module,
559 &buckets_a_gpu,
560 &metadatas_a_gpu,
561 &bucket_sizes_gpu,
562 &buckets_b_gpu,
563 &positions_f4_gpu,
564 &metadatas_b_gpu,
565 );
566
567 let (bind_group_find_matches_and_compute_f5, compute_pipeline_find_matches_and_compute_f5) =
568 bind_group_and_pipeline_find_matches_and_compute_fn::<5>(
569 &device,
570 &module,
571 &buckets_b_gpu,
572 &metadatas_b_gpu,
573 &bucket_sizes_gpu,
574 &buckets_a_gpu,
575 &positions_f5_gpu,
576 &metadatas_a_gpu,
577 );
578
579 let (bind_group_find_matches_and_compute_f6, compute_pipeline_find_matches_and_compute_f6) =
580 bind_group_and_pipeline_find_matches_and_compute_fn::<6>(
581 &device,
582 &module,
583 &buckets_a_gpu,
584 &metadatas_a_gpu,
585 &bucket_sizes_gpu,
586 &buckets_b_gpu,
587 &positions_f6_gpu,
588 &metadatas_b_gpu,
589 );
590
591 let (bind_group_find_matches_and_compute_f7, compute_pipeline_find_matches_and_compute_f7) =
592 bind_group_and_pipeline_find_matches_and_compute_f7(
593 &device,
594 &module,
595 &buckets_b_gpu,
596 &metadatas_b_gpu,
597 &table_6_proof_targets_sizes_gpu,
598 &table_6_proof_targets_gpu,
599 );
600
601 let (bind_group_find_proofs, compute_pipeline_find_proofs) =
602 bind_group_and_pipeline_find_proofs(
603 &device,
604 &module,
605 &positions_f2_gpu,
606 &positions_f3_gpu,
607 &positions_f4_gpu,
608 &positions_f5_gpu,
609 &positions_f6_gpu,
610 &table_6_proof_targets_sizes_gpu,
611 &table_6_proof_targets_gpu,
612 &proofs_gpu,
613 );
614
615 Self {
616 device,
617 queue,
618 mapping_error: Arc::new(Mutex::new(None)),
619 tainted: false,
620 initial_state_host,
621 initial_state_gpu,
622 proofs_host,
623 proofs_gpu,
624 bind_group_compute_f1,
625 compute_pipeline_compute_f1,
626 bind_group_sort_buckets_a,
627 compute_pipeline_sort_buckets_a,
628 bind_group_sort_buckets_b,
629 compute_pipeline_sort_buckets_b,
630 bind_group_find_matches_and_compute_f2,
631 compute_pipeline_find_matches_and_compute_f2,
632 bind_group_find_matches_and_compute_f3,
633 compute_pipeline_find_matches_and_compute_f3,
634 bind_group_find_matches_and_compute_f4,
635 compute_pipeline_find_matches_and_compute_f4,
636 bind_group_find_matches_and_compute_f5,
637 compute_pipeline_find_matches_and_compute_f5,
638 bind_group_find_matches_and_compute_f6,
639 compute_pipeline_find_matches_and_compute_f6,
640 bind_group_find_matches_and_compute_f7,
641 compute_pipeline_find_matches_and_compute_f7,
642 bind_group_find_proofs,
643 compute_pipeline_find_proofs,
644 }
645 }
646
647 fn create_proofs(
648 &mut self,
649 seed: &PosSeed,
650 ) -> Result<ProofsHostWrapper<'_>, RecordEncodingError> {
651 if self.tainted {
652 return Err(RecordEncodingError::DeviceBroken);
653 }
654 self.tainted = true;
655
656 let mut encoder = self
657 .device
658 .create_command_encoder(&CommandEncoderDescriptor {
659 label: Some("create_proofs"),
660 });
661
662 self.initial_state_host
664 .get_mapped_range_mut(..)
665 .copy_from_slice(&block_to_bytes(
666 &ChaCha8State::init(seed, &[0; _]).to_repr(),
667 ));
668 self.initial_state_host.unmap();
669
670 encoder.copy_buffer_to_buffer(
671 &self.initial_state_host,
672 0,
673 &self.initial_state_gpu,
674 0,
675 self.initial_state_host.size(),
676 );
677
678 {
679 let mut cpass = encoder.begin_compute_pass(&ComputePassDescriptor {
680 label: Some("create_proofs"),
681 timestamp_writes: None,
682 });
683
684 cpass.set_bind_group(0, &self.bind_group_compute_f1, &[]);
685 cpass.set_pipeline(&self.compute_pipeline_compute_f1);
686 cpass.dispatch_workgroups(
687 MAX_TABLE_SIZE
688 .div_ceil(compute_f1::WORKGROUP_SIZE * compute_f1::ELEMENTS_PER_INVOCATION),
689 1,
690 1,
691 );
692
693 cpass.set_bind_group(0, &self.bind_group_sort_buckets_a, &[]);
694 cpass.set_pipeline(&self.compute_pipeline_sort_buckets_a);
695 cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
696
697 cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f2, &[]);
698 cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f2);
699 cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
700
701 cpass.set_bind_group(0, &self.bind_group_sort_buckets_b, &[]);
702 cpass.set_pipeline(&self.compute_pipeline_sort_buckets_b);
703 cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
704
705 cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f3, &[]);
706 cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f3);
707 cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
708
709 cpass.set_bind_group(0, &self.bind_group_sort_buckets_a, &[]);
710 cpass.set_pipeline(&self.compute_pipeline_sort_buckets_a);
711 cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
712
713 cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f4, &[]);
714 cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f4);
715 cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
716
717 cpass.set_bind_group(0, &self.bind_group_sort_buckets_b, &[]);
718 cpass.set_pipeline(&self.compute_pipeline_sort_buckets_b);
719 cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
720
721 cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f5, &[]);
722 cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f5);
723 cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
724
725 cpass.set_bind_group(0, &self.bind_group_sort_buckets_a, &[]);
726 cpass.set_pipeline(&self.compute_pipeline_sort_buckets_a);
727 cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
728
729 cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f6, &[]);
730 cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f6);
731 cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
732
733 cpass.set_bind_group(0, &self.bind_group_sort_buckets_b, &[]);
734 cpass.set_pipeline(&self.compute_pipeline_sort_buckets_b);
735 cpass.dispatch_workgroups(NUM_BUCKETS as u32, 1, 1);
736
737 cpass.set_bind_group(0, &self.bind_group_find_matches_and_compute_f7, &[]);
738 cpass.set_pipeline(&self.compute_pipeline_find_matches_and_compute_f7);
739 cpass.dispatch_workgroups(NUM_MATCH_BUCKETS as u32, 1, 1);
740
741 cpass.set_bind_group(0, &self.bind_group_find_proofs, &[]);
742 cpass.set_pipeline(&self.compute_pipeline_find_proofs);
743 cpass.dispatch_workgroups(NUM_S_BUCKETS as u32 / find_proofs::WORKGROUP_SIZE, 1, 1);
744 }
745
746 encoder.copy_buffer_to_buffer(
747 &self.proofs_gpu,
748 0,
749 &self.proofs_host,
750 0,
751 self.proofs_host.size(),
752 );
753
754 encoder.map_buffer_on_submit(&self.initial_state_host, MapMode::Write, .., {
756 let mapping_error = Arc::clone(&self.mapping_error);
757
758 move |r| {
759 if let Err(error) = r {
760 mapping_error.lock().replace(error);
761 }
762 }
763 });
764 encoder.map_buffer_on_submit(&self.proofs_host, MapMode::Read, .., {
765 let mapping_error = Arc::clone(&self.mapping_error);
766
767 move |r| {
768 if let Err(error) = r {
769 mapping_error.lock().replace(error);
770 }
771 }
772 });
773
774 let submission_index = self.queue.submit([encoder.finish()]);
775
776 self.device.poll(PollType::Wait {
777 submission_index: Some(submission_index),
778 timeout: None,
779 })?;
780
781 if let Some(error) = self.mapping_error.lock().take() {
782 return Err(RecordEncodingError::BufferMapping(error));
783 }
784
785 let proofs = {
786 let proofs_host_ptr = self
787 .proofs_host
788 .get_mapped_range(..)
789 .as_ptr()
790 .cast::<ProofsHost>();
791 unsafe { &*proofs_host_ptr }
793 };
794
795 self.tainted = false;
796
797 Ok(ProofsHostWrapper {
798 proofs,
799 proofs_host: &self.proofs_host,
800 })
801 }
802}
803
804fn bind_group_and_pipeline_compute_f1(
805 device: &wgpu::Device,
806 module: &ShaderModule,
807 initial_state_gpu: &Buffer,
808 bucket_sizes_gpu: &Buffer,
809 buckets_gpu: &Buffer,
810) -> (BindGroup, ComputePipeline) {
811 let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
812 label: Some("compute_f1"),
813 entries: &[
814 BindGroupLayoutEntry {
815 binding: 0,
816 count: None,
817 visibility: ShaderStages::COMPUTE,
818 ty: BindingType::Buffer {
819 has_dynamic_offset: false,
820 min_binding_size: None,
821 ty: BufferBindingType::Uniform,
822 },
823 },
824 BindGroupLayoutEntry {
825 binding: 1,
826 count: None,
827 visibility: ShaderStages::COMPUTE,
828 ty: BindingType::Buffer {
829 has_dynamic_offset: false,
830 min_binding_size: None,
831 ty: BufferBindingType::Storage { read_only: false },
832 },
833 },
834 BindGroupLayoutEntry {
835 binding: 2,
836 count: None,
837 visibility: ShaderStages::COMPUTE,
838 ty: BindingType::Buffer {
839 has_dynamic_offset: false,
840 min_binding_size: None,
841 ty: BufferBindingType::Storage { read_only: false },
842 },
843 },
844 ],
845 });
846
847 let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
848 label: Some("compute_f1"),
849 bind_group_layouts: &[&bind_group_layout],
850 push_constant_ranges: &[],
851 });
852
853 let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
854 compilation_options: PipelineCompilationOptions {
855 constants: &[],
856 zero_initialize_workgroup_memory: false,
857 },
858 cache: None,
859 label: Some("compute_f1"),
860 layout: Some(&pipeline_layout),
861 module,
862 entry_point: Some("compute_f1"),
863 });
864
865 let bind_group = device.create_bind_group(&BindGroupDescriptor {
866 label: Some("compute_f1"),
867 layout: &bind_group_layout,
868 entries: &[
869 BindGroupEntry {
870 binding: 0,
871 resource: initial_state_gpu.as_entire_binding(),
872 },
873 BindGroupEntry {
874 binding: 1,
875 resource: bucket_sizes_gpu.as_entire_binding(),
876 },
877 BindGroupEntry {
878 binding: 2,
879 resource: buckets_gpu.as_entire_binding(),
880 },
881 ],
882 });
883
884 (bind_group, compute_pipeline)
885}
886
887fn bind_group_and_pipeline_sort_buckets(
888 device: &wgpu::Device,
889 module: &ShaderModule,
890 bucket_sizes_gpu: &Buffer,
891 buckets_gpu: &Buffer,
892) -> (BindGroup, ComputePipeline) {
893 let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
894 label: Some("sort_buckets"),
895 entries: &[
896 BindGroupLayoutEntry {
897 binding: 0,
898 count: None,
899 visibility: ShaderStages::COMPUTE,
900 ty: BindingType::Buffer {
901 has_dynamic_offset: false,
902 min_binding_size: None,
903 ty: BufferBindingType::Storage { read_only: false },
904 },
905 },
906 BindGroupLayoutEntry {
907 binding: 1,
908 count: None,
909 visibility: ShaderStages::COMPUTE,
910 ty: BindingType::Buffer {
911 has_dynamic_offset: false,
912 min_binding_size: None,
913 ty: BufferBindingType::Storage { read_only: false },
914 },
915 },
916 ],
917 });
918
919 let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
920 label: Some("sort_buckets"),
921 bind_group_layouts: &[&bind_group_layout],
922 push_constant_ranges: &[],
923 });
924
925 let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
926 compilation_options: PipelineCompilationOptions {
927 constants: &[],
928 zero_initialize_workgroup_memory: false,
929 },
930 cache: None,
931 label: Some("sort_buckets"),
932 layout: Some(&pipeline_layout),
933 module,
934 entry_point: Some("sort_buckets"),
935 });
936
937 let bind_group = device.create_bind_group(&BindGroupDescriptor {
938 label: Some("sort_buckets"),
939 layout: &bind_group_layout,
940 entries: &[
941 BindGroupEntry {
942 binding: 0,
943 resource: bucket_sizes_gpu.as_entire_binding(),
944 },
945 BindGroupEntry {
946 binding: 1,
947 resource: buckets_gpu.as_entire_binding(),
948 },
949 ],
950 });
951
952 (bind_group, compute_pipeline)
953}
954
955fn bind_group_and_pipeline_find_matches_and_compute_f2(
956 device: &wgpu::Device,
957 module: &ShaderModule,
958 parent_buckets_gpu: &Buffer,
959 bucket_sizes_gpu: &Buffer,
960 buckets_gpu: &Buffer,
961 positions_gpu: &Buffer,
962 metadatas_gpu: &Buffer,
963) -> (BindGroup, ComputePipeline) {
964 let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
965 label: Some("find_matches_and_compute_f2"),
966 entries: &[
967 BindGroupLayoutEntry {
968 binding: 0,
969 count: None,
970 visibility: ShaderStages::COMPUTE,
971 ty: BindingType::Buffer {
972 has_dynamic_offset: false,
973 min_binding_size: None,
974 ty: BufferBindingType::Storage { read_only: true },
975 },
976 },
977 BindGroupLayoutEntry {
978 binding: 1,
979 count: None,
980 visibility: ShaderStages::COMPUTE,
981 ty: BindingType::Buffer {
982 has_dynamic_offset: false,
983 min_binding_size: None,
984 ty: BufferBindingType::Storage { read_only: false },
985 },
986 },
987 BindGroupLayoutEntry {
988 binding: 2,
989 count: None,
990 visibility: ShaderStages::COMPUTE,
991 ty: BindingType::Buffer {
992 has_dynamic_offset: false,
993 min_binding_size: None,
994 ty: BufferBindingType::Storage { read_only: false },
995 },
996 },
997 BindGroupLayoutEntry {
998 binding: 3,
999 count: None,
1000 visibility: ShaderStages::COMPUTE,
1001 ty: BindingType::Buffer {
1002 has_dynamic_offset: false,
1003 min_binding_size: None,
1004 ty: BufferBindingType::Storage { read_only: false },
1005 },
1006 },
1007 BindGroupLayoutEntry {
1008 binding: 4,
1009 count: None,
1010 visibility: ShaderStages::COMPUTE,
1011 ty: BindingType::Buffer {
1012 has_dynamic_offset: false,
1013 min_binding_size: None,
1014 ty: BufferBindingType::Storage { read_only: false },
1015 },
1016 },
1017 ],
1018 });
1019
1020 let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1021 label: Some("find_matches_and_compute_f2"),
1022 bind_group_layouts: &[&bind_group_layout],
1023 push_constant_ranges: &[],
1024 });
1025
1026 let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1027 compilation_options: PipelineCompilationOptions {
1028 constants: &[],
1029 zero_initialize_workgroup_memory: true,
1030 },
1031 cache: None,
1032 label: Some("find_matches_and_compute_f2"),
1033 layout: Some(&pipeline_layout),
1034 module,
1035 entry_point: Some("find_matches_and_compute_f2"),
1036 });
1037
1038 let bind_group = device.create_bind_group(&BindGroupDescriptor {
1039 label: Some("find_matches_and_compute_f2"),
1040 layout: &bind_group_layout,
1041 entries: &[
1042 BindGroupEntry {
1043 binding: 0,
1044 resource: parent_buckets_gpu.as_entire_binding(),
1045 },
1046 BindGroupEntry {
1047 binding: 1,
1048 resource: bucket_sizes_gpu.as_entire_binding(),
1049 },
1050 BindGroupEntry {
1051 binding: 2,
1052 resource: buckets_gpu.as_entire_binding(),
1053 },
1054 BindGroupEntry {
1055 binding: 3,
1056 resource: positions_gpu.as_entire_binding(),
1057 },
1058 BindGroupEntry {
1059 binding: 4,
1060 resource: metadatas_gpu.as_entire_binding(),
1061 },
1062 ],
1063 });
1064
1065 (bind_group, compute_pipeline)
1066}
1067
1068#[expect(
1069 clippy::too_many_arguments,
1070 reason = "Both I/O and Vulkan stuff together take a lot of arguments"
1071)]
1072fn bind_group_and_pipeline_find_matches_and_compute_fn<const TABLE_NUMBER: u8>(
1073 device: &wgpu::Device,
1074 module: &ShaderModule,
1075 parent_buckets_gpu: &Buffer,
1076 parent_metadatas_gpu: &Buffer,
1077 bucket_sizes_gpu: &Buffer,
1078 buckets_gpu: &Buffer,
1079 positions_gpu: &Buffer,
1080 metadatas_gpu: &Buffer,
1081) -> (BindGroup, ComputePipeline) {
1082 let label = format!("find_matches_and_compute_f{TABLE_NUMBER}");
1083 let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
1084 label: Some(&label),
1085 entries: &[
1086 BindGroupLayoutEntry {
1087 binding: 0,
1088 count: None,
1089 visibility: ShaderStages::COMPUTE,
1090 ty: BindingType::Buffer {
1091 has_dynamic_offset: false,
1092 min_binding_size: None,
1093 ty: BufferBindingType::Storage { read_only: true },
1094 },
1095 },
1096 BindGroupLayoutEntry {
1097 binding: 1,
1098 count: None,
1099 visibility: ShaderStages::COMPUTE,
1100 ty: BindingType::Buffer {
1101 has_dynamic_offset: false,
1102 min_binding_size: None,
1103 ty: BufferBindingType::Storage { read_only: true },
1104 },
1105 },
1106 BindGroupLayoutEntry {
1107 binding: 2,
1108 count: None,
1109 visibility: ShaderStages::COMPUTE,
1110 ty: BindingType::Buffer {
1111 has_dynamic_offset: false,
1112 min_binding_size: None,
1113 ty: BufferBindingType::Storage { read_only: false },
1114 },
1115 },
1116 BindGroupLayoutEntry {
1117 binding: 3,
1118 count: None,
1119 visibility: ShaderStages::COMPUTE,
1120 ty: BindingType::Buffer {
1121 has_dynamic_offset: false,
1122 min_binding_size: None,
1123 ty: BufferBindingType::Storage { read_only: false },
1124 },
1125 },
1126 BindGroupLayoutEntry {
1127 binding: 4,
1128 count: None,
1129 visibility: ShaderStages::COMPUTE,
1130 ty: BindingType::Buffer {
1131 has_dynamic_offset: false,
1132 min_binding_size: None,
1133 ty: BufferBindingType::Storage { read_only: false },
1134 },
1135 },
1136 BindGroupLayoutEntry {
1137 binding: 5,
1138 count: None,
1139 visibility: ShaderStages::COMPUTE,
1140 ty: BindingType::Buffer {
1141 has_dynamic_offset: false,
1142 min_binding_size: None,
1143 ty: BufferBindingType::Storage { read_only: false },
1144 },
1145 },
1146 ],
1147 });
1148
1149 let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1150 label: Some(&label),
1151 bind_group_layouts: &[&bind_group_layout],
1152 push_constant_ranges: &[],
1153 });
1154
1155 let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1156 compilation_options: PipelineCompilationOptions {
1157 constants: &[],
1158 zero_initialize_workgroup_memory: true,
1159 },
1160 cache: None,
1161 label: Some(&label),
1162 layout: Some(&pipeline_layout),
1163 module,
1164 entry_point: Some(&format!("find_matches_and_compute_f{TABLE_NUMBER}")),
1165 });
1166
1167 let bind_group = device.create_bind_group(&BindGroupDescriptor {
1168 label: Some(&label),
1169 layout: &bind_group_layout,
1170 entries: &[
1171 BindGroupEntry {
1172 binding: 0,
1173 resource: parent_buckets_gpu.as_entire_binding(),
1174 },
1175 BindGroupEntry {
1176 binding: 1,
1177 resource: parent_metadatas_gpu.as_entire_binding(),
1178 },
1179 BindGroupEntry {
1180 binding: 2,
1181 resource: bucket_sizes_gpu.as_entire_binding(),
1182 },
1183 BindGroupEntry {
1184 binding: 3,
1185 resource: buckets_gpu.as_entire_binding(),
1186 },
1187 BindGroupEntry {
1188 binding: 4,
1189 resource: positions_gpu.as_entire_binding(),
1190 },
1191 BindGroupEntry {
1192 binding: 5,
1193 resource: metadatas_gpu.as_entire_binding(),
1194 },
1195 ],
1196 });
1197
1198 (bind_group, compute_pipeline)
1199}
1200
1201fn bind_group_and_pipeline_find_matches_and_compute_f7(
1202 device: &wgpu::Device,
1203 module: &ShaderModule,
1204 parent_buckets_gpu: &Buffer,
1205 parent_metadatas_gpu: &Buffer,
1206 table_6_proof_targets_sizes_gpu: &Buffer,
1207 table_6_proof_targets_gpu: &Buffer,
1208) -> (BindGroup, ComputePipeline) {
1209 let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
1210 label: Some("find_matches_and_compute_f7"),
1211 entries: &[
1212 BindGroupLayoutEntry {
1213 binding: 0,
1214 count: None,
1215 visibility: ShaderStages::COMPUTE,
1216 ty: BindingType::Buffer {
1217 has_dynamic_offset: false,
1218 min_binding_size: None,
1219 ty: BufferBindingType::Storage { read_only: true },
1220 },
1221 },
1222 BindGroupLayoutEntry {
1223 binding: 1,
1224 count: None,
1225 visibility: ShaderStages::COMPUTE,
1226 ty: BindingType::Buffer {
1227 has_dynamic_offset: false,
1228 min_binding_size: None,
1229 ty: BufferBindingType::Storage { read_only: true },
1230 },
1231 },
1232 BindGroupLayoutEntry {
1233 binding: 2,
1234 count: None,
1235 visibility: ShaderStages::COMPUTE,
1236 ty: BindingType::Buffer {
1237 has_dynamic_offset: false,
1238 min_binding_size: None,
1239 ty: BufferBindingType::Storage { read_only: false },
1240 },
1241 },
1242 BindGroupLayoutEntry {
1243 binding: 3,
1244 count: None,
1245 visibility: ShaderStages::COMPUTE,
1246 ty: BindingType::Buffer {
1247 has_dynamic_offset: false,
1248 min_binding_size: None,
1249 ty: BufferBindingType::Storage { read_only: false },
1250 },
1251 },
1252 ],
1253 });
1254
1255 let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1256 label: Some("find_matches_and_compute_f7"),
1257 bind_group_layouts: &[&bind_group_layout],
1258 push_constant_ranges: &[],
1259 });
1260
1261 let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1262 compilation_options: PipelineCompilationOptions {
1263 constants: &[],
1264 zero_initialize_workgroup_memory: true,
1265 },
1266 cache: None,
1267 label: Some("find_matches_and_compute_f7"),
1268 layout: Some(&pipeline_layout),
1269 module,
1270 entry_point: Some("find_matches_and_compute_f7"),
1271 });
1272
1273 let bind_group = device.create_bind_group(&BindGroupDescriptor {
1274 label: Some("find_matches_and_compute_f7"),
1275 layout: &bind_group_layout,
1276 entries: &[
1277 BindGroupEntry {
1278 binding: 0,
1279 resource: parent_buckets_gpu.as_entire_binding(),
1280 },
1281 BindGroupEntry {
1282 binding: 1,
1283 resource: parent_metadatas_gpu.as_entire_binding(),
1284 },
1285 BindGroupEntry {
1286 binding: 2,
1287 resource: table_6_proof_targets_sizes_gpu.as_entire_binding(),
1288 },
1289 BindGroupEntry {
1290 binding: 3,
1291 resource: table_6_proof_targets_gpu.as_entire_binding(),
1292 },
1293 ],
1294 });
1295
1296 (bind_group, compute_pipeline)
1297}
1298
1299#[expect(
1300 clippy::too_many_arguments,
1301 reason = "Both I/O and Vulkan stuff together take a lot of arguments"
1302)]
1303fn bind_group_and_pipeline_find_proofs(
1304 device: &wgpu::Device,
1305 module: &ShaderModule,
1306 table_2_positions_gpu: &Buffer,
1307 table_3_positions_gpu: &Buffer,
1308 table_4_positions_gpu: &Buffer,
1309 table_5_positions_gpu: &Buffer,
1310 table_6_positions_gpu: &Buffer,
1311 bucket_sizes_gpu: &Buffer,
1312 buckets_gpu: &Buffer,
1313 proofs_gpu: &Buffer,
1314) -> (BindGroup, ComputePipeline) {
1315 let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
1316 label: Some("find_proofs"),
1317 entries: &[
1318 BindGroupLayoutEntry {
1319 binding: 0,
1320 count: None,
1321 visibility: ShaderStages::COMPUTE,
1322 ty: BindingType::Buffer {
1323 has_dynamic_offset: false,
1324 min_binding_size: None,
1325 ty: BufferBindingType::Storage { read_only: true },
1326 },
1327 },
1328 BindGroupLayoutEntry {
1329 binding: 1,
1330 count: None,
1331 visibility: ShaderStages::COMPUTE,
1332 ty: BindingType::Buffer {
1333 has_dynamic_offset: false,
1334 min_binding_size: None,
1335 ty: BufferBindingType::Storage { read_only: true },
1336 },
1337 },
1338 BindGroupLayoutEntry {
1339 binding: 2,
1340 count: None,
1341 visibility: ShaderStages::COMPUTE,
1342 ty: BindingType::Buffer {
1343 has_dynamic_offset: false,
1344 min_binding_size: None,
1345 ty: BufferBindingType::Storage { read_only: true },
1346 },
1347 },
1348 BindGroupLayoutEntry {
1349 binding: 3,
1350 count: None,
1351 visibility: ShaderStages::COMPUTE,
1352 ty: BindingType::Buffer {
1353 has_dynamic_offset: false,
1354 min_binding_size: None,
1355 ty: BufferBindingType::Storage { read_only: true },
1356 },
1357 },
1358 BindGroupLayoutEntry {
1359 binding: 4,
1360 count: None,
1361 visibility: ShaderStages::COMPUTE,
1362 ty: BindingType::Buffer {
1363 has_dynamic_offset: false,
1364 min_binding_size: None,
1365 ty: BufferBindingType::Storage { read_only: true },
1366 },
1367 },
1368 BindGroupLayoutEntry {
1369 binding: 5,
1370 count: None,
1371 visibility: ShaderStages::COMPUTE,
1372 ty: BindingType::Buffer {
1373 has_dynamic_offset: false,
1374 min_binding_size: None,
1375 ty: BufferBindingType::Storage { read_only: false },
1376 },
1377 },
1378 BindGroupLayoutEntry {
1379 binding: 6,
1380 count: None,
1381 visibility: ShaderStages::COMPUTE,
1382 ty: BindingType::Buffer {
1383 has_dynamic_offset: false,
1384 min_binding_size: None,
1385 ty: BufferBindingType::Storage { read_only: true },
1386 },
1387 },
1388 BindGroupLayoutEntry {
1389 binding: 7,
1390 count: None,
1391 visibility: ShaderStages::COMPUTE,
1392 ty: BindingType::Buffer {
1393 has_dynamic_offset: false,
1394 min_binding_size: None,
1395 ty: BufferBindingType::Storage { read_only: false },
1396 },
1397 },
1398 ],
1399 });
1400
1401 let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
1402 label: Some("find_proofs"),
1403 bind_group_layouts: &[&bind_group_layout],
1404 push_constant_ranges: &[],
1405 });
1406
1407 let compute_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
1408 compilation_options: PipelineCompilationOptions {
1409 constants: &[],
1410 zero_initialize_workgroup_memory: false,
1411 },
1412 cache: None,
1413 label: Some("find_proofs"),
1414 layout: Some(&pipeline_layout),
1415 module,
1416 entry_point: Some("find_proofs"),
1417 });
1418
1419 let bind_group = device.create_bind_group(&BindGroupDescriptor {
1420 label: Some("find_proofs"),
1421 layout: &bind_group_layout,
1422 entries: &[
1423 BindGroupEntry {
1424 binding: 0,
1425 resource: table_2_positions_gpu.as_entire_binding(),
1426 },
1427 BindGroupEntry {
1428 binding: 1,
1429 resource: table_3_positions_gpu.as_entire_binding(),
1430 },
1431 BindGroupEntry {
1432 binding: 2,
1433 resource: table_4_positions_gpu.as_entire_binding(),
1434 },
1435 BindGroupEntry {
1436 binding: 3,
1437 resource: table_5_positions_gpu.as_entire_binding(),
1438 },
1439 BindGroupEntry {
1440 binding: 4,
1441 resource: table_6_positions_gpu.as_entire_binding(),
1442 },
1443 BindGroupEntry {
1444 binding: 5,
1445 resource: bucket_sizes_gpu.as_entire_binding(),
1446 },
1447 BindGroupEntry {
1448 binding: 6,
1449 resource: buckets_gpu.as_entire_binding(),
1450 },
1451 BindGroupEntry {
1452 binding: 7,
1453 resource: proofs_gpu.as_entire_binding(),
1454 },
1455 ],
1456 });
1457
1458 (bind_group, compute_pipeline)
1459}